Coverage for python/lsst/daf/butler/registry/tests/_database.py : 11%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26from collections import namedtuple
27from typing import ContextManager
29import sqlalchemy
31from lsst.sphgeom import ConvexPolygon, UnitVector3d
32from ..interfaces import (
33 Database,
34 ReadOnlyDatabaseError,
35 DatabaseConflictError,
36)
37from ...core import ddl
39StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
41STATIC_TABLE_SPECS = StaticTablesTuple(
42 a=ddl.TableSpec(
43 fields=[
44 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
45 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
46 ]
47 ),
48 b=ddl.TableSpec(
49 fields=[
50 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
51 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
52 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
53 ],
54 unique=[("name",)],
55 ),
56 c=ddl.TableSpec(
57 fields=[
58 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
59 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
60 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
61 ],
62 foreignKeys=[
63 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
64 ]
65 ),
66)
68DYNAMIC_TABLE_SPEC = ddl.TableSpec(
69 fields=[
70 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
71 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
72 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
73 ],
74 foreignKeys=[
75 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
76 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
77 ]
78)
81class DatabaseTests(ABC):
82 """Generic tests for the `Database` interface that can be subclassed to
83 generate tests for concrete implementations.
84 """
86 @abstractmethod
87 def makeEmptyDatabase(self, origin: int = 0) -> Database:
88 """Return an empty `Database` with the given origin, or an
89 automatically-generated one if ``origin`` is `None`.
90 """
91 raise NotImplementedError()
93 @abstractmethod
94 def asReadOnly(self, database: Database) -> ContextManager[Database]:
95 """Return a context manager for a read-only connection into the given
96 database.
98 The original database should be considered unusable within the context
99 but safe to use again afterwards (this allows the context manager to
100 block write access by temporarily changing user permissions to really
101 guarantee that write operations are not performed).
102 """
103 raise NotImplementedError()
105 @abstractmethod
106 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
107 """Return a new `Database` instance that points to the same underlying
108 storage as the given one.
109 """
110 raise NotImplementedError()
112 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
113 self.assertCountEqual(spec.fields.names, table.columns.keys())
114 # Checking more than this currently seems fragile, as it might restrict
115 # what Database implementations do; we don't care if the spec is
116 # actually preserved in terms of types and constraints as long as we
117 # can use the returned table as if it was.
119 def checkStaticSchema(self, tables: StaticTablesTuple):
120 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
121 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
122 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
124 def testDeclareStaticTables(self):
125 """Tests for `Database.declareStaticSchema` and the methods it
126 delegates to.
127 """
128 # Create the static schema in a new, empty database.
129 newDatabase = self.makeEmptyDatabase()
130 with newDatabase.declareStaticTables(create=True) as context:
131 tables = context.addTableTuple(STATIC_TABLE_SPECS)
132 self.checkStaticSchema(tables)
133 # Check that we can load that schema even from a read-only connection.
134 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
135 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
136 tables = context.addTableTuple(STATIC_TABLE_SPECS)
137 self.checkStaticSchema(tables)
139 def testDynamicTables(self):
140 """Tests for `Database.ensureTableExists` and
141 `Database.getExistingTable`.
142 """
143 # Need to start with the static schema.
144 newDatabase = self.makeEmptyDatabase()
145 with newDatabase.declareStaticTables(create=True) as context:
146 context.addTableTuple(STATIC_TABLE_SPECS)
147 # Try to ensure the dyamic table exists in a read-only version of that
148 # database, which should fail because we can't create it.
149 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
150 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
151 context.addTableTuple(STATIC_TABLE_SPECS)
152 with self.assertRaises(ReadOnlyDatabaseError):
153 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
154 # Just getting the dynamic table before it exists should return None.
155 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
156 # Ensure the new table exists back in the original database, which
157 # should create it.
158 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
159 self.checkTable(DYNAMIC_TABLE_SPEC, table)
160 # Ensuring that it exists should just return the exact same table
161 # instance again.
162 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
163 # Try again from the read-only database.
164 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
165 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
166 context.addTableTuple(STATIC_TABLE_SPECS)
167 # Just getting the dynamic table should now work...
168 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
169 # ...as should ensuring that it exists, since it now does.
170 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
171 self.checkTable(DYNAMIC_TABLE_SPEC, table)
172 # Trying to get the table with a different specification (at least
173 # in terms of what columns are present) should raise.
174 with self.assertRaises(DatabaseConflictError):
175 newDatabase.ensureTableExists(
176 "d",
177 ddl.TableSpec(
178 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
179 )
180 )
181 # Calling ensureTableExists inside a transaction block is an error,
182 # even if it would do nothing.
183 with newDatabase.transaction():
184 with self.assertRaises(AssertionError):
185 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
187 def testSchemaSeparation(self):
188 """Test that creating two different `Database` instances allows us
189 to create different tables with the same name in each.
190 """
191 db1 = self.makeEmptyDatabase(origin=1)
192 with db1.declareStaticTables(create=True) as context:
193 tables = context.addTableTuple(STATIC_TABLE_SPECS)
194 self.checkStaticSchema(tables)
196 db2 = self.makeEmptyDatabase(origin=2)
197 # Make the DDL here intentionally different so we'll definitely
198 # notice if db1 and db2 are pointing at the same schema.
199 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
200 with db2.declareStaticTables(create=True) as context:
201 # Make the DDL here intentionally different so we'll definitely
202 # notice if db1 and db2 are pointing at the same schema.
203 table = context.addTable("a", spec)
204 self.checkTable(spec, table)
206 def testInsertQueryDelete(self):
207 """Test the `Database.insert`, `Database.query`, and `Database.delete`
208 methods, as well as the `Base64Region` type and the ``onDelete``
209 argument to `ddl.ForeignKeySpec`.
210 """
211 db = self.makeEmptyDatabase(origin=1)
212 with db.declareStaticTables(create=True) as context:
213 tables = context.addTableTuple(STATIC_TABLE_SPECS)
214 # Insert a single, non-autoincrement row that contains a region and
215 # query to get it back.
216 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
217 row = {"name": "a1", "region": region}
218 db.insert(tables.a, row)
219 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row])
220 # Insert multiple autoincrement rows but do not try to get the IDs
221 # back immediately.
222 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
223 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()]
224 self.assertEqual(len(results), 2)
225 for row in results:
226 self.assertIn(row["name"], ("b1", "b2"))
227 self.assertIsInstance(row["id"], int)
228 self.assertGreater(results[1]["id"], results[0]["id"])
229 # Insert multiple autoincrement rows and get the IDs back from insert.
230 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
231 ids = db.insert(tables.b, *rows, returnIds=True)
232 results = [
233 dict(r) for r in db.query(
234 tables.b.select().where(tables.b.columns.id > results[1]["id"])
235 ).fetchall()
236 ]
237 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
238 self.assertCountEqual(results, expected)
239 self.assertTrue(all(result["id"] is not None for result in results))
240 # Insert multiple rows into a table with an autoincrement+origin
241 # primary key, then use the returned IDs to insert into a dynamic
242 # table.
243 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
244 {"origin": db.origin, "b_id": None}]
245 ids = db.insert(tables.c, *rows, returnIds=True)
246 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
247 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
248 self.assertCountEqual(results, expected)
249 self.assertTrue(all(result["id"] is not None for result in results))
250 # Add the dynamic table.
251 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
252 # Insert into it.
253 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
254 db.insert(d, *rows)
255 results = [dict(r) for r in db.query(d.select()).fetchall()]
256 self.assertCountEqual(rows, results)
257 # Insert multiple rows into a table with an autoincrement+origin
258 # primary key (this is especially tricky for SQLite, but good to test
259 # for all DBs), but pass in a value for the autoincrement key.
260 # For extra complexity, we re-use the autoincrement value with a
261 # different value for origin.
262 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
263 {"id": 700, "origin": 60, "b_id": None},
264 {"id": 1, "origin": 60, "b_id": None}]
265 db.insert(tables.c, *rows2)
266 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
267 self.assertCountEqual(results, expected + rows2)
268 self.assertTrue(all(result["id"] is not None for result in results))
270 # Define 'SELECT COUNT(*)' query for later use.
271 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()])
272 # Get the values we inserted into table b.
273 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()]
274 # Remove two row from table b by ID.
275 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
276 self.assertEqual(n, 2)
277 # Remove the other two rows from table b by name.
278 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
279 self.assertEqual(n, 2)
280 # There should now be no rows in table b.
281 self.assertEqual(
282 db.query(count.select_from(tables.b)).scalar(),
283 0
284 )
285 # All b_id values in table c should now be NULL, because there's an
286 # onDelete='SET NULL' foreign key.
287 self.assertEqual(
288 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
289 0
290 )
291 # Remove all rows in table a (there's only one); this should remove all
292 # rows in d due to onDelete='CASCADE'.
293 n = db.delete(tables.a, [])
294 self.assertEqual(n, 1)
295 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
296 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
298 def testUpdate(self):
299 """Tests for `Database.update`.
300 """
301 db = self.makeEmptyDatabase(origin=1)
302 with db.declareStaticTables(create=True) as context:
303 tables = context.addTableTuple(STATIC_TABLE_SPECS)
304 # Insert two rows into table a, both without regions.
305 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
306 # Update one of the rows with a region.
307 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
308 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
309 self.assertEqual(n, 1)
310 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a)
311 self.assertCountEqual(
312 [dict(r) for r in db.query(sql).fetchall()],
313 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
314 )
316 def testSync(self):
317 """Tests for `Database.sync`.
318 """
319 db = self.makeEmptyDatabase(origin=1)
320 with db.declareStaticTables(create=True) as context:
321 tables = context.addTableTuple(STATIC_TABLE_SPECS)
322 # Insert a row with sync, because it doesn't exist yet.
323 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
324 self.assertTrue(inserted)
325 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
326 [dict(r) for r in db.query(tables.b.select()).fetchall()])
327 # Repeat that operation, which should do nothing but return the
328 # requested values.
329 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
330 self.assertFalse(inserted)
331 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
332 [dict(r) for r in db.query(tables.b.select()).fetchall()])
333 # Repeat the operation without the 'extra' arg, which should also just
334 # return the existing row.
335 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
336 self.assertFalse(inserted)
337 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
338 [dict(r) for r in db.query(tables.b.select()).fetchall()])
339 # Repeat the operation with a different value in 'extra'. That still
340 # shouldn't be an error, because 'extra' is only used if we really do
341 # insert. Also drop the 'returning' argument.
342 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
343 self.assertFalse(inserted)
344 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
345 [dict(r) for r in db.query(tables.b.select()).fetchall()])
346 # Repeat the operation with the correct value in 'compared' instead of
347 # 'extra'.
348 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
349 self.assertFalse(inserted)
350 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
351 [dict(r) for r in db.query(tables.b.select()).fetchall()])
352 # Repeat the operation with an incorrect value in 'compared'; this
353 # should raise.
354 with self.assertRaises(DatabaseConflictError):
355 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
356 # Try to sync inside a transaction. That's always an error, regardless
357 # of whether there would be an insertion or not.
358 with self.assertRaises(AssertionError):
359 with db.transaction():
360 db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10})
361 with self.assertRaises(AssertionError):
362 with db.transaction():
363 db.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
364 # Try to sync in a read-only database. This should work if and only
365 # if the matching row already exists.
366 with self.asReadOnly(db) as rodb:
367 with rodb.declareStaticTables(create=False) as context:
368 tables = context.addTableTuple(STATIC_TABLE_SPECS)
369 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
370 self.assertFalse(inserted)
371 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
372 [dict(r) for r in rodb.query(tables.b.select()).fetchall()])
373 with self.assertRaises(ReadOnlyDatabaseError):
374 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
376 def testReplace(self):
377 """Tests for `Database.replace`.
378 """
379 db = self.makeEmptyDatabase(origin=1)
380 with db.declareStaticTables(create=True) as context:
381 tables = context.addTableTuple(STATIC_TABLE_SPECS)
382 # Use 'replace' to insert a single row that contains a region and
383 # query to get it back.
384 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
385 row1 = {"name": "a1", "region": region}
386 db.replace(tables.a, row1)
387 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
388 # Insert another row without a region.
389 row2 = {"name": "a2", "region": None}
390 db.replace(tables.a, row2)
391 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
392 # Use replace to re-insert both of those rows again, which should do
393 # nothing.
394 db.replace(tables.a, row1, row2)
395 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
396 # Replace row1 with a row with no region, while reinserting row2.
397 row1a = {"name": "a1", "region": None}
398 db.replace(tables.a, row1a, row2)
399 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2])
400 # Replace both rows, returning row1 to its original state, while adding
401 # a new one. Pass them in in a different order.
402 row2a = {"name": "a2", "region": region}
403 row3 = {"name": "a3", "region": None}
404 db.replace(tables.a, row3, row2a, row1)
405 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3])