Coverage for tests/test_postgresql.py: 40%
116 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 02:50 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 02:50 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30import gc
31import itertools
32import os
33import secrets
34import unittest
35import warnings
36from contextlib import contextmanager
38import astropy.time
40try:
41 # It's possible but silly to have testing.postgresql installed without
42 # having the postgresql server installed (because then nothing in
43 # testing.postgresql would work), so we use the presence of that module
44 # to test whether we can expect the server to be available.
45 import testing.postgresql
46except ImportError:
47 testing = None
49import sqlalchemy
50from lsst.daf.butler import Timespan, ddl
51from lsst.daf.butler.registry import _RegistryFactory
52from lsst.daf.butler.registry.sql_registry import SqlRegistry
54try:
55 from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType
56except ImportError:
57 testing = None
58from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests
59from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
61TESTDIR = os.path.abspath(os.path.dirname(__file__))
64def _startServer(root):
65 """Start a PostgreSQL server and create a database within it, returning
66 an object encapsulating both.
67 """
68 server = testing.postgresql.Postgresql(base_dir=root)
69 engine = sqlalchemy.engine.create_engine(server.url())
70 with engine.begin() as connection:
71 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;"))
72 return server
75@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
76class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests):
77 """Test a postgres Registry."""
79 @classmethod
80 def setUpClass(cls):
81 cls.root = makeTestTempDir(TESTDIR)
82 cls.server = _startServer(cls.root)
84 @classmethod
85 def tearDownClass(cls):
86 # Clean up any lingering SQLAlchemy engines/connections
87 # so they're closed before we shut down the server.
88 gc.collect()
89 cls.server.stop()
90 removeTestTempDir(cls.root)
92 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase:
93 namespace = f"namespace_{secrets.token_hex(8).lower()}"
94 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace)
96 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase:
97 return PostgresqlDatabase.fromUri(
98 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable
99 )
101 @contextmanager
102 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase:
103 yield self.getNewConnection(database, writeable=False)
105 def testNameShrinking(self):
106 """Test that too-long names for database entities other than tables
107 and columns (which we preserve, and just expect to fit) are shrunk.
108 """
109 db = self.makeEmptyDatabase(origin=1)
110 with db.declareStaticTables(create=True) as context:
111 # Table and field names are each below the 63-char limit even when
112 # accounting for the prefix, but their combination (which will
113 # appear in sequences and constraints) is not.
114 tableName = "a_table_with_a_very_very_long_42_char_name"
115 fieldName1 = "a_column_with_a_very_very_long_43_char_name"
116 fieldName2 = "another_column_with_a_very_very_long_49_char_name"
117 context.addTable(
118 tableName,
119 ddl.TableSpec(
120 fields=[
121 ddl.FieldSpec(
122 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
123 ),
124 ddl.FieldSpec(
125 fieldName2,
126 dtype=sqlalchemy.String,
127 length=16,
128 nullable=False,
129 ),
130 ],
131 unique={(fieldName2,)},
132 ),
133 )
134 # Add another table, this time dynamically, with a foreign key to the
135 # first table.
136 db.ensureTableExists(
137 tableName + "_b",
138 ddl.TableSpec(
139 fields=[
140 ddl.FieldSpec(
141 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
142 ),
143 ddl.FieldSpec(
144 fieldName2 + "_b",
145 dtype=sqlalchemy.String,
146 length=16,
147 nullable=False,
148 ),
149 ],
150 foreignKeys=[
151 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)),
152 ],
153 ),
154 )
156 def test_RangeTimespanType(self):
157 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
158 offset = astropy.time.TimeDelta(60, format="sec")
159 timestamps = [start + offset * n for n in range(3)]
160 timespans = [Timespan(begin=None, end=None)]
161 timespans.extend(Timespan(begin=None, end=t) for t in timestamps)
162 timespans.extend(Timespan(begin=t, end=None) for t in timestamps)
163 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2))
164 db = self.makeEmptyDatabase(origin=1)
165 with db.declareStaticTables(create=True) as context:
166 tbl = context.addTable(
167 "tbl",
168 ddl.TableSpec(
169 fields=[
170 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
171 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType),
172 ],
173 ),
174 )
175 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)]
176 db.insert(tbl, *rows)
178 # Test basic round-trip through database.
179 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result:
180 self.assertEqual(rows, [row._asdict() for row in sql_result])
182 # Test that Timespan's Python methods are consistent with our usage of
183 # half-open ranges and PostgreSQL operators on ranges.
184 def subquery(alias: str) -> sqlalchemy.sql.FromClause:
185 return (
186 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan"))
187 .select_from(tbl)
188 .alias(alias)
189 )
191 sq1 = subquery("sq1")
192 sq2 = subquery("sq2")
193 query = sqlalchemy.sql.select(
194 sq1.columns.id.label("n1"),
195 sq2.columns.id.label("n2"),
196 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"),
197 )
199 # `columns` is deprecated since 1.4, but
200 # `selected_columns` method did not exist in 1.3.
201 if hasattr(query, "selected_columns"):
202 columns = query.selected_columns
203 else:
204 columns = query.columns
206 # SQLAlchemy issues a warning about cartesian product of two tables,
207 # which we do intentionally. Disable that warning temporarily.
208 with warnings.catch_warnings():
209 warnings.filterwarnings(
210 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning
211 )
212 with db.query(query) as sql_result:
213 dbResults = {
214 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings()
215 }
217 pyResults = {
218 (n1, n2): t1.overlaps(t2)
219 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans))
220 }
221 self.assertEqual(pyResults, dbResults)
224@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
225class PostgresqlRegistryTests(RegistryTests):
226 """Tests for `Registry` backed by a PostgreSQL database.
228 Notes
229 -----
230 This is not a subclass of `unittest.TestCase` but to avoid repetition it
231 defines methods that override `unittest.TestCase` methods. To make this
232 work subclasses have to have this class first in the bases list.
233 """
235 @classmethod
236 def setUpClass(cls):
237 cls.root = makeTestTempDir(TESTDIR)
238 cls.server = _startServer(cls.root)
240 @classmethod
241 def tearDownClass(cls):
242 # Clean up any lingering SQLAlchemy engines/connections
243 # so they're closed before we shut down the server.
244 gc.collect()
245 cls.server.stop()
246 removeTestTempDir(cls.root)
248 @classmethod
249 def getDataDir(cls) -> str:
250 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry"))
252 def makeRegistry(self, share_repo_with: SqlRegistry | None = None) -> SqlRegistry:
253 if share_repo_with is None:
254 namespace = f"namespace_{secrets.token_hex(8).lower()}"
255 else:
256 namespace = share_repo_with._db.namespace
257 config = self.makeRegistryConfig()
258 config["db"] = self.server.url()
259 config["namespace"] = namespace
260 if share_repo_with is None:
261 return _RegistryFactory(config).create_from_config()
262 else:
263 return _RegistryFactory(config).from_config()
266class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
267 """Tests for `Registry` backed by a PostgreSQL database.
269 This test case uses NameKeyCollectionManager and
270 ByDimensionsDatasetRecordStorageManagerUUID.
271 """
273 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
274 datasetsManager = (
275 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
276 )
279class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
280 """Tests for `Registry` backed by a PostgreSQL database.
282 This test case uses SynthIntKeyCollectionManager and
283 ByDimensionsDatasetRecordStorageManagerUUID.
284 """
286 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
287 datasetsManager = (
288 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
289 )
292if __name__ == "__main__":
293 unittest.main()