Coverage for python/lsst/daf/butler/registry/tests/_database.py : 11%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26from collections import namedtuple
27from typing import ContextManager
29import sqlalchemy
31from lsst.sphgeom import ConvexPolygon, UnitVector3d
32from ..interfaces import (
33 Database,
34 ReadOnlyDatabaseError,
35 DatabaseConflictError,
36 SchemaAlreadyDefinedError
37)
38from ...core import ddl
40StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
42STATIC_TABLE_SPECS = StaticTablesTuple(
43 a=ddl.TableSpec(
44 fields=[
45 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
46 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
47 ]
48 ),
49 b=ddl.TableSpec(
50 fields=[
51 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
52 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
53 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
54 ],
55 unique=[("name",)],
56 ),
57 c=ddl.TableSpec(
58 fields=[
59 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
60 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
61 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
62 ],
63 foreignKeys=[
64 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
65 ]
66 ),
67)
69DYNAMIC_TABLE_SPEC = ddl.TableSpec(
70 fields=[
71 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
72 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
73 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
74 ],
75 foreignKeys=[
76 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
77 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
78 ]
79)
82class DatabaseTests(ABC):
83 """Generic tests for the `Database` interface that can be subclassed to
84 generate tests for concrete implementations.
85 """
87 @abstractmethod
88 def makeEmptyDatabase(self, origin: int = 0) -> Database:
89 """Return an empty `Database` with the given origin, or an
90 automatically-generated one if ``origin`` is `None`.
91 """
92 raise NotImplementedError()
94 @abstractmethod
95 def asReadOnly(self, database: Database) -> ContextManager[Database]:
96 """Return a context manager for a read-only connection into the given
97 database.
99 The original database should be considered unusable within the context
100 but safe to use again afterwards (this allows the context manager to
101 block write access by temporarily changing user permissions to really
102 guarantee that write operations are not performed).
103 """
104 raise NotImplementedError()
106 @abstractmethod
107 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
108 """Return a new `Database` instance that points to the same underlying
109 storage as the given one.
110 """
111 raise NotImplementedError()
113 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
114 self.assertCountEqual(spec.fields.names, table.columns.keys())
115 # Checking more than this currently seems fragile, as it might restrict
116 # what Database implementations do; we don't care if the spec is
117 # actually preserved in terms of types and constraints as long as we
118 # can use the returned table as if it was.
120 def checkStaticSchema(self, tables: StaticTablesTuple):
121 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
122 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
123 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
125 def testDeclareStaticTables(self):
126 """Tests for `Database.declareStaticSchema` and the methods it
127 delegates to.
128 """
129 # Create the static schema in a new, empty database.
130 newDatabase = self.makeEmptyDatabase()
131 with newDatabase.declareStaticTables(create=True) as context:
132 tables = context.addTableTuple(STATIC_TABLE_SPECS)
133 self.checkStaticSchema(tables)
134 # Check that we can load that schema even from a read-only connection.
135 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
136 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
137 tables = context.addTableTuple(STATIC_TABLE_SPECS)
138 self.checkStaticSchema(tables)
140 def testDeclareStaticTablesTwice(self):
141 """Tests for `Database.declareStaticSchema` being called twice.
142 """
143 # Create the static schema in a new, empty database.
144 newDatabase = self.makeEmptyDatabase()
145 with newDatabase.declareStaticTables(create=True) as context:
146 tables = context.addTableTuple(STATIC_TABLE_SPECS)
147 self.checkStaticSchema(tables)
148 # Second time it should raise
149 with self.assertRaises(SchemaAlreadyDefinedError):
150 with newDatabase.declareStaticTables(create=True) as context:
151 tables = context.addTableTuple(STATIC_TABLE_SPECS)
152 # Check schema, it should still contain all tables, and maybe some
153 # extra.
154 with newDatabase.declareStaticTables(create=False) as context:
155 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
157 def testRepr(self):
158 """Test that repr does not return a generic thing."""
159 newDatabase = self.makeEmptyDatabase()
160 rep = repr(newDatabase)
161 # Check that stringification works and gives us something different
162 self.assertNotEqual(rep, str(newDatabase))
163 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
164 self.assertIn("://", rep)
166 def testDynamicTables(self):
167 """Tests for `Database.ensureTableExists` and
168 `Database.getExistingTable`.
169 """
170 # Need to start with the static schema.
171 newDatabase = self.makeEmptyDatabase()
172 with newDatabase.declareStaticTables(create=True) as context:
173 context.addTableTuple(STATIC_TABLE_SPECS)
174 # Try to ensure the dyamic table exists in a read-only version of that
175 # database, which should fail because we can't create it.
176 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
177 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
178 context.addTableTuple(STATIC_TABLE_SPECS)
179 with self.assertRaises(ReadOnlyDatabaseError):
180 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
181 # Just getting the dynamic table before it exists should return None.
182 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
183 # Ensure the new table exists back in the original database, which
184 # should create it.
185 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
186 self.checkTable(DYNAMIC_TABLE_SPEC, table)
187 # Ensuring that it exists should just return the exact same table
188 # instance again.
189 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
190 # Try again from the read-only database.
191 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
192 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
193 context.addTableTuple(STATIC_TABLE_SPECS)
194 # Just getting the dynamic table should now work...
195 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
196 # ...as should ensuring that it exists, since it now does.
197 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
198 self.checkTable(DYNAMIC_TABLE_SPEC, table)
199 # Trying to get the table with a different specification (at least
200 # in terms of what columns are present) should raise.
201 with self.assertRaises(DatabaseConflictError):
202 newDatabase.ensureTableExists(
203 "d",
204 ddl.TableSpec(
205 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
206 )
207 )
208 # Calling ensureTableExists inside a transaction block is an error,
209 # even if it would do nothing.
210 with newDatabase.transaction():
211 with self.assertRaises(AssertionError):
212 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
214 def testSchemaSeparation(self):
215 """Test that creating two different `Database` instances allows us
216 to create different tables with the same name in each.
217 """
218 db1 = self.makeEmptyDatabase(origin=1)
219 with db1.declareStaticTables(create=True) as context:
220 tables = context.addTableTuple(STATIC_TABLE_SPECS)
221 self.checkStaticSchema(tables)
223 db2 = self.makeEmptyDatabase(origin=2)
224 # Make the DDL here intentionally different so we'll definitely
225 # notice if db1 and db2 are pointing at the same schema.
226 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
227 with db2.declareStaticTables(create=True) as context:
228 # Make the DDL here intentionally different so we'll definitely
229 # notice if db1 and db2 are pointing at the same schema.
230 table = context.addTable("a", spec)
231 self.checkTable(spec, table)
233 def testInsertQueryDelete(self):
234 """Test the `Database.insert`, `Database.query`, and `Database.delete`
235 methods, as well as the `Base64Region` type and the ``onDelete``
236 argument to `ddl.ForeignKeySpec`.
237 """
238 db = self.makeEmptyDatabase(origin=1)
239 with db.declareStaticTables(create=True) as context:
240 tables = context.addTableTuple(STATIC_TABLE_SPECS)
241 # Insert a single, non-autoincrement row that contains a region and
242 # query to get it back.
243 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
244 row = {"name": "a1", "region": region}
245 db.insert(tables.a, row)
246 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row])
247 # Insert multiple autoincrement rows but do not try to get the IDs
248 # back immediately.
249 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
250 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()]
251 self.assertEqual(len(results), 2)
252 for row in results:
253 self.assertIn(row["name"], ("b1", "b2"))
254 self.assertIsInstance(row["id"], int)
255 self.assertGreater(results[1]["id"], results[0]["id"])
256 # Insert multiple autoincrement rows and get the IDs back from insert.
257 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
258 ids = db.insert(tables.b, *rows, returnIds=True)
259 results = [
260 dict(r) for r in db.query(
261 tables.b.select().where(tables.b.columns.id > results[1]["id"])
262 ).fetchall()
263 ]
264 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
265 self.assertCountEqual(results, expected)
266 self.assertTrue(all(result["id"] is not None for result in results))
267 # Insert multiple rows into a table with an autoincrement+origin
268 # primary key, then use the returned IDs to insert into a dynamic
269 # table.
270 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
271 {"origin": db.origin, "b_id": None}]
272 ids = db.insert(tables.c, *rows, returnIds=True)
273 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
274 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
275 self.assertCountEqual(results, expected)
276 self.assertTrue(all(result["id"] is not None for result in results))
277 # Add the dynamic table.
278 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
279 # Insert into it.
280 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
281 db.insert(d, *rows)
282 results = [dict(r) for r in db.query(d.select()).fetchall()]
283 self.assertCountEqual(rows, results)
284 # Insert multiple rows into a table with an autoincrement+origin
285 # primary key (this is especially tricky for SQLite, but good to test
286 # for all DBs), but pass in a value for the autoincrement key.
287 # For extra complexity, we re-use the autoincrement value with a
288 # different value for origin.
289 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
290 {"id": 700, "origin": 60, "b_id": None},
291 {"id": 1, "origin": 60, "b_id": None}]
292 db.insert(tables.c, *rows2)
293 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
294 self.assertCountEqual(results, expected + rows2)
295 self.assertTrue(all(result["id"] is not None for result in results))
297 # Define 'SELECT COUNT(*)' query for later use.
298 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()])
299 # Get the values we inserted into table b.
300 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()]
301 # Remove two row from table b by ID.
302 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
303 self.assertEqual(n, 2)
304 # Remove the other two rows from table b by name.
305 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
306 self.assertEqual(n, 2)
307 # There should now be no rows in table b.
308 self.assertEqual(
309 db.query(count.select_from(tables.b)).scalar(),
310 0
311 )
312 # All b_id values in table c should now be NULL, because there's an
313 # onDelete='SET NULL' foreign key.
314 self.assertEqual(
315 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
316 0
317 )
318 # Remove all rows in table a (there's only one); this should remove all
319 # rows in d due to onDelete='CASCADE'.
320 n = db.delete(tables.a, [])
321 self.assertEqual(n, 1)
322 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
323 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
325 def testUpdate(self):
326 """Tests for `Database.update`.
327 """
328 db = self.makeEmptyDatabase(origin=1)
329 with db.declareStaticTables(create=True) as context:
330 tables = context.addTableTuple(STATIC_TABLE_SPECS)
331 # Insert two rows into table a, both without regions.
332 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
333 # Update one of the rows with a region.
334 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
335 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
336 self.assertEqual(n, 1)
337 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a)
338 self.assertCountEqual(
339 [dict(r) for r in db.query(sql).fetchall()],
340 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
341 )
343 def testSync(self):
344 """Tests for `Database.sync`.
345 """
346 db = self.makeEmptyDatabase(origin=1)
347 with db.declareStaticTables(create=True) as context:
348 tables = context.addTableTuple(STATIC_TABLE_SPECS)
349 # Insert a row with sync, because it doesn't exist yet.
350 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
351 self.assertTrue(inserted)
352 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
353 [dict(r) for r in db.query(tables.b.select()).fetchall()])
354 # Repeat that operation, which should do nothing but return the
355 # requested values.
356 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
357 self.assertFalse(inserted)
358 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
359 [dict(r) for r in db.query(tables.b.select()).fetchall()])
360 # Repeat the operation without the 'extra' arg, which should also just
361 # return the existing row.
362 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
363 self.assertFalse(inserted)
364 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
365 [dict(r) for r in db.query(tables.b.select()).fetchall()])
366 # Repeat the operation with a different value in 'extra'. That still
367 # shouldn't be an error, because 'extra' is only used if we really do
368 # insert. Also drop the 'returning' argument.
369 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
370 self.assertFalse(inserted)
371 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
372 [dict(r) for r in db.query(tables.b.select()).fetchall()])
373 # Repeat the operation with the correct value in 'compared' instead of
374 # 'extra'.
375 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
376 self.assertFalse(inserted)
377 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
378 [dict(r) for r in db.query(tables.b.select()).fetchall()])
379 # Repeat the operation with an incorrect value in 'compared'; this
380 # should raise.
381 with self.assertRaises(DatabaseConflictError):
382 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
383 # Try to sync inside a transaction. That's always an error, regardless
384 # of whether there would be an insertion or not.
385 with self.assertRaises(AssertionError):
386 with db.transaction():
387 db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10})
388 with self.assertRaises(AssertionError):
389 with db.transaction():
390 db.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
391 # Try to sync in a read-only database. This should work if and only
392 # if the matching row already exists.
393 with self.asReadOnly(db) as rodb:
394 with rodb.declareStaticTables(create=False) as context:
395 tables = context.addTableTuple(STATIC_TABLE_SPECS)
396 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
397 self.assertFalse(inserted)
398 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
399 [dict(r) for r in rodb.query(tables.b.select()).fetchall()])
400 with self.assertRaises(ReadOnlyDatabaseError):
401 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
403 def testReplace(self):
404 """Tests for `Database.replace`.
405 """
406 db = self.makeEmptyDatabase(origin=1)
407 with db.declareStaticTables(create=True) as context:
408 tables = context.addTableTuple(STATIC_TABLE_SPECS)
409 # Use 'replace' to insert a single row that contains a region and
410 # query to get it back.
411 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
412 row1 = {"name": "a1", "region": region}
413 db.replace(tables.a, row1)
414 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
415 # Insert another row without a region.
416 row2 = {"name": "a2", "region": None}
417 db.replace(tables.a, row2)
418 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
419 # Use replace to re-insert both of those rows again, which should do
420 # nothing.
421 db.replace(tables.a, row1, row2)
422 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
423 # Replace row1 with a row with no region, while reinserting row2.
424 row1a = {"name": "a1", "region": None}
425 db.replace(tables.a, row1a, row2)
426 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2])
427 # Replace both rows, returning row1 to its original state, while adding
428 # a new one. Pass them in in a different order.
429 row2a = {"name": "a2", "region": region}
430 row3 = {"name": "a3", "region": None}
431 db.replace(tables.a, row3, row2a, row1)
432 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3])