Coverage for python/lsst/daf/butler/registry/tests/_database.py : 7%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26import asyncio
27from collections import namedtuple
28from concurrent.futures import ThreadPoolExecutor
29import itertools
30from typing import ContextManager, Iterable, Set, Tuple
31import warnings
33import astropy.time
34import sqlalchemy
36from lsst.sphgeom import ConvexPolygon, UnitVector3d
37from ..interfaces import (
38 Database,
39 ReadOnlyDatabaseError,
40 DatabaseConflictError,
41 SchemaAlreadyDefinedError
42)
43from ...core import ddl, Timespan
45StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
47STATIC_TABLE_SPECS = StaticTablesTuple(
48 a=ddl.TableSpec(
49 fields=[
50 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
51 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
52 ]
53 ),
54 b=ddl.TableSpec(
55 fields=[
56 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
58 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
59 ],
60 unique=[("name",)],
61 ),
62 c=ddl.TableSpec(
63 fields=[
64 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
65 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
66 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
67 ],
68 foreignKeys=[
69 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
70 ]
71 ),
72)
74DYNAMIC_TABLE_SPEC = ddl.TableSpec(
75 fields=[
76 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
77 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
78 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
79 ],
80 foreignKeys=[
81 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
82 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
83 ]
84)
86TEMPORARY_TABLE_SPEC = ddl.TableSpec(
87 fields=[
88 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
89 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
90 ],
91)
94class DatabaseTests(ABC):
95 """Generic tests for the `Database` interface that can be subclassed to
96 generate tests for concrete implementations.
97 """
99 @abstractmethod
100 def makeEmptyDatabase(self, origin: int = 0) -> Database:
101 """Return an empty `Database` with the given origin, or an
102 automatically-generated one if ``origin`` is `None`.
103 """
104 raise NotImplementedError()
106 @abstractmethod
107 def asReadOnly(self, database: Database) -> ContextManager[Database]:
108 """Return a context manager for a read-only connection into the given
109 database.
111 The original database should be considered unusable within the context
112 but safe to use again afterwards (this allows the context manager to
113 block write access by temporarily changing user permissions to really
114 guarantee that write operations are not performed).
115 """
116 raise NotImplementedError()
118 @abstractmethod
119 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
120 """Return a new `Database` instance that points to the same underlying
121 storage as the given one.
122 """
123 raise NotImplementedError()
125 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
126 self.assertCountEqual(spec.fields.names, table.columns.keys())
127 # Checking more than this currently seems fragile, as it might restrict
128 # what Database implementations do; we don't care if the spec is
129 # actually preserved in terms of types and constraints as long as we
130 # can use the returned table as if it was.
132 def checkStaticSchema(self, tables: StaticTablesTuple):
133 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
134 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
135 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
137 def testDeclareStaticTables(self):
138 """Tests for `Database.declareStaticSchema` and the methods it
139 delegates to.
140 """
141 # Create the static schema in a new, empty database.
142 newDatabase = self.makeEmptyDatabase()
143 with newDatabase.declareStaticTables(create=True) as context:
144 tables = context.addTableTuple(STATIC_TABLE_SPECS)
145 self.checkStaticSchema(tables)
146 # Check that we can load that schema even from a read-only connection.
147 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
148 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
149 tables = context.addTableTuple(STATIC_TABLE_SPECS)
150 self.checkStaticSchema(tables)
152 def testDeclareStaticTablesTwice(self):
153 """Tests for `Database.declareStaticSchema` being called twice.
154 """
155 # Create the static schema in a new, empty database.
156 newDatabase = self.makeEmptyDatabase()
157 with newDatabase.declareStaticTables(create=True) as context:
158 tables = context.addTableTuple(STATIC_TABLE_SPECS)
159 self.checkStaticSchema(tables)
160 # Second time it should raise
161 with self.assertRaises(SchemaAlreadyDefinedError):
162 with newDatabase.declareStaticTables(create=True) as context:
163 tables = context.addTableTuple(STATIC_TABLE_SPECS)
164 # Check schema, it should still contain all tables, and maybe some
165 # extra.
166 with newDatabase.declareStaticTables(create=False) as context:
167 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
169 def testRepr(self):
170 """Test that repr does not return a generic thing."""
171 newDatabase = self.makeEmptyDatabase()
172 rep = repr(newDatabase)
173 # Check that stringification works and gives us something different
174 self.assertNotEqual(rep, str(newDatabase))
175 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
176 self.assertIn("://", rep)
178 def testDynamicTables(self):
179 """Tests for `Database.ensureTableExists` and
180 `Database.getExistingTable`.
181 """
182 # Need to start with the static schema.
183 newDatabase = self.makeEmptyDatabase()
184 with newDatabase.declareStaticTables(create=True) as context:
185 context.addTableTuple(STATIC_TABLE_SPECS)
186 # Try to ensure the dyamic table exists in a read-only version of that
187 # database, which should fail because we can't create it.
188 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
189 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
190 context.addTableTuple(STATIC_TABLE_SPECS)
191 with self.assertRaises(ReadOnlyDatabaseError):
192 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
193 # Just getting the dynamic table before it exists should return None.
194 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
195 # Ensure the new table exists back in the original database, which
196 # should create it.
197 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
198 self.checkTable(DYNAMIC_TABLE_SPEC, table)
199 # Ensuring that it exists should just return the exact same table
200 # instance again.
201 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
202 # Try again from the read-only database.
203 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
204 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
205 context.addTableTuple(STATIC_TABLE_SPECS)
206 # Just getting the dynamic table should now work...
207 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
208 # ...as should ensuring that it exists, since it now does.
209 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
210 self.checkTable(DYNAMIC_TABLE_SPEC, table)
211 # Trying to get the table with a different specification (at least
212 # in terms of what columns are present) should raise.
213 with self.assertRaises(DatabaseConflictError):
214 newDatabase.ensureTableExists(
215 "d",
216 ddl.TableSpec(
217 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
218 )
219 )
220 # Calling ensureTableExists inside a transaction block is an error,
221 # even if it would do nothing.
222 with newDatabase.transaction():
223 with self.assertRaises(AssertionError):
224 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
226 def testTemporaryTables(self):
227 """Tests for `Database.makeTemporaryTable`,
228 `Database.dropTemporaryTable`, and `Database.insert` with
229 the ``select`` argument.
230 """
231 # Need to start with the static schema; also insert some test data.
232 newDatabase = self.makeEmptyDatabase()
233 with newDatabase.declareStaticTables(create=True) as context:
234 static = context.addTableTuple(STATIC_TABLE_SPECS)
235 newDatabase.insert(static.a,
236 {"name": "a1", "region": None},
237 {"name": "a2", "region": None})
238 bIds = newDatabase.insert(static.b,
239 {"name": "b1", "value": 11},
240 {"name": "b2", "value": 12},
241 {"name": "b3", "value": 13},
242 returnIds=True)
243 # Create the table.
244 table1 = newDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
245 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
246 # Insert via a INSERT INTO ... SELECT query.
247 newDatabase.insert(
248 table1,
249 select=sqlalchemy.sql.select(
250 [static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")]
251 ).select_from(
252 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
253 ).where(
254 sqlalchemy.sql.and_(
255 static.a.columns.name == "a1",
256 static.b.columns.value <= 12,
257 )
258 )
259 )
260 # Check that the inserted rows are present.
261 self.assertCountEqual(
262 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
263 [dict(row) for row in newDatabase.query(table1.select())]
264 )
265 # Create another one via a read-only connection to the database.
266 # We _do_ allow temporary table modifications in read-only databases.
267 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
268 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
269 context.addTableTuple(STATIC_TABLE_SPECS)
270 table2 = existingReadOnlyDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
271 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
272 # Those tables should not be the same, despite having the same ddl.
273 self.assertIsNot(table1, table2)
274 # Do a slightly different insert into this table, to check that
275 # it works in a read-only database. This time we pass column
276 # names as a kwarg to insert instead of by labeling the columns in
277 # the select.
278 existingReadOnlyDatabase.insert(
279 table2,
280 select=sqlalchemy.sql.select(
281 [static.a.columns.name, static.b.columns.id]
282 ).select_from(
283 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
284 ).where(
285 sqlalchemy.sql.and_(
286 static.a.columns.name == "a2",
287 static.b.columns.value >= 12,
288 )
289 ),
290 names=["a_name", "b_id"],
291 )
292 # Check that the inserted rows are present.
293 self.assertCountEqual(
294 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
295 [dict(row) for row in existingReadOnlyDatabase.query(table2.select())]
296 )
297 # Drop the temporary table from the read-only DB. It's unspecified
298 # whether attempting to use it after this point is an error or just
299 # never returns any results, so we can't test what it does, only
300 # that it's not an error.
301 existingReadOnlyDatabase.dropTemporaryTable(table2)
302 # Drop the original temporary table.
303 newDatabase.dropTemporaryTable(table1)
305 def testSchemaSeparation(self):
306 """Test that creating two different `Database` instances allows us
307 to create different tables with the same name in each.
308 """
309 db1 = self.makeEmptyDatabase(origin=1)
310 with db1.declareStaticTables(create=True) as context:
311 tables = context.addTableTuple(STATIC_TABLE_SPECS)
312 self.checkStaticSchema(tables)
314 db2 = self.makeEmptyDatabase(origin=2)
315 # Make the DDL here intentionally different so we'll definitely
316 # notice if db1 and db2 are pointing at the same schema.
317 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
318 with db2.declareStaticTables(create=True) as context:
319 # Make the DDL here intentionally different so we'll definitely
320 # notice if db1 and db2 are pointing at the same schema.
321 table = context.addTable("a", spec)
322 self.checkTable(spec, table)
324 def testInsertQueryDelete(self):
325 """Test the `Database.insert`, `Database.query`, and `Database.delete`
326 methods, as well as the `Base64Region` type and the ``onDelete``
327 argument to `ddl.ForeignKeySpec`.
328 """
329 db = self.makeEmptyDatabase(origin=1)
330 with db.declareStaticTables(create=True) as context:
331 tables = context.addTableTuple(STATIC_TABLE_SPECS)
332 # Insert a single, non-autoincrement row that contains a region and
333 # query to get it back.
334 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
335 row = {"name": "a1", "region": region}
336 db.insert(tables.a, row)
337 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row])
338 # Insert multiple autoincrement rows but do not try to get the IDs
339 # back immediately.
340 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
341 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()]
342 self.assertEqual(len(results), 2)
343 for row in results:
344 self.assertIn(row["name"], ("b1", "b2"))
345 self.assertIsInstance(row["id"], int)
346 self.assertGreater(results[1]["id"], results[0]["id"])
347 # Insert multiple autoincrement rows and get the IDs back from insert.
348 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
349 ids = db.insert(tables.b, *rows, returnIds=True)
350 results = [
351 dict(r) for r in db.query(
352 tables.b.select().where(tables.b.columns.id > results[1]["id"])
353 ).fetchall()
354 ]
355 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
356 self.assertCountEqual(results, expected)
357 self.assertTrue(all(result["id"] is not None for result in results))
358 # Insert multiple rows into a table with an autoincrement+origin
359 # primary key, then use the returned IDs to insert into a dynamic
360 # table.
361 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
362 {"origin": db.origin, "b_id": None}]
363 ids = db.insert(tables.c, *rows, returnIds=True)
364 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
365 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
366 self.assertCountEqual(results, expected)
367 self.assertTrue(all(result["id"] is not None for result in results))
368 # Add the dynamic table.
369 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
370 # Insert into it.
371 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
372 db.insert(d, *rows)
373 results = [dict(r) for r in db.query(d.select()).fetchall()]
374 self.assertCountEqual(rows, results)
375 # Insert multiple rows into a table with an autoincrement+origin
376 # primary key (this is especially tricky for SQLite, but good to test
377 # for all DBs), but pass in a value for the autoincrement key.
378 # For extra complexity, we re-use the autoincrement value with a
379 # different value for origin.
380 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
381 {"id": 700, "origin": 60, "b_id": None},
382 {"id": 1, "origin": 60, "b_id": None}]
383 db.insert(tables.c, *rows2)
384 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
385 self.assertCountEqual(results, expected + rows2)
386 self.assertTrue(all(result["id"] is not None for result in results))
388 # Define 'SELECT COUNT(*)' query for later use.
389 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()])
390 # Get the values we inserted into table b.
391 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()]
392 # Remove two row from table b by ID.
393 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
394 self.assertEqual(n, 2)
395 # Remove the other two rows from table b by name.
396 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
397 self.assertEqual(n, 2)
398 # There should now be no rows in table b.
399 self.assertEqual(
400 db.query(count.select_from(tables.b)).scalar(),
401 0
402 )
403 # All b_id values in table c should now be NULL, because there's an
404 # onDelete='SET NULL' foreign key.
405 self.assertEqual(
406 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
407 0
408 )
409 # Remove all rows in table a (there's only one); this should remove all
410 # rows in d due to onDelete='CASCADE'.
411 n = db.delete(tables.a, [])
412 self.assertEqual(n, 1)
413 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
414 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
416 def testUpdate(self):
417 """Tests for `Database.update`.
418 """
419 db = self.makeEmptyDatabase(origin=1)
420 with db.declareStaticTables(create=True) as context:
421 tables = context.addTableTuple(STATIC_TABLE_SPECS)
422 # Insert two rows into table a, both without regions.
423 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
424 # Update one of the rows with a region.
425 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
426 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
427 self.assertEqual(n, 1)
428 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a)
429 self.assertCountEqual(
430 [dict(r) for r in db.query(sql).fetchall()],
431 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
432 )
434 def testSync(self):
435 """Tests for `Database.sync`.
436 """
437 db = self.makeEmptyDatabase(origin=1)
438 with db.declareStaticTables(create=True) as context:
439 tables = context.addTableTuple(STATIC_TABLE_SPECS)
440 # Insert a row with sync, because it doesn't exist yet.
441 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
442 self.assertTrue(inserted)
443 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
444 [dict(r) for r in db.query(tables.b.select()).fetchall()])
445 # Repeat that operation, which should do nothing but return the
446 # requested values.
447 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
448 self.assertFalse(inserted)
449 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
450 [dict(r) for r in db.query(tables.b.select()).fetchall()])
451 # Repeat the operation without the 'extra' arg, which should also just
452 # return the existing row.
453 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
454 self.assertFalse(inserted)
455 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
456 [dict(r) for r in db.query(tables.b.select()).fetchall()])
457 # Repeat the operation with a different value in 'extra'. That still
458 # shouldn't be an error, because 'extra' is only used if we really do
459 # insert. Also drop the 'returning' argument.
460 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
461 self.assertFalse(inserted)
462 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
463 [dict(r) for r in db.query(tables.b.select()).fetchall()])
464 # Repeat the operation with the correct value in 'compared' instead of
465 # 'extra'.
466 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
467 self.assertFalse(inserted)
468 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
469 [dict(r) for r in db.query(tables.b.select()).fetchall()])
470 # Repeat the operation with an incorrect value in 'compared'; this
471 # should raise.
472 with self.assertRaises(DatabaseConflictError):
473 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
474 # Try to sync in a read-only database. This should work if and only
475 # if the matching row already exists.
476 with self.asReadOnly(db) as rodb:
477 with rodb.declareStaticTables(create=False) as context:
478 tables = context.addTableTuple(STATIC_TABLE_SPECS)
479 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
480 self.assertFalse(inserted)
481 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
482 [dict(r) for r in rodb.query(tables.b.select()).fetchall()])
483 with self.assertRaises(ReadOnlyDatabaseError):
484 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
486 def testReplace(self):
487 """Tests for `Database.replace`.
488 """
489 db = self.makeEmptyDatabase(origin=1)
490 with db.declareStaticTables(create=True) as context:
491 tables = context.addTableTuple(STATIC_TABLE_SPECS)
492 # Use 'replace' to insert a single row that contains a region and
493 # query to get it back.
494 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
495 row1 = {"name": "a1", "region": region}
496 db.replace(tables.a, row1)
497 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
498 # Insert another row without a region.
499 row2 = {"name": "a2", "region": None}
500 db.replace(tables.a, row2)
501 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
502 # Use replace to re-insert both of those rows again, which should do
503 # nothing.
504 db.replace(tables.a, row1, row2)
505 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
506 # Replace row1 with a row with no region, while reinserting row2.
507 row1a = {"name": "a1", "region": None}
508 db.replace(tables.a, row1a, row2)
509 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2])
510 # Replace both rows, returning row1 to its original state, while adding
511 # a new one. Pass them in in a different order.
512 row2a = {"name": "a2", "region": region}
513 row3 = {"name": "a3", "region": None}
514 db.replace(tables.a, row3, row2a, row1)
515 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3])
517 def testEnsure(self):
518 """Tests for `Database.ensure`.
519 """
520 db = self.makeEmptyDatabase(origin=1)
521 with db.declareStaticTables(create=True) as context:
522 tables = context.addTableTuple(STATIC_TABLE_SPECS)
523 # Use 'ensure' to insert a single row that contains a region and
524 # query to get it back.
525 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
526 row1 = {"name": "a1", "region": region}
527 self.assertEqual(db.ensure(tables.a, row1), 1)
528 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
529 # Insert another row without a region.
530 row2 = {"name": "a2", "region": None}
531 self.assertEqual(db.ensure(tables.a, row2), 1)
532 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
533 # Use ensure to re-insert both of those rows again, which should do
534 # nothing.
535 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
536 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
537 # Attempt to insert row1's key with no region, while
538 # reinserting row2. This should also do nothing.
539 row1a = {"name": "a1", "region": None}
540 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
541 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
542 # Attempt to insert new rows for both existing keys, this time also
543 # adding a new row. Pass them in in a different order. Only the new
544 # row should be added.
545 row2a = {"name": "a2", "region": region}
546 row3 = {"name": "a3", "region": None}
547 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
548 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2, row3])
550 def testTransactionNesting(self):
551 """Test that transactions can be nested with the behavior in the
552 presence of exceptions working as documented.
553 """
554 db = self.makeEmptyDatabase(origin=1)
555 with db.declareStaticTables(create=True) as context:
556 tables = context.addTableTuple(STATIC_TABLE_SPECS)
557 # Insert one row so we can trigger integrity errors by trying to insert
558 # a duplicate of it below.
559 db.insert(tables.a, {"name": "a1"})
560 # First test: error recovery via explicit savepoint=True in the inner
561 # transaction.
562 with db.transaction():
563 # This insert should succeed, and should not be rolled back because
564 # the assertRaises context should catch any exception before it
565 # propagates up to the outer transaction.
566 db.insert(tables.a, {"name": "a2"})
567 with self.assertRaises(sqlalchemy.exc.IntegrityError):
568 with db.transaction(savepoint=True):
569 # This insert should succeed, but should be rolled back.
570 db.insert(tables.a, {"name": "a4"})
571 # This insert should fail (duplicate primary key), raising
572 # an exception.
573 db.insert(tables.a, {"name": "a1"})
574 self.assertCountEqual(
575 [dict(r) for r in db.query(tables.a.select()).fetchall()],
576 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
577 )
578 # Second test: error recovery via implicit savepoint=True, when the
579 # innermost transaction is inside a savepoint=True transaction.
580 with db.transaction():
581 # This insert should succeed, and should not be rolled back
582 # because the assertRaises context should catch any
583 # exception before it propagates up to the outer
584 # transaction.
585 db.insert(tables.a, {"name": "a3"})
586 with self.assertRaises(sqlalchemy.exc.IntegrityError):
587 with db.transaction(savepoint=True):
588 # This insert should succeed, but should be rolled back.
589 db.insert(tables.a, {"name": "a4"})
590 with db.transaction():
591 # This insert should succeed, but should be rolled
592 # back.
593 db.insert(tables.a, {"name": "a5"})
594 # This insert should fail (duplicate primary key),
595 # raising an exception.
596 db.insert(tables.a, {"name": "a1"})
597 self.assertCountEqual(
598 [dict(r) for r in db.query(tables.a.select()).fetchall()],
599 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
600 )
602 def testTransactionLocking(self):
603 """Test that `Database.transaction` can be used to acquire a lock
604 that prohibits concurrent writes.
605 """
606 db1 = self.makeEmptyDatabase(origin=1)
607 with db1.declareStaticTables(create=True) as context:
608 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
610 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
611 """One side of the concurrent locking test.
613 This optionally locks the table (and maybe the whole database),
614 does a select for its contents, inserts a new row, and then selects
615 again, with some waiting in between to make sure the other side has
616 a chance to _attempt_ to insert in between. If the locking is
617 enabled and works, the difference between the selects should just
618 be the insert done on this thread.
619 """
620 # Give Side2 a chance to create a connection
621 await asyncio.sleep(1.0)
622 with db1.transaction(lock=lock):
623 names1 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
624 # Give Side2 a chance to insert (which will be blocked if
625 # we've acquired a lock).
626 await asyncio.sleep(2.0)
627 db1.insert(tables1.a, {"name": "a1"})
628 names2 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
629 return names1, names2
631 async def side2() -> None:
632 """The other side of the concurrent locking test.
634 This side just waits a bit and then tries to insert a row into the
635 table that the other side is trying to lock. Hopefully that
636 waiting is enough to give the other side a chance to acquire the
637 lock and thus make this side block until the lock is released. If
638 this side manages to do the insert before side1 acquires the lock,
639 we'll just warn about not succeeding at testing the locking,
640 because we can only make that unlikely, not impossible.
641 """
642 def toRunInThread():
643 """SQLite locking isn't asyncio-friendly unless we actually
644 run it in another thread. And SQLite gets very unhappy if
645 we try to use a connection from multiple threads, so we have
646 to create the new connection here instead of out in the main
647 body of the test function.
648 """
649 db2 = self.getNewConnection(db1, writeable=True)
650 with db2.declareStaticTables(create=False) as context:
651 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
652 with db2.transaction():
653 db2.insert(tables2.a, {"name": "a2"})
655 await asyncio.sleep(2.0)
656 loop = asyncio.get_running_loop()
657 with ThreadPoolExecutor() as pool:
658 await loop.run_in_executor(pool, toRunInThread)
660 async def testProblemsWithNoLocking() -> None:
661 """Run side1 and side2 with no locking, attempting to demonstrate
662 the problem that locking is supposed to solve. If we get unlucky
663 with scheduling, side2 will just happen to insert after side1 is
664 done, and we won't have anything definitive. We just warn in that
665 case because we really don't want spurious test failures.
666 """
667 task1 = asyncio.create_task(side1())
668 task2 = asyncio.create_task(side2())
670 names1, names2 = await task1
671 await task2
672 if "a2" in names1:
673 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
674 "happened before first SELECT.")
675 self.assertEqual(names1, {"a2"})
676 self.assertEqual(names2, {"a1", "a2"})
677 elif "a2" not in names2:
678 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
679 "happened after second SELECT even without locking.")
680 self.assertEqual(names1, set())
681 self.assertEqual(names2, {"a1"})
682 else:
683 # This is the expected case: both INSERTS happen between the
684 # two SELECTS. If we don't get this almost all of the time we
685 # should adjust the sleep amounts.
686 self.assertEqual(names1, set())
687 self.assertEqual(names2, {"a1", "a2"})
689 asyncio.run(testProblemsWithNoLocking())
691 # Clean up after first test.
692 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
694 async def testSolutionWithLocking() -> None:
695 """Run side1 and side2 with locking, which should make side2 block
696 its insert until side2 releases its lock.
697 """
698 task1 = asyncio.create_task(side1(lock=[tables1.a]))
699 task2 = asyncio.create_task(side2())
701 names1, names2 = await task1
702 await task2
703 if "a2" in names1:
704 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT "
705 "happened before first SELECT.")
706 self.assertEqual(names1, {"a2"})
707 self.assertEqual(names2, {"a1", "a2"})
708 else:
709 # This is the expected case: the side2 INSERT happens after the
710 # last SELECT on side1. This can also happen due to unlucky
711 # scheduling, and we have no way to detect that here, but the
712 # similar "no-locking" test has at least some chance of being
713 # affected by the same problem and warning about it.
714 self.assertEqual(names1, set())
715 self.assertEqual(names2, {"a1"})
717 asyncio.run(testSolutionWithLocking())
719 def testTimespanDatabaseRepresentation(self):
720 """Tests for `TimespanDatabaseRepresentation` and the `Database`
721 methods that interact with it.
722 """
723 # Make some test timespans to play with, with the full suite of
724 # topological relationships.
725 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai")
726 offset = astropy.time.TimeDelta(60, format="sec")
727 timestamps = [start + offset*n for n in range(3)]
728 aTimespans = [Timespan(begin=None, end=None)]
729 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
730 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
731 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
732 aTimespans.append(Timespan.makeEmpty())
733 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
734 # Make another list of timespans that span the full range but don't
735 # overlap. This is a subset of the previous list.
736 bTimespans = [Timespan(begin=None, end=timestamps[0])]
737 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
738 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
739 # Make a database and create a table with that database's timespan
740 # representation. This one will have no exclusion constraint and
741 # a nullable timespan.
742 db = self.makeEmptyDatabase(origin=1)
743 TimespanReprClass = db.getTimespanRepresentation()
744 aSpec = ddl.TableSpec(
745 fields=[
746 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
747 ],
748 )
749 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
750 aSpec.fields.add(fieldSpec)
751 with db.declareStaticTables(create=True) as context:
752 aTable = context.addTable("a", aSpec)
753 self.maxDiff = None
755 def convertRowForInsert(row: dict) -> dict:
756 """Convert a row containing a Timespan instance into one suitable
757 for insertion into the database.
758 """
759 result = row.copy()
760 ts = result.pop(TimespanReprClass.NAME)
761 return TimespanReprClass.update(ts, result=result)
763 def convertRowFromSelect(row: dict) -> dict:
764 """Convert a row from the database into one containing a Timespan.
765 """
766 result = row.copy()
767 timespan = TimespanReprClass.extract(result)
768 for name in TimespanReprClass.getFieldNames():
769 del result[name]
770 result[TimespanReprClass.NAME] = timespan
771 return result
773 # Insert rows into table A, in chunks just to make things interesting.
774 # Include one with a NULL timespan.
775 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
776 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
777 db.insert(aTable, convertRowForInsert(aRows[0]))
778 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
779 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
780 # Add another one with a NULL timespan, but this time by invoking
781 # the server-side default.
782 aRows.append({"id": len(aRows)})
783 db.insert(aTable, aRows[-1])
784 aRows[-1][TimespanReprClass.NAME] = None
785 # Test basic round-trip through database.
786 self.assertEqual(
787 aRows,
788 [convertRowFromSelect(dict(row))
789 for row in db.query(aTable.select().order_by(aTable.columns.id)).fetchall()]
790 )
791 # Create another table B with a not-null timespan and (if the database
792 # supports it), an exclusion constraint. Use ensureTableExists this
793 # time to check that mode of table creation vs. timespans.
794 bSpec = ddl.TableSpec(
795 fields=[
796 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
797 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
798 ],
799 )
800 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
801 bSpec.fields.add(fieldSpec)
802 if TimespanReprClass.hasExclusionConstraint():
803 bSpec.exclusion.add(("key", TimespanReprClass))
804 bTable = db.ensureTableExists("b", bSpec)
805 # Insert rows into table B, again in chunks. Each Timespan appears
806 # twice, but with different values for the 'key' field (which should
807 # still be okay for any exclusion constraint we may have defined).
808 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
809 offset = len(bRows)
810 bRows.extend({"id": n + offset, "key": 2, TimespanReprClass.NAME: t}
811 for n, t in enumerate(bTimespans))
812 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
813 db.insert(bTable, convertRowForInsert(bRows[2]))
814 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
815 # Insert a row with no timespan into table B. This should invoke the
816 # server-side default, which is a timespan over (-∞, ∞). We set
817 # key=3 to avoid upsetting an exclusion constraint that might exist.
818 bRows.append({"id": len(bRows), "key": 3})
819 db.insert(bTable, bRows[-1])
820 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
821 # Test basic round-trip through database.
822 self.assertEqual(
823 bRows,
824 [convertRowFromSelect(dict(row))
825 for row in db.query(bTable.select().order_by(bTable.columns.id)).fetchall()]
826 )
827 # Test that we can't insert timespan=None into this table.
828 with self.assertRaises(sqlalchemy.exc.IntegrityError):
829 db.insert(
830 bTable,
831 convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})
832 )
833 # IFF this database supports exclusion constraints, test that they
834 # also prevent inserts.
835 if TimespanReprClass.hasExclusionConstraint():
836 with self.assertRaises(sqlalchemy.exc.IntegrityError):
837 db.insert(
838 bTable,
839 convertRowForInsert({
840 "id": len(bRows), "key": 1,
841 TimespanReprClass.NAME: Timespan(None, timestamps[1])
842 })
843 )
844 with self.assertRaises(sqlalchemy.exc.IntegrityError):
845 db.insert(
846 bTable,
847 convertRowForInsert({
848 "id": len(bRows), "key": 1,
849 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2])
850 })
851 )
852 with self.assertRaises(sqlalchemy.exc.IntegrityError):
853 db.insert(
854 bTable,
855 convertRowForInsert({
856 "id": len(bRows), "key": 1,
857 TimespanReprClass.NAME: Timespan(timestamps[2], None)
858 })
859 )
860 # Test NULL checks in SELECT queries, on both tables.
861 aRepr = TimespanReprClass.fromSelectable(aTable)
862 self.assertEqual(
863 [row[TimespanReprClass.NAME] is None for row in aRows],
864 [
865 row["f"] for row in db.query(
866 sqlalchemy.sql.select(
867 [aRepr.isNull().label("f")]
868 ).order_by(
869 aTable.columns.id
870 )
871 ).fetchall()
872 ]
873 )
874 bRepr = TimespanReprClass.fromSelectable(bTable)
875 self.assertEqual(
876 [False for row in bRows],
877 [
878 row["f"] for row in db.query(
879 sqlalchemy.sql.select(
880 [bRepr.isNull().label("f")]
881 ).order_by(
882 bTable.columns.id
883 )
884 ).fetchall()
885 ]
886 )
887 # Test relationships expressions that relate in-database timespans to
888 # Python-literal timespans, all from the more complete 'a' set; check
889 # that this is consistent with Python-only relationship tests.
890 for rhsRow in aRows:
891 if rhsRow[TimespanReprClass.NAME] is None:
892 continue
893 with self.subTest(rhsRow=rhsRow):
894 expected = {}
895 for lhsRow in aRows:
896 if lhsRow[TimespanReprClass.NAME] is None:
897 expected[lhsRow["id"]] = (None, None, None, None)
898 else:
899 expected[lhsRow["id"]] = (
900 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
901 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
902 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
903 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
904 )
905 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
906 sql = sqlalchemy.sql.select([
907 aTable.columns.id.label("lhs"),
908 aRepr.overlaps(rhsRepr).label("overlaps"),
909 aRepr.contains(rhsRepr).label("contains"),
910 (aRepr < rhsRepr).label("less_than"),
911 (aRepr > rhsRepr).label("greater_than"),
912 ]).select_from(aTable)
913 queried = {
914 row["lhs"]: (row["overlaps"], row["contains"], row["less_than"], row["greater_than"])
915 for row in db.query(sql).fetchall()
916 }
917 self.assertEqual(expected, queried)
918 # Test relationship expressions that relate in-database timespans to
919 # each other, all from the more complete 'a' set; check that this is
920 # consistent with Python-only relationship tests.
921 expected = {}
922 for lhs, rhs in itertools.product(aRows, aRows):
923 lhsT = lhs[TimespanReprClass.NAME]
924 rhsT = rhs[TimespanReprClass.NAME]
925 if lhsT is not None and rhsT is not None:
926 expected[lhs["id"], rhs["id"]] = (
927 lhsT.overlaps(rhsT),
928 lhsT.contains(rhsT),
929 lhsT < rhsT,
930 lhsT > rhsT
931 )
932 else:
933 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
934 lhsSubquery = aTable.alias("lhs")
935 rhsSubquery = aTable.alias("rhs")
936 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery)
937 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery)
938 sql = sqlalchemy.sql.select(
939 [
940 lhsSubquery.columns.id.label("lhs"),
941 rhsSubquery.columns.id.label("rhs"),
942 lhsRepr.overlaps(rhsRepr).label("overlaps"),
943 lhsRepr.contains(rhsRepr).label("contains"),
944 (lhsRepr < rhsRepr).label("less_than"),
945 (lhsRepr > rhsRepr).label("greater_than"),
946 ]
947 ).select_from(
948 lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))
949 )
950 queried = {
951 (row["lhs"], row["rhs"]): (row["overlaps"], row["contains"],
952 row["less_than"], row["greater_than"])
953 for row in db.query(sql).fetchall()}
954 self.assertEqual(expected, queried)
955 # Test relationship expressions between in-database timespans and
956 # Python-literal instantaneous times.
957 for t in timestamps:
958 with self.subTest(t=t):
959 expected = {}
960 for lhsRow in aRows:
961 if lhsRow[TimespanReprClass.NAME] is None:
962 expected[lhsRow["id"]] = (None, None, None)
963 else:
964 expected[lhsRow["id"]] = (
965 lhsRow[TimespanReprClass.NAME].contains(t),
966 lhsRow[TimespanReprClass.NAME] < t,
967 lhsRow[TimespanReprClass.NAME] > t,
968 )
969 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
970 sql = sqlalchemy.sql.select([
971 aTable.columns.id.label("lhs"),
972 aRepr.contains(rhs).label("contains"),
973 (aRepr < rhs).label("less_than"),
974 (aRepr > rhs).label("greater_than"),
975 ]).select_from(aTable)
976 queried = {
977 row["lhs"]: (row["contains"], row["less_than"], row["greater_than"])
978 for row in db.query(sql).fetchall()
979 }
980 self.assertEqual(expected, queried)