Coverage for tests/test_postgresql.py: 39%
120 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-27 08:58 +0000
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-27 08:58 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import gc
23import itertools
24import os
25import secrets
26import unittest
27import warnings
28from contextlib import contextmanager
29from typing import Optional
31import astropy.time
33try:
34 # It's possible but silly to have testing.postgresql installed without
35 # having the postgresql server installed (because then nothing in
36 # testing.postgresql would work), so we use the presence of that module
37 # to test whether we can expect the server to be available.
38 import testing.postgresql
39except ImportError:
40 testing = None
42import sqlalchemy
43from lsst.daf.butler import Timespan, ddl
44from lsst.daf.butler.registry import Registry
45from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType
46from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests
47from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
49TESTDIR = os.path.abspath(os.path.dirname(__file__))
52def _startServer(root):
53 """Start a PostgreSQL server and create a database within it, returning
54 an object encapsulating both.
55 """
56 server = testing.postgresql.Postgresql(base_dir=root)
57 engine = sqlalchemy.engine.create_engine(server.url())
58 with engine.begin() as connection:
59 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;"))
60 return server
63@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
64class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests):
65 @classmethod
66 def setUpClass(cls):
67 cls.root = makeTestTempDir(TESTDIR)
68 cls.server = _startServer(cls.root)
70 @classmethod
71 def tearDownClass(cls):
72 # Clean up any lingering SQLAlchemy engines/connections
73 # so they're closed before we shut down the server.
74 gc.collect()
75 cls.server.stop()
76 removeTestTempDir(cls.root)
78 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase:
79 namespace = f"namespace_{secrets.token_hex(8).lower()}"
80 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace)
82 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase:
83 return PostgresqlDatabase.fromUri(
84 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable
85 )
87 @contextmanager
88 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase:
89 yield self.getNewConnection(database, writeable=False)
91 def testNameShrinking(self):
92 """Test that too-long names for database entities other than tables
93 and columns (which we preserve, and just expect to fit) are shrunk.
94 """
95 db = self.makeEmptyDatabase(origin=1)
96 with db.declareStaticTables(create=True) as context:
97 # Table and field names are each below the 63-char limit even when
98 # accounting for the prefix, but their combination (which will
99 # appear in sequences and constraints) is not.
100 tableName = "a_table_with_a_very_very_long_42_char_name"
101 fieldName1 = "a_column_with_a_very_very_long_43_char_name"
102 fieldName2 = "another_column_with_a_very_very_long_49_char_name"
103 context.addTable(
104 tableName,
105 ddl.TableSpec(
106 fields=[
107 ddl.FieldSpec(
108 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
109 ),
110 ddl.FieldSpec(
111 fieldName2,
112 dtype=sqlalchemy.String,
113 length=16,
114 nullable=False,
115 ),
116 ],
117 unique={(fieldName2,)},
118 ),
119 )
120 # Add another table, this time dynamically, with a foreign key to the
121 # first table.
122 db.ensureTableExists(
123 tableName + "_b",
124 ddl.TableSpec(
125 fields=[
126 ddl.FieldSpec(
127 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
128 ),
129 ddl.FieldSpec(
130 fieldName2 + "_b",
131 dtype=sqlalchemy.String,
132 length=16,
133 nullable=False,
134 ),
135 ],
136 foreignKeys=[
137 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)),
138 ],
139 ),
140 )
142 def test_RangeTimespanType(self):
143 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
144 offset = astropy.time.TimeDelta(60, format="sec")
145 timestamps = [start + offset * n for n in range(3)]
146 timespans = [Timespan(begin=None, end=None)]
147 timespans.extend(Timespan(begin=None, end=t) for t in timestamps)
148 timespans.extend(Timespan(begin=t, end=None) for t in timestamps)
149 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2))
150 db = self.makeEmptyDatabase(origin=1)
151 with db.declareStaticTables(create=True) as context:
152 tbl = context.addTable(
153 "tbl",
154 ddl.TableSpec(
155 fields=[
156 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
157 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType),
158 ],
159 ),
160 )
161 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)]
162 db.insert(tbl, *rows)
164 # Test basic round-trip through database.
165 self.assertEqual(rows, [row._asdict() for row in db.query(tbl.select().order_by(tbl.columns.id))])
167 # Test that Timespan's Python methods are consistent with our usage of
168 # half-open ranges and PostgreSQL operators on ranges.
169 def subquery(alias: str) -> sqlalchemy.sql.FromClause:
170 return (
171 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan"))
172 .select_from(tbl)
173 .alias(alias)
174 )
176 sq1 = subquery("sq1")
177 sq2 = subquery("sq2")
178 query = sqlalchemy.sql.select(
179 sq1.columns.id.label("n1"),
180 sq2.columns.id.label("n2"),
181 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"),
182 )
184 # `columns` is deprecated since 1.4, but
185 # `selected_columns` method did not exist in 1.3.
186 if hasattr(query, "selected_columns"):
187 columns = query.selected_columns
188 else:
189 columns = query.columns
191 # SQLAlchemy issues a warning about cartesian product of two tables,
192 # which we do intentionally. Disable that warning temporarily.
193 with warnings.catch_warnings():
194 warnings.filterwarnings(
195 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning
196 )
197 dbResults = {
198 (row[columns.n1], row[columns.n2]): row[columns.overlaps]
199 for row in db.query(query).mappings()
200 }
202 pyResults = {
203 (n1, n2): t1.overlaps(t2)
204 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans))
205 }
206 self.assertEqual(pyResults, dbResults)
209@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
210class PostgresqlRegistryTests(RegistryTests):
211 """Tests for `Registry` backed by a PostgreSQL database.
213 Note
214 ----
215 This is not a subclass of `unittest.TestCase` but to avoid repetition it
216 defines methods that override `unittest.TestCase` methods. To make this
217 work sublasses have to have this class first in the bases list.
218 """
220 @classmethod
221 def setUpClass(cls):
222 cls.root = makeTestTempDir(TESTDIR)
223 cls.server = _startServer(cls.root)
225 @classmethod
226 def tearDownClass(cls):
227 # Clean up any lingering SQLAlchemy engines/connections
228 # so they're closed before we shut down the server.
229 gc.collect()
230 cls.server.stop()
231 removeTestTempDir(cls.root)
233 @classmethod
234 def getDataDir(cls) -> str:
235 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry"))
237 def makeRegistry(self, share_repo_with: Optional[Registry] = None) -> Registry:
238 if share_repo_with is None:
239 namespace = f"namespace_{secrets.token_hex(8).lower()}"
240 else:
241 namespace = share_repo_with._db.namespace
242 config = self.makeRegistryConfig()
243 config["db"] = self.server.url()
244 config["namespace"] = namespace
245 if share_repo_with is None:
246 return Registry.createFromConfig(config)
247 else:
248 return Registry.fromConfig(config)
251class PostgresqlRegistryNameKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase):
252 """Tests for `Registry` backed by a PostgreSQL database.
254 This test case uses NameKeyCollectionManager and
255 ByDimensionsDatasetRecordStorageManager.
256 """
258 def makeRegistry(self, share_repo_with: Optional[Registry] = None) -> Registry:
259 if share_repo_with is None:
260 with self.assertWarns(FutureWarning):
261 return super().makeRegistry()
262 else:
263 return super().makeRegistry(share_repo_with)
265 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
266 datasetsManager = "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManager"
269class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
270 """Tests for `Registry` backed by a PostgreSQL database.
272 This test case uses NameKeyCollectionManager and
273 ByDimensionsDatasetRecordStorageManagerUUID.
274 """
276 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
277 datasetsManager = (
278 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
279 )
282class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
283 """Tests for `Registry` backed by a PostgreSQL database.
285 This test case uses SynthIntKeyCollectionManager and
286 ByDimensionsDatasetRecordStorageManagerUUID.
287 """
289 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
290 datasetsManager = (
291 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
292 )
295if __name__ == "__main__": 295 ↛ 296line 295 didn't jump to line 296, because the condition on line 295 was never true
296 unittest.main()