Coverage for python/lsst/daf/butler/registry/tests/_database.py : 7%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26import asyncio
27from collections import namedtuple
28from concurrent.futures import ThreadPoolExecutor
29import itertools
30from typing import ContextManager, Iterable, Set, Tuple
31import warnings
33import astropy.time
34import sqlalchemy
36from lsst.sphgeom import ConvexPolygon, UnitVector3d
37from ..interfaces import (
38 Database,
39 ReadOnlyDatabaseError,
40 DatabaseConflictError,
41 SchemaAlreadyDefinedError
42)
43from ...core import ddl, Timespan
45StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
47STATIC_TABLE_SPECS = StaticTablesTuple(
48 a=ddl.TableSpec(
49 fields=[
50 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
51 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
52 ]
53 ),
54 b=ddl.TableSpec(
55 fields=[
56 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
58 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
59 ],
60 unique=[("name",)],
61 ),
62 c=ddl.TableSpec(
63 fields=[
64 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
65 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
66 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
67 ],
68 foreignKeys=[
69 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
70 ]
71 ),
72)
74DYNAMIC_TABLE_SPEC = ddl.TableSpec(
75 fields=[
76 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
77 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
78 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
79 ],
80 foreignKeys=[
81 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
82 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
83 ]
84)
86TEMPORARY_TABLE_SPEC = ddl.TableSpec(
87 fields=[
88 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
89 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
90 ],
91)
94class DatabaseTests(ABC):
95 """Generic tests for the `Database` interface that can be subclassed to
96 generate tests for concrete implementations.
97 """
99 @abstractmethod
100 def makeEmptyDatabase(self, origin: int = 0) -> Database:
101 """Return an empty `Database` with the given origin, or an
102 automatically-generated one if ``origin`` is `None`.
103 """
104 raise NotImplementedError()
106 @abstractmethod
107 def asReadOnly(self, database: Database) -> ContextManager[Database]:
108 """Return a context manager for a read-only connection into the given
109 database.
111 The original database should be considered unusable within the context
112 but safe to use again afterwards (this allows the context manager to
113 block write access by temporarily changing user permissions to really
114 guarantee that write operations are not performed).
115 """
116 raise NotImplementedError()
118 @abstractmethod
119 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
120 """Return a new `Database` instance that points to the same underlying
121 storage as the given one.
122 """
123 raise NotImplementedError()
125 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
126 self.assertCountEqual(spec.fields.names, table.columns.keys())
127 # Checking more than this currently seems fragile, as it might restrict
128 # what Database implementations do; we don't care if the spec is
129 # actually preserved in terms of types and constraints as long as we
130 # can use the returned table as if it was.
132 def checkStaticSchema(self, tables: StaticTablesTuple):
133 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
134 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
135 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
137 def testDeclareStaticTables(self):
138 """Tests for `Database.declareStaticSchema` and the methods it
139 delegates to.
140 """
141 # Create the static schema in a new, empty database.
142 newDatabase = self.makeEmptyDatabase()
143 with newDatabase.declareStaticTables(create=True) as context:
144 tables = context.addTableTuple(STATIC_TABLE_SPECS)
145 self.checkStaticSchema(tables)
146 # Check that we can load that schema even from a read-only connection.
147 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
148 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
149 tables = context.addTableTuple(STATIC_TABLE_SPECS)
150 self.checkStaticSchema(tables)
152 def testDeclareStaticTablesTwice(self):
153 """Tests for `Database.declareStaticSchema` being called twice.
154 """
155 # Create the static schema in a new, empty database.
156 newDatabase = self.makeEmptyDatabase()
157 with newDatabase.declareStaticTables(create=True) as context:
158 tables = context.addTableTuple(STATIC_TABLE_SPECS)
159 self.checkStaticSchema(tables)
160 # Second time it should raise
161 with self.assertRaises(SchemaAlreadyDefinedError):
162 with newDatabase.declareStaticTables(create=True) as context:
163 tables = context.addTableTuple(STATIC_TABLE_SPECS)
164 # Check schema, it should still contain all tables, and maybe some
165 # extra.
166 with newDatabase.declareStaticTables(create=False) as context:
167 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
169 def testRepr(self):
170 """Test that repr does not return a generic thing."""
171 newDatabase = self.makeEmptyDatabase()
172 rep = repr(newDatabase)
173 # Check that stringification works and gives us something different
174 self.assertNotEqual(rep, str(newDatabase))
175 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
176 self.assertIn("://", rep)
178 def testDynamicTables(self):
179 """Tests for `Database.ensureTableExists` and
180 `Database.getExistingTable`.
181 """
182 # Need to start with the static schema.
183 newDatabase = self.makeEmptyDatabase()
184 with newDatabase.declareStaticTables(create=True) as context:
185 context.addTableTuple(STATIC_TABLE_SPECS)
186 # Try to ensure the dyamic table exists in a read-only version of that
187 # database, which should fail because we can't create it.
188 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
189 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
190 context.addTableTuple(STATIC_TABLE_SPECS)
191 with self.assertRaises(ReadOnlyDatabaseError):
192 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
193 # Just getting the dynamic table before it exists should return None.
194 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
195 # Ensure the new table exists back in the original database, which
196 # should create it.
197 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
198 self.checkTable(DYNAMIC_TABLE_SPEC, table)
199 # Ensuring that it exists should just return the exact same table
200 # instance again.
201 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
202 # Try again from the read-only database.
203 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
204 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
205 context.addTableTuple(STATIC_TABLE_SPECS)
206 # Just getting the dynamic table should now work...
207 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
208 # ...as should ensuring that it exists, since it now does.
209 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
210 self.checkTable(DYNAMIC_TABLE_SPEC, table)
211 # Trying to get the table with a different specification (at least
212 # in terms of what columns are present) should raise.
213 with self.assertRaises(DatabaseConflictError):
214 newDatabase.ensureTableExists(
215 "d",
216 ddl.TableSpec(
217 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
218 )
219 )
220 # Calling ensureTableExists inside a transaction block is an error,
221 # even if it would do nothing.
222 with newDatabase.transaction():
223 with self.assertRaises(AssertionError):
224 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
226 def testTemporaryTables(self):
227 """Tests for `Database.makeTemporaryTable`,
228 `Database.dropTemporaryTable`, and `Database.insert` with
229 the ``select`` argument.
230 """
231 # Need to start with the static schema; also insert some test data.
232 newDatabase = self.makeEmptyDatabase()
233 with newDatabase.declareStaticTables(create=True) as context:
234 static = context.addTableTuple(STATIC_TABLE_SPECS)
235 newDatabase.insert(static.a,
236 {"name": "a1", "region": None},
237 {"name": "a2", "region": None})
238 bIds = newDatabase.insert(static.b,
239 {"name": "b1", "value": 11},
240 {"name": "b2", "value": 12},
241 {"name": "b3", "value": 13},
242 returnIds=True)
243 # Create the table.
244 table1 = newDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
245 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
246 # Insert via a INSERT INTO ... SELECT query.
247 newDatabase.insert(
248 table1,
249 select=sqlalchemy.sql.select(
250 [static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")]
251 ).select_from(
252 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
253 ).where(
254 sqlalchemy.sql.and_(
255 static.a.columns.name == "a1",
256 static.b.columns.value <= 12,
257 )
258 )
259 )
260 # Check that the inserted rows are present.
261 self.assertCountEqual(
262 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
263 [dict(row) for row in newDatabase.query(table1.select())]
264 )
265 # Create another one via a read-only connection to the database.
266 # We _do_ allow temporary table modifications in read-only databases.
267 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
268 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
269 context.addTableTuple(STATIC_TABLE_SPECS)
270 table2 = existingReadOnlyDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
271 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
272 # Those tables should not be the same, despite having the same ddl.
273 self.assertIsNot(table1, table2)
274 # Do a slightly different insert into this table, to check that
275 # it works in a read-only database. This time we pass column
276 # names as a kwarg to insert instead of by labeling the columns in
277 # the select.
278 existingReadOnlyDatabase.insert(
279 table2,
280 select=sqlalchemy.sql.select(
281 [static.a.columns.name, static.b.columns.id]
282 ).select_from(
283 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
284 ).where(
285 sqlalchemy.sql.and_(
286 static.a.columns.name == "a2",
287 static.b.columns.value >= 12,
288 )
289 ),
290 names=["a_name", "b_id"],
291 )
292 # Check that the inserted rows are present.
293 self.assertCountEqual(
294 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
295 [dict(row) for row in existingReadOnlyDatabase.query(table2.select())]
296 )
297 # Drop the temporary table from the read-only DB. It's unspecified
298 # whether attempting to use it after this point is an error or just
299 # never returns any results, so we can't test what it does, only
300 # that it's not an error.
301 existingReadOnlyDatabase.dropTemporaryTable(table2)
302 # Drop the original temporary table.
303 newDatabase.dropTemporaryTable(table1)
305 def testSchemaSeparation(self):
306 """Test that creating two different `Database` instances allows us
307 to create different tables with the same name in each.
308 """
309 db1 = self.makeEmptyDatabase(origin=1)
310 with db1.declareStaticTables(create=True) as context:
311 tables = context.addTableTuple(STATIC_TABLE_SPECS)
312 self.checkStaticSchema(tables)
314 db2 = self.makeEmptyDatabase(origin=2)
315 # Make the DDL here intentionally different so we'll definitely
316 # notice if db1 and db2 are pointing at the same schema.
317 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
318 with db2.declareStaticTables(create=True) as context:
319 # Make the DDL here intentionally different so we'll definitely
320 # notice if db1 and db2 are pointing at the same schema.
321 table = context.addTable("a", spec)
322 self.checkTable(spec, table)
324 def testInsertQueryDelete(self):
325 """Test the `Database.insert`, `Database.query`, and `Database.delete`
326 methods, as well as the `Base64Region` type and the ``onDelete``
327 argument to `ddl.ForeignKeySpec`.
328 """
329 db = self.makeEmptyDatabase(origin=1)
330 with db.declareStaticTables(create=True) as context:
331 tables = context.addTableTuple(STATIC_TABLE_SPECS)
332 # Insert a single, non-autoincrement row that contains a region and
333 # query to get it back.
334 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
335 row = {"name": "a1", "region": region}
336 db.insert(tables.a, row)
337 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row])
338 # Insert multiple autoincrement rows but do not try to get the IDs
339 # back immediately.
340 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
341 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()]
342 self.assertEqual(len(results), 2)
343 for row in results:
344 self.assertIn(row["name"], ("b1", "b2"))
345 self.assertIsInstance(row["id"], int)
346 self.assertGreater(results[1]["id"], results[0]["id"])
347 # Insert multiple autoincrement rows and get the IDs back from insert.
348 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
349 ids = db.insert(tables.b, *rows, returnIds=True)
350 results = [
351 dict(r) for r in db.query(
352 tables.b.select().where(tables.b.columns.id > results[1]["id"])
353 ).fetchall()
354 ]
355 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
356 self.assertCountEqual(results, expected)
357 self.assertTrue(all(result["id"] is not None for result in results))
358 # Insert multiple rows into a table with an autoincrement+origin
359 # primary key, then use the returned IDs to insert into a dynamic
360 # table.
361 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
362 {"origin": db.origin, "b_id": None}]
363 ids = db.insert(tables.c, *rows, returnIds=True)
364 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
365 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
366 self.assertCountEqual(results, expected)
367 self.assertTrue(all(result["id"] is not None for result in results))
368 # Add the dynamic table.
369 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
370 # Insert into it.
371 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
372 db.insert(d, *rows)
373 results = [dict(r) for r in db.query(d.select()).fetchall()]
374 self.assertCountEqual(rows, results)
375 # Insert multiple rows into a table with an autoincrement+origin
376 # primary key (this is especially tricky for SQLite, but good to test
377 # for all DBs), but pass in a value for the autoincrement key.
378 # For extra complexity, we re-use the autoincrement value with a
379 # different value for origin.
380 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
381 {"id": 700, "origin": 60, "b_id": None},
382 {"id": 1, "origin": 60, "b_id": None}]
383 db.insert(tables.c, *rows2)
384 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
385 self.assertCountEqual(results, expected + rows2)
386 self.assertTrue(all(result["id"] is not None for result in results))
388 # Define 'SELECT COUNT(*)' query for later use.
389 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()])
390 # Get the values we inserted into table b.
391 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()]
392 # Remove two row from table b by ID.
393 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
394 self.assertEqual(n, 2)
395 # Remove the other two rows from table b by name.
396 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
397 self.assertEqual(n, 2)
398 # There should now be no rows in table b.
399 self.assertEqual(
400 db.query(count.select_from(tables.b)).scalar(),
401 0
402 )
403 # All b_id values in table c should now be NULL, because there's an
404 # onDelete='SET NULL' foreign key.
405 self.assertEqual(
406 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
407 0
408 )
409 # Remove all rows in table a (there's only one); this should remove all
410 # rows in d due to onDelete='CASCADE'.
411 n = db.delete(tables.a, [])
412 self.assertEqual(n, 1)
413 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
414 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
416 def testUpdate(self):
417 """Tests for `Database.update`.
418 """
419 db = self.makeEmptyDatabase(origin=1)
420 with db.declareStaticTables(create=True) as context:
421 tables = context.addTableTuple(STATIC_TABLE_SPECS)
422 # Insert two rows into table a, both without regions.
423 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
424 # Update one of the rows with a region.
425 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
426 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
427 self.assertEqual(n, 1)
428 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a)
429 self.assertCountEqual(
430 [dict(r) for r in db.query(sql).fetchall()],
431 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
432 )
434 def testSync(self):
435 """Tests for `Database.sync`.
436 """
437 db = self.makeEmptyDatabase(origin=1)
438 with db.declareStaticTables(create=True) as context:
439 tables = context.addTableTuple(STATIC_TABLE_SPECS)
440 # Insert a row with sync, because it doesn't exist yet.
441 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
442 self.assertTrue(inserted)
443 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
444 [dict(r) for r in db.query(tables.b.select()).fetchall()])
445 # Repeat that operation, which should do nothing but return the
446 # requested values.
447 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
448 self.assertFalse(inserted)
449 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
450 [dict(r) for r in db.query(tables.b.select()).fetchall()])
451 # Repeat the operation without the 'extra' arg, which should also just
452 # return the existing row.
453 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
454 self.assertFalse(inserted)
455 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
456 [dict(r) for r in db.query(tables.b.select()).fetchall()])
457 # Repeat the operation with a different value in 'extra'. That still
458 # shouldn't be an error, because 'extra' is only used if we really do
459 # insert. Also drop the 'returning' argument.
460 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
461 self.assertFalse(inserted)
462 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
463 [dict(r) for r in db.query(tables.b.select()).fetchall()])
464 # Repeat the operation with the correct value in 'compared' instead of
465 # 'extra'.
466 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
467 self.assertFalse(inserted)
468 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
469 [dict(r) for r in db.query(tables.b.select()).fetchall()])
470 # Repeat the operation with an incorrect value in 'compared'; this
471 # should raise.
472 with self.assertRaises(DatabaseConflictError):
473 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
474 # Try to sync in a read-only database. This should work if and only
475 # if the matching row already exists.
476 with self.asReadOnly(db) as rodb:
477 with rodb.declareStaticTables(create=False) as context:
478 tables = context.addTableTuple(STATIC_TABLE_SPECS)
479 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
480 self.assertFalse(inserted)
481 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
482 [dict(r) for r in rodb.query(tables.b.select()).fetchall()])
483 with self.assertRaises(ReadOnlyDatabaseError):
484 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
486 def testReplace(self):
487 """Tests for `Database.replace`.
488 """
489 db = self.makeEmptyDatabase(origin=1)
490 with db.declareStaticTables(create=True) as context:
491 tables = context.addTableTuple(STATIC_TABLE_SPECS)
492 # Use 'replace' to insert a single row that contains a region and
493 # query to get it back.
494 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
495 row1 = {"name": "a1", "region": region}
496 db.replace(tables.a, row1)
497 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
498 # Insert another row without a region.
499 row2 = {"name": "a2", "region": None}
500 db.replace(tables.a, row2)
501 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
502 # Use replace to re-insert both of those rows again, which should do
503 # nothing.
504 db.replace(tables.a, row1, row2)
505 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
506 # Replace row1 with a row with no region, while reinserting row2.
507 row1a = {"name": "a1", "region": None}
508 db.replace(tables.a, row1a, row2)
509 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2])
510 # Replace both rows, returning row1 to its original state, while adding
511 # a new one. Pass them in in a different order.
512 row2a = {"name": "a2", "region": region}
513 row3 = {"name": "a3", "region": None}
514 db.replace(tables.a, row3, row2a, row1)
515 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3])
517 def testEnsure(self):
518 """Tests for `Database.ensure`.
519 """
520 db = self.makeEmptyDatabase(origin=1)
521 with db.declareStaticTables(create=True) as context:
522 tables = context.addTableTuple(STATIC_TABLE_SPECS)
523 # Use 'ensure' to insert a single row that contains a region and
524 # query to get it back.
525 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
526 row1 = {"name": "a1", "region": region}
527 self.assertEqual(db.ensure(tables.a, row1), 1)
528 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
529 # Insert another row without a region.
530 row2 = {"name": "a2", "region": None}
531 self.assertEqual(db.ensure(tables.a, row2), 1)
532 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
533 # Use ensure to re-insert both of those rows again, which should do
534 # nothing.
535 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
536 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
537 # Attempt to insert row1's key with no region, while
538 # reinserting row2. This should also do nothing.
539 row1a = {"name": "a1", "region": None}
540 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
541 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
542 # Attempt to insert new rows for both existing keys, this time also
543 # adding a new row. Pass them in in a different order. Only the new
544 # row should be added.
545 row2a = {"name": "a2", "region": region}
546 row3 = {"name": "a3", "region": None}
547 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
548 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2, row3])
550 def testTransactionNesting(self):
551 """Test that transactions can be nested with the behavior in the
552 presence of exceptions working as documented.
553 """
554 db = self.makeEmptyDatabase(origin=1)
555 with db.declareStaticTables(create=True) as context:
556 tables = context.addTableTuple(STATIC_TABLE_SPECS)
557 # Insert one row so we can trigger integrity errors by trying to insert
558 # a duplicate of it below.
559 db.insert(tables.a, {"name": "a1"})
560 # First test: error recovery via explicit savepoint=True in the inner
561 # transaction.
562 with db.transaction():
563 # This insert should succeed, and should not be rolled back because
564 # the assertRaises context should catch any exception before it
565 # propagates up to the outer transaction.
566 db.insert(tables.a, {"name": "a2"})
567 with self.assertRaises(sqlalchemy.exc.IntegrityError):
568 with db.transaction(savepoint=True):
569 # This insert should succeed, but should be rolled back.
570 db.insert(tables.a, {"name": "a4"})
571 # This insert should fail (duplicate primary key), raising
572 # an exception.
573 db.insert(tables.a, {"name": "a1"})
574 self.assertCountEqual(
575 [dict(r) for r in db.query(tables.a.select()).fetchall()],
576 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
577 )
578 # Second test: error recovery via implicit savepoint=True, when the
579 # innermost transaction is inside a savepoint=True transaction.
580 with db.transaction():
581 # This insert should succeed, and should not be rolled back
582 # because the assertRaises context should catch any
583 # exception before it propagates up to the outer
584 # transaction.
585 db.insert(tables.a, {"name": "a3"})
586 with self.assertRaises(sqlalchemy.exc.IntegrityError):
587 with db.transaction(savepoint=True):
588 # This insert should succeed, but should be rolled back.
589 db.insert(tables.a, {"name": "a4"})
590 with db.transaction():
591 # This insert should succeed, but should be rolled
592 # back.
593 db.insert(tables.a, {"name": "a5"})
594 # This insert should fail (duplicate primary key),
595 # raising an exception.
596 db.insert(tables.a, {"name": "a1"})
597 self.assertCountEqual(
598 [dict(r) for r in db.query(tables.a.select()).fetchall()],
599 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
600 )
602 def testTransactionLocking(self):
603 """Test that `Database.transaction` can be used to acquire a lock
604 that prohibits concurrent writes.
605 """
606 db1 = self.makeEmptyDatabase(origin=1)
607 with db1.declareStaticTables(create=True) as context:
608 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
610 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
611 """One side of the concurrent locking test.
613 This optionally locks the table (and maybe the whole database),
614 does a select for its contents, inserts a new row, and then selects
615 again, with some waiting in between to make sure the other side has
616 a chance to _attempt_ to insert in between. If the locking is
617 enabled and works, the difference between the selects should just
618 be the insert done on this thread.
619 """
620 # Give Side2 a chance to create a connection
621 await asyncio.sleep(1.0)
622 with db1.transaction(lock=lock):
623 names1 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
624 # Give Side2 a chance to insert (which will be blocked if
625 # we've acquired a lock).
626 await asyncio.sleep(2.0)
627 db1.insert(tables1.a, {"name": "a1"})
628 names2 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
629 return names1, names2
631 async def side2() -> None:
632 """The other side of the concurrent locking test.
634 This side just waits a bit and then tries to insert a row into the
635 table that the other side is trying to lock. Hopefully that
636 waiting is enough to give the other side a chance to acquire the
637 lock and thus make this side block until the lock is released. If
638 this side manages to do the insert before side1 acquires the lock,
639 we'll just warn about not succeeding at testing the locking,
640 because we can only make that unlikely, not impossible.
641 """
642 def toRunInThread():
643 """SQLite locking isn't asyncio-friendly unless we actually
644 run it in another thread. And SQLite gets very unhappy if
645 we try to use a connection from multiple threads, so we have
646 to create the new connection here instead of out in the main
647 body of the test function.
648 """
649 db2 = self.getNewConnection(db1, writeable=True)
650 with db2.declareStaticTables(create=False) as context:
651 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
652 with db2.transaction():
653 db2.insert(tables2.a, {"name": "a2"})
655 await asyncio.sleep(2.0)
656 loop = asyncio.get_running_loop()
657 with ThreadPoolExecutor() as pool:
658 await loop.run_in_executor(pool, toRunInThread)
660 async def testProblemsWithNoLocking() -> None:
661 """Run side1 and side2 with no locking, attempting to demonstrate
662 the problem that locking is supposed to solve. If we get unlucky
663 with scheduling, side2 will just happen to insert after side1 is
664 done, and we won't have anything definitive. We just warn in that
665 case because we really don't want spurious test failures.
666 """
667 task1 = asyncio.create_task(side1())
668 task2 = asyncio.create_task(side2())
670 names1, names2 = await task1
671 await task2
672 if "a2" in names1:
673 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
674 "happened before first SELECT.")
675 self.assertEqual(names1, {"a2"})
676 self.assertEqual(names2, {"a1", "a2"})
677 elif "a2" not in names2:
678 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
679 "happened after second SELECT even without locking.")
680 self.assertEqual(names1, set())
681 self.assertEqual(names2, {"a1"})
682 else:
683 # This is the expected case: both INSERTS happen between the
684 # two SELECTS. If we don't get this almost all of the time we
685 # should adjust the sleep amounts.
686 self.assertEqual(names1, set())
687 self.assertEqual(names2, {"a1", "a2"})
689 asyncio.run(testProblemsWithNoLocking())
691 # Clean up after first test.
692 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
694 async def testSolutionWithLocking() -> None:
695 """Run side1 and side2 with locking, which should make side2 block
696 its insert until side2 releases its lock.
697 """
698 task1 = asyncio.create_task(side1(lock=[tables1.a]))
699 task2 = asyncio.create_task(side2())
701 names1, names2 = await task1
702 await task2
703 if "a2" in names1:
704 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT "
705 "happened before first SELECT.")
706 self.assertEqual(names1, {"a2"})
707 self.assertEqual(names2, {"a1", "a2"})
708 else:
709 # This is the expected case: the side2 INSERT happens after the
710 # last SELECT on side1. This can also happen due to unlucky
711 # scheduling, and we have no way to detect that here, but the
712 # similar "no-locking" test has at least some chance of being
713 # affected by the same problem and warning about it.
714 self.assertEqual(names1, set())
715 self.assertEqual(names2, {"a1"})
717 asyncio.run(testSolutionWithLocking())
719 def testTimespanDatabaseRepresentation(self):
720 """Tests for `TimespanDatabaseRepresentation` and the `Database`
721 methods that interact with it.
722 """
723 # Make some test timespans to play with, with the full suite of
724 # topological relationships.
725 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai")
726 offset = astropy.time.TimeDelta(60, format="sec")
727 timestamps = [start + offset*n for n in range(3)]
728 aTimespans = [Timespan(begin=None, end=None)]
729 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
730 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
731 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
732 # Make another list of timespans that span the full range but don't
733 # overlap. This is a subset of the previous list.
734 bTimespans = [Timespan(begin=None, end=timestamps[0])]
735 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
736 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
737 # Make a database and create a table with that database's timespan
738 # representation. This one will have no exclusion constraint and
739 # a nullable timespan.
740 db = self.makeEmptyDatabase(origin=1)
741 tsRepr = db.getTimespanRepresentation()
742 aSpec = ddl.TableSpec(
743 fields=[
744 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
745 ],
746 )
747 for fieldSpec in tsRepr.makeFieldSpecs(nullable=True):
748 aSpec.fields.add(fieldSpec)
749 with db.declareStaticTables(create=True) as context:
750 aTable = context.addTable("a", aSpec)
752 def convertRowForInsert(row: dict) -> dict:
753 """Convert a row containing a Timespan instance into one suitable
754 for insertion into the database.
755 """
756 result = row.copy()
757 ts = result.pop(tsRepr.NAME)
758 return tsRepr.update(ts, result=result)
760 def convertRowFromSelect(row: dict) -> dict:
761 """Convert a row from the database into one containing a Timespan.
762 """
763 result = row.copy()
764 timespan = tsRepr.extract(result)
765 for name in tsRepr.getFieldNames():
766 del result[name]
767 result[tsRepr.NAME] = timespan
768 return result
770 # Insert rows into table A, in chunks just to make things interesting.
771 # Include one with a NULL timespan.
772 aRows = [{"id": n, tsRepr.NAME: t} for n, t in enumerate(aTimespans)]
773 aRows.append({"id": len(aRows), tsRepr.NAME: None})
774 db.insert(aTable, convertRowForInsert(aRows[0]))
775 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
776 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
777 # Add another one with a NULL timespan, but this time by invoking
778 # the server-side default.
779 aRows.append({"id": len(aRows)})
780 db.insert(aTable, aRows[-1])
781 aRows[-1][tsRepr.NAME] = None
782 # Test basic round-trip through database.
783 self.assertEqual(
784 aRows,
785 [convertRowFromSelect(dict(row))
786 for row in db.query(aTable.select().order_by(aTable.columns.id)).fetchall()]
787 )
788 # Create another table B with a not-null timespan and (if the database
789 # supports it), an exclusion constraint. Use ensureTableExists this
790 # time to check that mode of table creation vs. timespans.
791 bSpec = ddl.TableSpec(
792 fields=[
793 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
794 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
795 ],
796 )
797 for fieldSpec in tsRepr.makeFieldSpecs(nullable=False):
798 bSpec.fields.add(fieldSpec)
799 if tsRepr.hasExclusionConstraint():
800 bSpec.exclusion.add(("key", tsRepr))
801 bTable = db.ensureTableExists("b", bSpec)
802 # Insert rows into table B, again in chunks. Each Timespan appears
803 # twice, but with different values for the 'key' field (which should
804 # still be okay for any exclusion constraint we may have defined).
805 bRows = [{"id": n, "key": 1, tsRepr.NAME: t} for n, t in enumerate(bTimespans)]
806 offset = len(bRows)
807 bRows.extend({"id": n + offset, "key": 2, tsRepr.NAME: t} for n, t in enumerate(bTimespans))
808 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
809 db.insert(bTable, convertRowForInsert(bRows[2]))
810 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
811 # Insert a row with no timespan into table B. This should invoke the
812 # server-side default, which is a timespan over (-∞, ∞). We set
813 # key=3 to avoid upsetting an exclusion constraint that might exist.
814 bRows.append({"id": len(bRows), "key": 3})
815 db.insert(bTable, bRows[-1])
816 bRows[-1][tsRepr.NAME] = Timespan(None, None)
817 # Test basic round-trip through database.
818 self.assertEqual(
819 bRows,
820 [convertRowFromSelect(dict(row))
821 for row in db.query(bTable.select().order_by(bTable.columns.id)).fetchall()]
822 )
823 # Test that we can't insert timespan=None into this table.
824 with self.assertRaises(sqlalchemy.exc.IntegrityError):
825 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, tsRepr.NAME: None}))
826 # IFF this database supports exclusion constraints, test that they
827 # also prevent inserts.
828 if tsRepr.hasExclusionConstraint():
829 with self.assertRaises(sqlalchemy.exc.IntegrityError):
830 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1,
831 tsRepr.NAME: Timespan(None, timestamps[1])}))
832 with self.assertRaises(sqlalchemy.exc.IntegrityError):
833 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1,
834 tsRepr.NAME: Timespan(timestamps[0], timestamps[2])}))
835 with self.assertRaises(sqlalchemy.exc.IntegrityError):
836 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1,
837 tsRepr.NAME: Timespan(timestamps[2], None)}))
838 # Test NULL checks in SELECT queries, on both tables.
839 aRepr = tsRepr.fromSelectable(aTable)
840 self.assertEqual(
841 [row[tsRepr.NAME] is None for row in aRows],
842 [
843 row["f"] for row in db.query(
844 sqlalchemy.sql.select(
845 [aRepr.isNull().label("f")]
846 ).order_by(
847 aTable.columns.id
848 )
849 ).fetchall()
850 ]
851 )
852 bRepr = tsRepr.fromSelectable(bTable)
853 self.assertEqual(
854 [False for row in bRows],
855 [
856 row["f"] for row in db.query(
857 sqlalchemy.sql.select(
858 [bRepr.isNull().label("f")]
859 ).order_by(
860 bTable.columns.id
861 )
862 ).fetchall()
863 ]
864 )
865 # Test overlap expressions that relate in-database A timespans to
866 # Python-literal B timespans; check that this is consistent with
867 # Python-only overlap tests.
868 for bRow in bRows:
869 with self.subTest(bRow=bRow):
870 expected = {}
871 for aRow in aRows:
872 if aRow[tsRepr.NAME] is None:
873 expected[aRow["id"]] = None
874 else:
875 expected[aRow["id"]] = aRow[tsRepr.NAME].overlaps(bRow[tsRepr.NAME])
876 sql = sqlalchemy.sql.select(
877 [aTable.columns.id.label("a"), aRepr.overlaps(bRow[tsRepr.NAME]).label("f")]
878 ).select_from(aTable)
879 queried = {row["a"]: row["f"] for row in db.query(sql).fetchall()}
880 self.assertEqual(expected, queried)
881 # Test overlap expressions that relate in-database A timespans to
882 # in-database B timespans; check that this is consistent with
883 # Python-only overlap tests.
884 expected = {
885 (aRow["id"], bRow["id"]): (aRow[tsRepr.NAME].overlaps(bRow[tsRepr.NAME])
886 if aRow[tsRepr.NAME] is not None else None)
887 for aRow, bRow in itertools.product(aRows, bRows)
888 }
889 sql = sqlalchemy.sql.select(
890 [
891 aTable.columns.id.label("a"),
892 bTable.columns.id.label("b"),
893 aRepr.overlaps(bRepr).label("f")
894 ]
895 ).select_from(aTable.join(bTable, onclause=sqlalchemy.sql.literal(True)))
896 queried = {(row["a"], row["b"]): row["f"] for row in db.query(sql).fetchall()}
897 self.assertEqual(expected, queried)