Coverage for tests/test_postgresql.py : 43%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23from contextlib import contextmanager
24import itertools
25import secrets
26import unittest
27import gc
28import warnings
30import astropy.time
31try:
32 # It's possible but silly to have testing.postgresql installed without
33 # having the postgresql server installed (because then nothing in
34 # testing.postgresql would work), so we use the presence of that module
35 # to test whether we can expect the server to be available.
36 import testing.postgresql
37except ImportError:
38 testing = None
40import sqlalchemy
42from lsst.daf.butler import ddl, Timespan
43from lsst.daf.butler.registry import Registry
44from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType
45from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests
46from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
48TESTDIR = os.path.abspath(os.path.dirname(__file__))
51def _startServer(root):
52 """Start a PostgreSQL server and create a database within it, returning
53 an object encapsulating both.
54 """
55 server = testing.postgresql.Postgresql(base_dir=root)
56 engine = sqlalchemy.engine.create_engine(server.url())
57 engine.execute("CREATE EXTENSION btree_gist;")
58 return server
61@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
62class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests):
64 @classmethod
65 def setUpClass(cls):
66 cls.root = makeTestTempDir(TESTDIR)
67 cls.server = _startServer(cls.root)
69 @classmethod
70 def tearDownClass(cls):
71 # Clean up any lingering SQLAlchemy engines/connections
72 # so they're closed before we shut down the server.
73 gc.collect()
74 cls.server.stop()
75 removeTestTempDir(cls.root)
77 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase:
78 namespace = f"namespace_{secrets.token_hex(8).lower()}"
79 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace)
81 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase:
82 return PostgresqlDatabase.fromUri(origin=database.origin, uri=self.server.url(),
83 namespace=database.namespace, writeable=writeable)
85 @contextmanager
86 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase:
87 yield self.getNewConnection(database, writeable=False)
89 def testNameShrinking(self):
90 """Test that too-long names for database entities other than tables
91 and columns (which we preserve, and just expect to fit) are shrunk.
92 """
93 db = self.makeEmptyDatabase(origin=1)
94 with db.declareStaticTables(create=True) as context:
95 # Table and field names are each below the 63-char limit even when
96 # accounting for the prefix, but their combination (which will
97 # appear in sequences and constraints) is not.
98 tableName = "a_table_with_a_very_very_long_42_char_name"
99 fieldName1 = "a_column_with_a_very_very_long_43_char_name"
100 fieldName2 = "another_column_with_a_very_very_long_49_char_name"
101 context.addTable(
102 tableName,
103 ddl.TableSpec(
104 fields=[
105 ddl.FieldSpec(
106 fieldName1,
107 dtype=sqlalchemy.BigInteger,
108 autoincrement=True,
109 primaryKey=True
110 ),
111 ddl.FieldSpec(
112 fieldName2,
113 dtype=sqlalchemy.String,
114 length=16,
115 nullable=False,
116 ),
117 ],
118 unique={(fieldName2,)},
119 )
120 )
121 # Add another table, this time dynamically, with a foreign key to the
122 # first table.
123 db.ensureTableExists(
124 tableName + "_b",
125 ddl.TableSpec(
126 fields=[
127 ddl.FieldSpec(
128 fieldName1 + "_b",
129 dtype=sqlalchemy.BigInteger,
130 autoincrement=True,
131 primaryKey=True
132 ),
133 ddl.FieldSpec(
134 fieldName2 + "_b",
135 dtype=sqlalchemy.String,
136 length=16,
137 nullable=False,
138 ),
139 ],
140 foreignKeys=[
141 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)),
142 ]
143 )
144 )
146 def test_RangeTimespanType(self):
147 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai")
148 offset = astropy.time.TimeDelta(60, format="sec")
149 timestamps = [start + offset*n for n in range(3)]
150 timespans = [Timespan(begin=None, end=None)]
151 timespans.extend(Timespan(begin=None, end=t) for t in timestamps)
152 timespans.extend(Timespan(begin=t, end=None) for t in timestamps)
153 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2))
154 db = self.makeEmptyDatabase(origin=1)
155 with db.declareStaticTables(create=True) as context:
156 tbl = context.addTable(
157 "tbl",
158 ddl.TableSpec(
159 fields=[
160 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
161 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType),
162 ],
163 )
164 )
165 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)]
166 db.insert(tbl, *rows)
168 # Test basic round-trip through database.
169 self.assertEqual(
170 rows,
171 [dict(row) for row in db.query(tbl.select().order_by(tbl.columns.id)).fetchall()]
172 )
174 # Test that Timespan's Python methods are consistent with our usage of
175 # half-open ranges and PostgreSQL operators on ranges.
176 def subquery(alias: str) -> sqlalchemy.sql.FromClause:
177 return sqlalchemy.sql.select(
178 [tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")]
179 ).select_from(
180 tbl
181 ).alias(alias)
182 sq1 = subquery("sq1")
183 sq2 = subquery("sq2")
184 query = sqlalchemy.sql.select([
185 sq1.columns.id.label("n1"),
186 sq2.columns.id.label("n2"),
187 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"),
188 ])
190 # `columns` is deprecated since 1.4, but
191 # `selected_columns` method did not exist in 1.3.
192 if hasattr(query, "selected_columns"):
193 columns = query.selected_columns
194 else:
195 columns = query.columns
197 # SQLAlchemy issues a warning about cartesian product of two tables,
198 # which we do intentionally. Disable that warning temporarily.
199 with warnings.catch_warnings():
200 warnings.filterwarnings("ignore", message=".*cartesian product",
201 category=sqlalchemy.exc.SAWarning)
202 dbResults = {
203 (row[columns.n1], row[columns.n2]): row[columns.overlaps]
204 for row in db.query(query)
205 }
207 pyResults = {
208 (n1, n2): t1.overlaps(t2)
209 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans))
210 }
211 self.assertEqual(pyResults, dbResults)
214@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
215class PostgresqlRegistryTests(RegistryTests):
216 """Tests for `Registry` backed by a PostgreSQL database.
218 Note
219 ----
220 This is not a subclass of `unittest.TestCase` but to avoid repetition it
221 defines methods that override `unittest.TestCase` methods. To make this
222 work sublasses have to have this class first in the bases list.
223 """
225 @classmethod
226 def setUpClass(cls):
227 cls.root = makeTestTempDir(TESTDIR)
228 cls.server = _startServer(cls.root)
230 @classmethod
231 def tearDownClass(cls):
232 # Clean up any lingering SQLAlchemy engines/connections
233 # so they're closed before we shut down the server.
234 gc.collect()
235 cls.server.stop()
236 removeTestTempDir(cls.root)
238 @classmethod
239 def getDataDir(cls) -> str:
240 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry"))
242 def makeRegistry(self) -> Registry:
243 namespace = f"namespace_{secrets.token_hex(8).lower()}"
244 config = self.makeRegistryConfig()
245 config["db"] = self.server.url()
246 config["namespace"] = namespace
247 return Registry.createFromConfig(config)
250class PostgresqlRegistryNameKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase):
251 """Tests for `Registry` backed by a PostgreSQL database.
253 This test case uses NameKeyCollectionManager and
254 ByDimensionsDatasetRecordStorageManager.
255 """
256 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
257 datasetsManager = \
258 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManager"
261class PostgresqlRegistrySynthIntKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase):
262 """Tests for `Registry` backed by a PostgreSQL database.
264 This test case uses SynthIntKeyCollectionManager and
265 ByDimensionsDatasetRecordStorageManager.
266 """
267 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
268 datasetsManager = \
269 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManager"
272class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
273 """Tests for `Registry` backed by a PostgreSQL database.
275 This test case uses NameKeyCollectionManager and
276 ByDimensionsDatasetRecordStorageManagerUUID.
277 """
278 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
279 datasetsManager = \
280 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
283class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
284 """Tests for `Registry` backed by a PostgreSQL database.
286 This test case uses SynthIntKeyCollectionManager and
287 ByDimensionsDatasetRecordStorageManagerUUID.
288 """
289 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
290 datasetsManager = \
291 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
294if __name__ == "__main__": 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true
295 unittest.main()