Coverage for tests/test_postgresql.py : 45%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import os
23from contextlib import contextmanager
24import itertools
25import secrets
26import unittest
27import gc
29import astropy.time
30try:
31 # It's possible but silly to have testing.postgresql installed without
32 # having the postgresql server installed (because then nothing in
33 # testing.postgresql would work), so we use the presence of that module
34 # to test whether we can expect the server to be available.
35 import testing.postgresql
36except ImportError:
37 testing = None
39import sqlalchemy
41from lsst.daf.butler import ddl, Timespan
42from lsst.daf.butler.registry import Registry
43from lsst.daf.butler.registry.databases.postgresql import PostgresqlDatabase, _RangeTimespanType
44from lsst.daf.butler.registry.tests import DatabaseTests, RegistryTests
45from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir
47TESTDIR = os.path.abspath(os.path.dirname(__file__))
50def _startServer(root):
51 """Start a PostgreSQL server and create a database within it, returning
52 an object encapsulating both.
53 """
54 server = testing.postgresql.Postgresql(base_dir=root)
55 engine = sqlalchemy.engine.create_engine(server.url())
56 engine.execute("CREATE EXTENSION btree_gist;")
57 return server
60@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
61class PostgresqlDatabaseTestCase(unittest.TestCase, DatabaseTests):
63 @classmethod
64 def setUpClass(cls):
65 cls.root = makeTestTempDir(TESTDIR)
66 cls.server = _startServer(cls.root)
68 @classmethod
69 def tearDownClass(cls):
70 # Clean up any lingering SQLAlchemy engines/connections
71 # so they're closed before we shut down the server.
72 gc.collect()
73 cls.server.stop()
74 removeTestTempDir(cls.root)
76 def makeEmptyDatabase(self, origin: int = 0) -> PostgresqlDatabase:
77 namespace = f"namespace_{secrets.token_hex(8).lower()}"
78 return PostgresqlDatabase.fromUri(origin=origin, uri=self.server.url(), namespace=namespace)
80 def getNewConnection(self, database: PostgresqlDatabase, *, writeable: bool) -> PostgresqlDatabase:
81 return PostgresqlDatabase.fromUri(origin=database.origin, uri=self.server.url(),
82 namespace=database.namespace, writeable=writeable)
84 @contextmanager
85 def asReadOnly(self, database: PostgresqlDatabase) -> PostgresqlDatabase:
86 yield self.getNewConnection(database, writeable=False)
88 def testNameShrinking(self):
89 """Test that too-long names for database entities other than tables
90 and columns (which we preserve, and just expect to fit) are shrunk.
91 """
92 db = self.makeEmptyDatabase(origin=1)
93 with db.declareStaticTables(create=True) as context:
94 # Table and field names are each below the 63-char limit even when
95 # accounting for the prefix, but their combination (which will
96 # appear in sequences and constraints) is not.
97 tableName = "a_table_with_a_very_very_long_42_char_name"
98 fieldName1 = "a_column_with_a_very_very_long_43_char_name"
99 fieldName2 = "another_column_with_a_very_very_long_49_char_name"
100 context.addTable(
101 tableName,
102 ddl.TableSpec(
103 fields=[
104 ddl.FieldSpec(
105 fieldName1,
106 dtype=sqlalchemy.BigInteger,
107 autoincrement=True,
108 primaryKey=True
109 ),
110 ddl.FieldSpec(
111 fieldName2,
112 dtype=sqlalchemy.String,
113 length=16,
114 nullable=False,
115 ),
116 ],
117 unique={(fieldName2,)},
118 )
119 )
120 # Add another table, this time dynamically, with a foreign key to the
121 # first table.
122 db.ensureTableExists(
123 tableName + "_b",
124 ddl.TableSpec(
125 fields=[
126 ddl.FieldSpec(
127 fieldName1 + "_b",
128 dtype=sqlalchemy.BigInteger,
129 autoincrement=True,
130 primaryKey=True
131 ),
132 ddl.FieldSpec(
133 fieldName2 + "_b",
134 dtype=sqlalchemy.String,
135 length=16,
136 nullable=False,
137 ),
138 ],
139 foreignKeys=[
140 ddl.ForeignKeySpec(tableName, source=(fieldName2 + "_b",), target=(fieldName2,)),
141 ]
142 )
143 )
145 def test_RangeTimespanType(self):
146 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai")
147 offset = astropy.time.TimeDelta(60, format="sec")
148 timestamps = [start + offset*n for n in range(3)]
149 timespans = [Timespan(begin=None, end=None)]
150 timespans.extend(Timespan(begin=None, end=t) for t in timestamps)
151 timespans.extend(Timespan(begin=t, end=None) for t in timestamps)
152 timespans.extend(Timespan(begin=a, end=b) for a, b in itertools.combinations(timestamps, 2))
153 db = self.makeEmptyDatabase(origin=1)
154 with db.declareStaticTables(create=True) as context:
155 tbl = context.addTable(
156 "tbl",
157 ddl.TableSpec(
158 fields=[
159 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
160 ddl.FieldSpec(name="timespan", dtype=_RangeTimespanType),
161 ],
162 )
163 )
164 rows = [{"id": n, "timespan": t} for n, t in enumerate(timespans)]
165 db.insert(tbl, *rows)
167 # Test basic round-trip through database.
168 self.assertEqual(
169 rows,
170 [dict(row) for row in db.query(tbl.select().order_by(tbl.columns.id)).fetchall()]
171 )
173 # Test that Timespan's Python methods are consistent with our usage of
174 # half-open ranges and PostgreSQL operators on ranges.
175 def subquery(alias: str) -> sqlalchemy.sql.FromClause:
176 return sqlalchemy.sql.select(
177 [tbl.columns.id.label("id"), tbl.columns.timespan.label("timespan")]
178 ).select_from(
179 tbl
180 ).alias(alias)
181 sq1 = subquery("sq1")
182 sq2 = subquery("sq2")
183 query = sqlalchemy.sql.select([
184 sq1.columns.id.label("n1"),
185 sq2.columns.id.label("n2"),
186 sq1.columns.timespan.overlaps(sq2.columns.timespan).label("overlaps"),
187 ])
189 dbResults = {
190 (row[query.columns.n1], row[query.columns.n2]): row[query.columns.overlaps]
191 for row in db.query(query)
192 }
193 pyResults = {
194 (n1, n2): t1.overlaps(t2)
195 for (n1, t1), (n2, t2) in itertools.product(enumerate(timespans), enumerate(timespans))
196 }
197 self.assertEqual(pyResults, dbResults)
200@unittest.skipUnless(testing is not None, "testing.postgresql module not found")
201class PostgresqlRegistryTests(RegistryTests):
202 """Tests for `Registry` backed by a PostgreSQL database.
204 Note
205 ----
206 This is not a subclass of `unittest.TestCase` but to avoid repetition it
207 defines methods that override `unittest.TestCase` methods. To make this
208 work sublasses have to have this class first in the bases list.
209 """
211 @classmethod
212 def setUpClass(cls):
213 cls.root = makeTestTempDir(TESTDIR)
214 cls.server = _startServer(cls.root)
216 @classmethod
217 def tearDownClass(cls):
218 # Clean up any lingering SQLAlchemy engines/connections
219 # so they're closed before we shut down the server.
220 gc.collect()
221 cls.server.stop()
222 removeTestTempDir(cls.root)
224 @classmethod
225 def getDataDir(cls) -> str:
226 return os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry"))
228 def makeRegistry(self) -> Registry:
229 namespace = f"namespace_{secrets.token_hex(8).lower()}"
230 config = self.makeRegistryConfig()
231 config["db"] = self.server.url()
232 config["namespace"] = namespace
233 return Registry.createFromConfig(config)
236class PostgresqlRegistryNameKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase):
237 """Tests for `Registry` backed by a PostgreSQL database.
239 This test case uses NameKeyCollectionManager and
240 ByDimensionsDatasetRecordStorageManager.
241 """
242 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
243 datasetsManager = \
244 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManager"
247class PostgresqlRegistrySynthIntKeyCollMgrTestCase(PostgresqlRegistryTests, unittest.TestCase):
248 """Tests for `Registry` backed by a PostgreSQL database.
250 This test case uses SynthIntKeyCollectionManager and
251 ByDimensionsDatasetRecordStorageManager.
252 """
253 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
254 datasetsManager = \
255 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManager"
258class PostgresqlRegistryNameKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
259 """Tests for `Registry` backed by a PostgreSQL database.
261 This test case uses NameKeyCollectionManager and
262 ByDimensionsDatasetRecordStorageManagerUUID.
263 """
264 collectionsManager = "lsst.daf.butler.registry.collections.nameKey.NameKeyCollectionManager"
265 datasetsManager = \
266 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
269class PostgresqlRegistrySynthIntKeyCollMgrUUIDTestCase(PostgresqlRegistryTests, unittest.TestCase):
270 """Tests for `Registry` backed by a PostgreSQL database.
272 This test case uses SynthIntKeyCollectionManager and
273 ByDimensionsDatasetRecordStorageManagerUUID.
274 """
275 collectionsManager = "lsst.daf.butler.registry.collections.synthIntKey.SynthIntKeyCollectionManager"
276 datasetsManager = \
277 "lsst.daf.butler.registry.datasets.byDimensions.ByDimensionsDatasetRecordStorageManagerUUID"
280if __name__ == "__main__": 280 ↛ 281line 280 didn't jump to line 281, because the condition on line 280 was never true
281 unittest.main()