Coverage for tests/test_postgresql.py: 39%
115 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 07:59 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 07:59 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30import gc
31import itertools
32import os
33import secrets
34import unittest
35import warnings
36from contextlib import contextmanager
38import astropy.time
40try:
41 # It's possible but silly to have testing.postgresql installed without
42 # having the postgresql server installed (because then nothing in
43 # testing.postgresql would work), so we use the presence of that module
44 # to test whether we can expect the server to be available.
45 import testing.postgresql
46except ImportError:
47 testing = None
49import sqlalchemy
50from lsst.daf.butler import Timespan, ddl
51from lsst.daf.butler.registry import _ButlerRegistry, _RegistryFactory
53try:
54 from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType
55except ImportError:
56 testing = None
57from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests
58from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
60TESTDIR = os.path.abspath(os.path.dirname(__file__))
63def _startServer(root):
64 """Start a PostgreSQL server and create a database within it, returning
65 an object encapsulating both.
66 """
67 server = testing.postgresql.Postgresql(base_dir=root)
68 engine = sqlalchemy.engine.create_engine(server.url())
69 with engine.begin() as connection:
70 connection.execute(sqlalchemy.text("CREATE EXTENSION btree_gist;"))
71 return server
74@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
75class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests):
76 """Test a postgres Registry."""
78 @classmethod
79 def setUpClass(cls):
80 cls.root = makeTestTempDir(TESTDIR)
81 cls.server = _startServer(cls.root)
83 @classmethod
84 def tearDownClass(cls):
85 # Clean up any lingering SQLAlchemy engines/connections
86 # so they're closed before we shut down the server.
87 gc.collect()
88 cls.server.stop()
89 removeTestTempDir(cls.root)
91 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase:
92 namespace = f"namespace_{secrets.token_hex(8).lower()}"
93 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace)
95 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase:
96 return PostgresqlDatabase.fromUri(
97 origin=database.origin, uri=self.server.url(), namespace=database.namespace, writeable=writeable
98 )
100 @contextmanager
101 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase:
102 yield self.getNewConnection(database, writeable=False)
104 def testNameShrinking(self):
105 """Test that too-long names for database entities other than tables
106 and columns (which we preserve, and just expect to fit) are shrunk.
107 """
108 db = self.makeEmptyDatabase(origin=1)
109 with db.declareStaticTables(create=True) as context:
110 # Table and field names are each below the 63-char limit even when
111 # accounting for the prefix, but their combination (which will
112 # appear in sequences and constraints) is not.
113 tableName = "a_table_with_a_very_very_long_42_char_name"
114 fieldName1 = "a_column_with_a_very_very_long_43_char_name"
115 fieldName2 = "another_column_with_a_very_very_long_49_char_name"
116 context.addTable(
117 tableName,
118 ddl.TableSpec(
119 fields=[
120 ddl.FieldSpec(
121 fieldName1, dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
122 ),
123 ddl.FieldSpec(
124 fieldName2,
125 dtype=sqlalchemy.String,
126 length=16,
127 nullable=False,
128 ),
129 ],
130 unique={(fieldName2,)},
131 ),
132 )
133 # Add another table, this time dynamically, with a foreign key to the
134 # first table.
135 db.ensureTableExists(
136 tableName + "_b",
137 ddl.TableSpec(
138 fields=[
139 ddl.FieldSpec(
140 fieldName1 + "_b", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True
141 ),
142 ddl.FieldSpec(
143 fieldName2 + "_b",
144 dtype=sqlalchemy.String,
145 length=16,
146 nullable=False,
147 ),
148 ],
149 foreignKeys=[
150 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)),
151 ],
152 ),
153 )
155 def test_RangeTimespanType(self):
156 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
157 offset = astropy.time.TimeDelta(60, format="sec")
158 timestamps = [start + offset * n for n in range(3)]
159 timespans = [Timespan(begin=None, end=None)]
160 timespans.extend(Timespan(begin=None, end=t) for t in timestamps)
161 timespans.extend(Timespan(begin=t, end=None) for t in timestamps)
162 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2))
163 db = self.makeEmptyDatabase(origin=1)
164 with db.declareStaticTables(create=True) as context:
165 tbl = context.addTable(
166 "tbl",
167 ddl.TableSpec(
168 fields=[
169 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
170 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType),
171 ],
172 ),
173 )
174 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)]
175 db.insert(tbl, *rows)
177 # Test basic round-trip through database.
178 with db.query(tbl.select().order_by(tbl.columns.id)) as sql_result:
179 self.assertEqual(rows, [row._asdict() for row in sql_result])
181 # Test that Timespan's Python methods are consistent with our usage of
182 # half-open ranges and PostgreSQL operators on ranges.
183 def subquery(alias: str) -> sqlalchemy.sql.FromClause:
184 return (
185 sqlalchemy.sql.select(tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan"))
186 .select_from(tbl)
187 .alias(alias)
188 )
190 sq1 = subquery("sq1")
191 sq2 = subquery("sq2")
192 query = sqlalchemy.sql.select(
193 sq1.columns.id.label("n1"),
194 sq2.columns.id.label("n2"),
195 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"),
196 )
198 # `columns` is deprecated since 1.4, but
199 # `selected_columns` method did not exist in 1.3.
200 if hasattr(query, "selected_columns"):
201 columns = query.selected_columns
202 else:
203 columns = query.columns
205 # SQLAlchemy issues a warning about cartesian product of two tables,
206 # which we do intentionally. Disable that warning temporarily.
207 with warnings.catch_warnings():
208 warnings.filterwarnings(
209 "ignore", message=".*cartesian product", category=sqlalchemy.exc.SAWarning
210 )
211 with db.query(query) as sql_result:
212 dbResults = {
213 (row[columns.n1], row[columns.n2]): row[columns.overlaps] for row in sql_result.mappings()
214 }
216 pyResults = {
217 (n1, n2): t1.overlaps(t2)
218 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans))
219 }
220 self.assertEqual(pyResults, dbResults)
223@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
224class PostgresqlRegistryTests(RegistryTests):
225 """Tests for `Registry` backed by a PostgreSQL database.
227 Notes
228 -----
229 This is not a subclass of `unittest.TestCase` but to avoid repetition it
230 defines methods that override `unittest.TestCase` methods. To make this
231 work subclasses have to have this class first in the bases list.
232 """
234 @classmethod
235 def setUpClass(cls):
236 cls.root = makeTestTempDir(TESTDIR)
237 cls.server = _startServer(cls.root)
239 @classmethod
240 def tearDownClass(cls):
241 # Clean up any lingering SQLAlchemy engines/connections
242 # so they're closed before we shut down the server.
243 gc.collect()
244 cls.server.stop()
245 removeTestTempDir(cls.root)
247 @classmethod
248 def getDataDir(cls) -> str:
249 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry"))
251 def makeRegistry(self, share_repo_with: _ButlerRegistry | None = None) -> _ButlerRegistry:
252 if share_repo_with is None:
253 namespace = f"namespace_{secrets.token_hex(8).lower()}"
254 else:
255 namespace = share_repo_with._db.namespace
256 config = self.makeRegistryConfig()
257 config["db"] = self.server.url()
258 config["namespace"] = namespace
259 if share_repo_with is None:
260 return _RegistryFactory(config).create_from_config()
261 else:
262 return _RegistryFactory(config).from_config()
265class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
266 """Tests for `Registry` backed by a PostgreSQL database.
268 This test case uses NameKeyCollectionManager and
269 ByDimensionsDatasetRecordStorageManagerUUID.
270 """
272 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
273 datasetsManager = (
274 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
275 )
278class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
279 """Tests for `Registry` backed by a PostgreSQL database.
281 This test case uses SynthIntKeyCollectionManager and
282 ByDimensionsDatasetRecordStorageManagerUUID.
283 """
285 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
286 datasetsManager = (
287 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
288 )
291if __name__ == "__main__":
292 unittest.main()