Coverage for tests/test_postgresql.py: 39%
115 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-25 15:13 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-25 15:13 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24import gc
25import itertools
26import os
27import secrets
28import unittest
29import warnings
30from contextlib import contextmanager
32import astropy.time
34try:
35 # It's possible but silly to have testing.postgresql installed without
36 # having the postgresql server installed (because then nothing in
37 # testing.postgresql would work), so we use the presence of that module
38 # to test whether we can expect the server to be available.
39 import testing.postgresql
40except ImportError:
41 testing = None
43import sqlalchemy
44from lsst.daf.butler import Timespan, ddl
45from lsst.daf.butler.registry import _ButlerRegistry, _RegistryFactory
47try:
48 from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType
49except ImportError:
50 testing = None
51from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests
52from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
54TESTDIR = os.path.abspath(os.path.dirname(__file__))
57def _startServer(root):
58 """Start a PostgreSQL server and create a database within it, returning
59 an object encapsulating both.
60 """
61 server = testing.postgresql.Postgresql(base_dir=root)
62 engine = sqlalchemy.engine.create_engine(server.url())
63 with engine.begin() as connection:
64 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;"))
65 return server
68@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
69class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests):
70 """Test a postgres Registry."""
72 @classmethod
73 def setUpClass(cls):
74 cls.root = makeTestTempDir(TESTDIR)
75 cls.server = _startServer(cls.root)
77 @classmethod
78 def tearDownClass(cls):
79 # Clean up any lingering SQLAlchemy engines/connections
80 # so they're closed before we shut down the server.
81 gc.collect()
82 cls.server.stop()
83 removeTestTempDir(cls.root)
85 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase:
86 namespace = f"namespace_{secrets.token_hex(8).lower()}"
87 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace)
89 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase:
90 return PostgresqlDatabase.fromUri(
91 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable
92 )
94 @contextmanager
95 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase:
96 yield self.getNewConnection(database, writeable=False)
98 def testNameShrinking(self):
99 """Test that too-long names for database entities other than tables
100 and columns (which we preserve, and just expect to fit) are shrunk.
101 """
102 db = self.makeEmptyDatabase(origin=1)
103 with db.declareStaticTables(create=True) as context:
104 # Table and field names are each below the 63-char limit even when
105 # accounting for the prefix, but their combination (which will
106 # appear in sequences and constraints) is not.
107 tableName = "a_table_with_a_very_very_long_42_char_name"
108 fieldName1 = "a_column_with_a_very_very_long_43_char_name"
109 fieldName2 = "another_column_with_a_very_very_long_49_char_name"
110 context.addTable(
111 tableName,
112 ddl.TableSpec(
113 fields=[
114 ddl.FieldSpec(
115 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
116 ),
117 ddl.FieldSpec(
118 fieldName2,
119 dtype=sqlalchemy.String,
120 length=16,
121 nullable=False,
122 ),
123 ],
124 unique={(fieldName2,)},
125 ),
126 )
127 # Add another table, this time dynamically, with a foreign key to the
128 # first table.
129 db.ensureTableExists(
130 tableName + "_b",
131 ddl.TableSpec(
132 fields=[
133 ddl.FieldSpec(
134 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
135 ),
136 ddl.FieldSpec(
137 fieldName2 + "_b",
138 dtype=sqlalchemy.String,
139 length=16,
140 nullable=False,
141 ),
142 ],
143 foreignKeys=[
144 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)),
145 ],
146 ),
147 )
149 def test_RangeTimespanType(self):
150 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
151 offset = astropy.time.TimeDelta(60, format="sec")
152 timestamps = [start + offset * n for n in range(3)]
153 timespans = [Timespan(begin=None, end=None)]
154 timespans.extend(Timespan(begin=None, end=t) for t in timestamps)
155 timespans.extend(Timespan(begin=t, end=None) for t in timestamps)
156 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2))
157 db = self.makeEmptyDatabase(origin=1)
158 with db.declareStaticTables(create=True) as context:
159 tbl = context.addTable(
160 "tbl",
161 ddl.TableSpec(
162 fields=[
163 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
164 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType),
165 ],
166 ),
167 )
168 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)]
169 db.insert(tbl, *rows)
171 # Test basic round-trip through database.
172 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result:
173 self.assertEqual(rows, [row._asdict() for row in sql_result])
175 # Test that Timespan's Python methods are consistent with our usage of
176 # half-open ranges and PostgreSQL operators on ranges.
177 def subquery(alias: str) -> sqlalchemy.sql.FromClause:
178 return (
179 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan"))
180 .select_from(tbl)
181 .alias(alias)
182 )
184 sq1 = subquery("sq1")
185 sq2 = subquery("sq2")
186 query = sqlalchemy.sql.select(
187 sq1.columns.id.label("n1"),
188 sq2.columns.id.label("n2"),
189 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"),
190 )
192 # `columns` is deprecated since 1.4, but
193 # `selected_columns` method did not exist in 1.3.
194 if hasattr(query, "selected_columns"):
195 columns = query.selected_columns
196 else:
197 columns = query.columns
199 # SQLAlchemy issues a warning about cartesian product of two tables,
200 # which we do intentionally. Disable that warning temporarily.
201 with warnings.catch_warnings():
202 warnings.filterwarnings(
203 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning
204 )
205 with db.query(query) as sql_result:
206 dbResults = {
207 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings()
208 }
210 pyResults = {
211 (n1, n2): t1.overlaps(t2)
212 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans))
213 }
214 self.assertEqual(pyResults, dbResults)
217@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
218class PostgresqlRegistryTests(RegistryTests):
219 """Tests for `Registry` backed by a PostgreSQL database.
221 Notes
222 -----
223 This is not a subclass of `unittest.TestCase` but to avoid repetition it
224 defines methods that override `unittest.TestCase` methods. To make this
225 work subclasses have to have this class first in the bases list.
226 """
228 @classmethod
229 def setUpClass(cls):
230 cls.root = makeTestTempDir(TESTDIR)
231 cls.server = _startServer(cls.root)
233 @classmethod
234 def tearDownClass(cls):
235 # Clean up any lingering SQLAlchemy engines/connections
236 # so they're closed before we shut down the server.
237 gc.collect()
238 cls.server.stop()
239 removeTestTempDir(cls.root)
241 @classmethod
242 def getDataDir(cls) -> str:
243 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry"))
245 def makeRegistry(self, share_repo_with: _ButlerRegistry | None = None) -> _ButlerRegistry:
246 if share_repo_with is None:
247 namespace = f"namespace_{secrets.token_hex(8).lower()}"
248 else:
249 namespace = share_repo_with._db.namespace
250 config = self.makeRegistryConfig()
251 config["db"] = self.server.url()
252 config["namespace"] = namespace
253 if share_repo_with is None:
254 return _RegistryFactory(config).create_from_config()
255 else:
256 return _RegistryFactory(config).from_config()
259class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
260 """Tests for `Registry` backed by a PostgreSQL database.
262 This test case uses NameKeyCollectionManager and
263 ByDimensionsDatasetRecordStorageManagerUUID.
264 """
266 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
267 datasetsManager = (
268 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
269 )
272class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
273 """Tests for `Registry` backed by a PostgreSQL database.
275 This test case uses SynthIntKeyCollectionManager and
276 ByDimensionsDatasetRecordStorageManagerUUID.
277 """
279 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
280 datasetsManager = (
281 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
282 )
285if __name__ == "__main__":
286 unittest.main()