Coverage for python/lsst/daf/butler/registry/tests/_database.py : 7%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26import asyncio
27from collections import namedtuple
28from concurrent.futures import ThreadPoolExecutor
29import itertools
30from typing import ContextManager, Iterable, Set, Tuple
31import warnings
33import astropy.time
34import sqlalchemy
36from lsst.sphgeom import ConvexPolygon, UnitVector3d
37from ..interfaces import (
38 Database,
39 ReadOnlyDatabaseError,
40 DatabaseConflictError,
41 SchemaAlreadyDefinedError
42)
43from ...core import ddl, Timespan
45StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
47STATIC_TABLE_SPECS = StaticTablesTuple(
48 a=ddl.TableSpec(
49 fields=[
50 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
51 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
52 ]
53 ),
54 b=ddl.TableSpec(
55 fields=[
56 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
58 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
59 ],
60 unique=[("name",)],
61 ),
62 c=ddl.TableSpec(
63 fields=[
64 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
65 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
66 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
67 ],
68 foreignKeys=[
69 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
70 ]
71 ),
72)
74DYNAMIC_TABLE_SPEC = ddl.TableSpec(
75 fields=[
76 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
77 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
78 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
79 ],
80 foreignKeys=[
81 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
82 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
83 ]
84)
86TEMPORARY_TABLE_SPEC = ddl.TableSpec(
87 fields=[
88 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
89 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
90 ],
91)
94class DatabaseTests(ABC):
95 """Generic tests for the `Database` interface that can be subclassed to
96 generate tests for concrete implementations.
97 """
99 @abstractmethod
100 def makeEmptyDatabase(self, origin: int = 0) -> Database:
101 """Return an empty `Database` with the given origin, or an
102 automatically-generated one if ``origin`` is `None`.
103 """
104 raise NotImplementedError()
106 @abstractmethod
107 def asReadOnly(self, database: Database) -> ContextManager[Database]:
108 """Return a context manager for a read-only connection into the given
109 database.
111 The original database should be considered unusable within the context
112 but safe to use again afterwards (this allows the context manager to
113 block write access by temporarily changing user permissions to really
114 guarantee that write operations are not performed).
115 """
116 raise NotImplementedError()
118 @abstractmethod
119 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
120 """Return a new `Database` instance that points to the same underlying
121 storage as the given one.
122 """
123 raise NotImplementedError()
125 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
126 self.assertCountEqual(spec.fields.names, table.columns.keys())
127 # Checking more than this currently seems fragile, as it might restrict
128 # what Database implementations do; we don't care if the spec is
129 # actually preserved in terms of types and constraints as long as we
130 # can use the returned table as if it was.
132 def checkStaticSchema(self, tables: StaticTablesTuple):
133 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
134 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
135 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
137 def testDeclareStaticTables(self):
138 """Tests for `Database.declareStaticSchema` and the methods it
139 delegates to.
140 """
141 # Create the static schema in a new, empty database.
142 newDatabase = self.makeEmptyDatabase()
143 with newDatabase.declareStaticTables(create=True) as context:
144 tables = context.addTableTuple(STATIC_TABLE_SPECS)
145 self.checkStaticSchema(tables)
146 # Check that we can load that schema even from a read-only connection.
147 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
148 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
149 tables = context.addTableTuple(STATIC_TABLE_SPECS)
150 self.checkStaticSchema(tables)
152 def testDeclareStaticTablesTwice(self):
153 """Tests for `Database.declareStaticSchema` being called twice.
154 """
155 # Create the static schema in a new, empty database.
156 newDatabase = self.makeEmptyDatabase()
157 with newDatabase.declareStaticTables(create=True) as context:
158 tables = context.addTableTuple(STATIC_TABLE_SPECS)
159 self.checkStaticSchema(tables)
160 # Second time it should raise
161 with self.assertRaises(SchemaAlreadyDefinedError):
162 with newDatabase.declareStaticTables(create=True) as context:
163 tables = context.addTableTuple(STATIC_TABLE_SPECS)
164 # Check schema, it should still contain all tables, and maybe some
165 # extra.
166 with newDatabase.declareStaticTables(create=False) as context:
167 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
169 def testRepr(self):
170 """Test that repr does not return a generic thing."""
171 newDatabase = self.makeEmptyDatabase()
172 rep = repr(newDatabase)
173 # Check that stringification works and gives us something different
174 self.assertNotEqual(rep, str(newDatabase))
175 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
176 self.assertIn("://", rep)
178 def testDynamicTables(self):
179 """Tests for `Database.ensureTableExists` and
180 `Database.getExistingTable`.
181 """
182 # Need to start with the static schema.
183 newDatabase = self.makeEmptyDatabase()
184 with newDatabase.declareStaticTables(create=True) as context:
185 context.addTableTuple(STATIC_TABLE_SPECS)
186 # Try to ensure the dyamic table exists in a read-only version of that
187 # database, which should fail because we can't create it.
188 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
189 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
190 context.addTableTuple(STATIC_TABLE_SPECS)
191 with self.assertRaises(ReadOnlyDatabaseError):
192 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
193 # Just getting the dynamic table before it exists should return None.
194 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
195 # Ensure the new table exists back in the original database, which
196 # should create it.
197 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
198 self.checkTable(DYNAMIC_TABLE_SPEC, table)
199 # Ensuring that it exists should just return the exact same table
200 # instance again.
201 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
202 # Try again from the read-only database.
203 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
204 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
205 context.addTableTuple(STATIC_TABLE_SPECS)
206 # Just getting the dynamic table should now work...
207 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
208 # ...as should ensuring that it exists, since it now does.
209 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
210 self.checkTable(DYNAMIC_TABLE_SPEC, table)
211 # Trying to get the table with a different specification (at least
212 # in terms of what columns are present) should raise.
213 with self.assertRaises(DatabaseConflictError):
214 newDatabase.ensureTableExists(
215 "d",
216 ddl.TableSpec(
217 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
218 )
219 )
220 # Calling ensureTableExists inside a transaction block is an error,
221 # even if it would do nothing.
222 with newDatabase.transaction():
223 with self.assertRaises(AssertionError):
224 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
226 def testTemporaryTables(self):
227 """Tests for `Database.makeTemporaryTable`,
228 `Database.dropTemporaryTable`, and `Database.insert` with
229 the ``select`` argument.
230 """
231 # Need to start with the static schema; also insert some test data.
232 newDatabase = self.makeEmptyDatabase()
233 with newDatabase.declareStaticTables(create=True) as context:
234 static = context.addTableTuple(STATIC_TABLE_SPECS)
235 newDatabase.insert(static.a,
236 {"name": "a1", "region": None},
237 {"name": "a2", "region": None})
238 bIds = newDatabase.insert(static.b,
239 {"name": "b1", "value": 11},
240 {"name": "b2", "value": 12},
241 {"name": "b3", "value": 13},
242 returnIds=True)
243 # Create the table.
244 with newDatabase.session() as session:
245 table1 = session.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
246 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
247 # Insert via a INSERT INTO ... SELECT query.
248 newDatabase.insert(
249 table1,
250 select=sqlalchemy.sql.select(
251 [static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")]
252 ).select_from(
253 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
254 ).where(
255 sqlalchemy.sql.and_(
256 static.a.columns.name == "a1",
257 static.b.columns.value <= 12,
258 )
259 )
260 )
261 # Check that the inserted rows are present.
262 self.assertCountEqual(
263 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
264 [dict(row) for row in newDatabase.query(table1.select())]
265 )
266 # Create another one via a read-only connection to the database.
267 # We _do_ allow temporary table modifications in read-only
268 # databases.
269 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
270 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
271 context.addTableTuple(STATIC_TABLE_SPECS)
272 with existingReadOnlyDatabase.session() as session2:
273 table2 = session2.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
274 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
275 # Those tables should not be the same, despite having the
276 # same ddl.
277 self.assertIsNot(table1, table2)
278 # Do a slightly different insert into this table, to check
279 # that it works in a read-only database. This time we
280 # pass column names as a kwarg to insert instead of by
281 # labeling the columns in the select.
282 existingReadOnlyDatabase.insert(
283 table2,
284 select=sqlalchemy.sql.select(
285 [static.a.columns.name, static.b.columns.id]
286 ).select_from(
287 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
288 ).where(
289 sqlalchemy.sql.and_(
290 static.a.columns.name == "a2",
291 static.b.columns.value >= 12,
292 )
293 ),
294 names=["a_name", "b_id"],
295 )
296 # Check that the inserted rows are present.
297 self.assertCountEqual(
298 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
299 [dict(row) for row in existingReadOnlyDatabase.query(table2.select())]
300 )
301 # Drop the temporary table from the read-only DB. It's
302 # unspecified whether attempting to use it after this
303 # point is an error or just never returns any results, so
304 # we can't test what it does, only that it's not an error.
305 session2.dropTemporaryTable(table2)
306 # Drop the original temporary table.
307 session.dropTemporaryTable(table1)
309 def testSchemaSeparation(self):
310 """Test that creating two different `Database` instances allows us
311 to create different tables with the same name in each.
312 """
313 db1 = self.makeEmptyDatabase(origin=1)
314 with db1.declareStaticTables(create=True) as context:
315 tables = context.addTableTuple(STATIC_TABLE_SPECS)
316 self.checkStaticSchema(tables)
318 db2 = self.makeEmptyDatabase(origin=2)
319 # Make the DDL here intentionally different so we'll definitely
320 # notice if db1 and db2 are pointing at the same schema.
321 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
322 with db2.declareStaticTables(create=True) as context:
323 # Make the DDL here intentionally different so we'll definitely
324 # notice if db1 and db2 are pointing at the same schema.
325 table = context.addTable("a", spec)
326 self.checkTable(spec, table)
328 def testInsertQueryDelete(self):
329 """Test the `Database.insert`, `Database.query`, and `Database.delete`
330 methods, as well as the `Base64Region` type and the ``onDelete``
331 argument to `ddl.ForeignKeySpec`.
332 """
333 db = self.makeEmptyDatabase(origin=1)
334 with db.declareStaticTables(create=True) as context:
335 tables = context.addTableTuple(STATIC_TABLE_SPECS)
336 # Insert a single, non-autoincrement row that contains a region and
337 # query to get it back.
338 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
339 row = {"name": "a1", "region": region}
340 db.insert(tables.a, row)
341 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row])
342 # Insert multiple autoincrement rows but do not try to get the IDs
343 # back immediately.
344 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
345 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()]
346 self.assertEqual(len(results), 2)
347 for row in results:
348 self.assertIn(row["name"], ("b1", "b2"))
349 self.assertIsInstance(row["id"], int)
350 self.assertGreater(results[1]["id"], results[0]["id"])
351 # Insert multiple autoincrement rows and get the IDs back from insert.
352 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
353 ids = db.insert(tables.b, *rows, returnIds=True)
354 results = [
355 dict(r) for r in db.query(
356 tables.b.select().where(tables.b.columns.id > results[1]["id"])
357 ).fetchall()
358 ]
359 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
360 self.assertCountEqual(results, expected)
361 self.assertTrue(all(result["id"] is not None for result in results))
362 # Insert multiple rows into a table with an autoincrement+origin
363 # primary key, then use the returned IDs to insert into a dynamic
364 # table.
365 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
366 {"origin": db.origin, "b_id": None}]
367 ids = db.insert(tables.c, *rows, returnIds=True)
368 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
369 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
370 self.assertCountEqual(results, expected)
371 self.assertTrue(all(result["id"] is not None for result in results))
372 # Add the dynamic table.
373 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
374 # Insert into it.
375 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
376 db.insert(d, *rows)
377 results = [dict(r) for r in db.query(d.select()).fetchall()]
378 self.assertCountEqual(rows, results)
379 # Insert multiple rows into a table with an autoincrement+origin
380 # primary key (this is especially tricky for SQLite, but good to test
381 # for all DBs), but pass in a value for the autoincrement key.
382 # For extra complexity, we re-use the autoincrement value with a
383 # different value for origin.
384 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
385 {"id": 700, "origin": 60, "b_id": None},
386 {"id": 1, "origin": 60, "b_id": None}]
387 db.insert(tables.c, *rows2)
388 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
389 self.assertCountEqual(results, expected + rows2)
390 self.assertTrue(all(result["id"] is not None for result in results))
392 # Define 'SELECT COUNT(*)' query for later use.
393 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()])
394 # Get the values we inserted into table b.
395 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()]
396 # Remove two row from table b by ID.
397 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
398 self.assertEqual(n, 2)
399 # Remove the other two rows from table b by name.
400 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
401 self.assertEqual(n, 2)
402 # There should now be no rows in table b.
403 self.assertEqual(
404 db.query(count.select_from(tables.b)).scalar(),
405 0
406 )
407 # All b_id values in table c should now be NULL, because there's an
408 # onDelete='SET NULL' foreign key.
409 self.assertEqual(
410 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
411 0
412 )
413 # Remove all rows in table a (there's only one); this should remove all
414 # rows in d due to onDelete='CASCADE'.
415 n = db.delete(tables.a, [])
416 self.assertEqual(n, 1)
417 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
418 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
420 def testUpdate(self):
421 """Tests for `Database.update`.
422 """
423 db = self.makeEmptyDatabase(origin=1)
424 with db.declareStaticTables(create=True) as context:
425 tables = context.addTableTuple(STATIC_TABLE_SPECS)
426 # Insert two rows into table a, both without regions.
427 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
428 # Update one of the rows with a region.
429 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
430 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
431 self.assertEqual(n, 1)
432 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a)
433 self.assertCountEqual(
434 [dict(r) for r in db.query(sql).fetchall()],
435 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
436 )
438 def testSync(self):
439 """Tests for `Database.sync`.
440 """
441 db = self.makeEmptyDatabase(origin=1)
442 with db.declareStaticTables(create=True) as context:
443 tables = context.addTableTuple(STATIC_TABLE_SPECS)
444 # Insert a row with sync, because it doesn't exist yet.
445 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
446 self.assertTrue(inserted)
447 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
448 [dict(r) for r in db.query(tables.b.select()).fetchall()])
449 # Repeat that operation, which should do nothing but return the
450 # requested values.
451 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
452 self.assertFalse(inserted)
453 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
454 [dict(r) for r in db.query(tables.b.select()).fetchall()])
455 # Repeat the operation without the 'extra' arg, which should also just
456 # return the existing row.
457 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
458 self.assertFalse(inserted)
459 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
460 [dict(r) for r in db.query(tables.b.select()).fetchall()])
461 # Repeat the operation with a different value in 'extra'. That still
462 # shouldn't be an error, because 'extra' is only used if we really do
463 # insert. Also drop the 'returning' argument.
464 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
465 self.assertFalse(inserted)
466 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
467 [dict(r) for r in db.query(tables.b.select()).fetchall()])
468 # Repeat the operation with the correct value in 'compared' instead of
469 # 'extra'.
470 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
471 self.assertFalse(inserted)
472 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
473 [dict(r) for r in db.query(tables.b.select()).fetchall()])
474 # Repeat the operation with an incorrect value in 'compared'; this
475 # should raise.
476 with self.assertRaises(DatabaseConflictError):
477 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
478 # Try to sync in a read-only database. This should work if and only
479 # if the matching row already exists.
480 with self.asReadOnly(db) as rodb:
481 with rodb.declareStaticTables(create=False) as context:
482 tables = context.addTableTuple(STATIC_TABLE_SPECS)
483 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
484 self.assertFalse(inserted)
485 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
486 [dict(r) for r in rodb.query(tables.b.select()).fetchall()])
487 with self.assertRaises(ReadOnlyDatabaseError):
488 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
490 def testReplace(self):
491 """Tests for `Database.replace`.
492 """
493 db = self.makeEmptyDatabase(origin=1)
494 with db.declareStaticTables(create=True) as context:
495 tables = context.addTableTuple(STATIC_TABLE_SPECS)
496 # Use 'replace' to insert a single row that contains a region and
497 # query to get it back.
498 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
499 row1 = {"name": "a1", "region": region}
500 db.replace(tables.a, row1)
501 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
502 # Insert another row without a region.
503 row2 = {"name": "a2", "region": None}
504 db.replace(tables.a, row2)
505 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
506 # Use replace to re-insert both of those rows again, which should do
507 # nothing.
508 db.replace(tables.a, row1, row2)
509 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
510 # Replace row1 with a row with no region, while reinserting row2.
511 row1a = {"name": "a1", "region": None}
512 db.replace(tables.a, row1a, row2)
513 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2])
514 # Replace both rows, returning row1 to its original state, while adding
515 # a new one. Pass them in in a different order.
516 row2a = {"name": "a2", "region": region}
517 row3 = {"name": "a3", "region": None}
518 db.replace(tables.a, row3, row2a, row1)
519 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3])
521 def testEnsure(self):
522 """Tests for `Database.ensure`.
523 """
524 db = self.makeEmptyDatabase(origin=1)
525 with db.declareStaticTables(create=True) as context:
526 tables = context.addTableTuple(STATIC_TABLE_SPECS)
527 # Use 'ensure' to insert a single row that contains a region and
528 # query to get it back.
529 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
530 row1 = {"name": "a1", "region": region}
531 self.assertEqual(db.ensure(tables.a, row1), 1)
532 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
533 # Insert another row without a region.
534 row2 = {"name": "a2", "region": None}
535 self.assertEqual(db.ensure(tables.a, row2), 1)
536 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
537 # Use ensure to re-insert both of those rows again, which should do
538 # nothing.
539 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
540 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
541 # Attempt to insert row1's key with no region, while
542 # reinserting row2. This should also do nothing.
543 row1a = {"name": "a1", "region": None}
544 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
545 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
546 # Attempt to insert new rows for both existing keys, this time also
547 # adding a new row. Pass them in in a different order. Only the new
548 # row should be added.
549 row2a = {"name": "a2", "region": region}
550 row3 = {"name": "a3", "region": None}
551 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
552 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2, row3])
554 def testTransactionNesting(self):
555 """Test that transactions can be nested with the behavior in the
556 presence of exceptions working as documented.
557 """
558 db = self.makeEmptyDatabase(origin=1)
559 with db.declareStaticTables(create=True) as context:
560 tables = context.addTableTuple(STATIC_TABLE_SPECS)
561 # Insert one row so we can trigger integrity errors by trying to insert
562 # a duplicate of it below.
563 db.insert(tables.a, {"name": "a1"})
564 # First test: error recovery via explicit savepoint=True in the inner
565 # transaction.
566 with db.transaction():
567 # This insert should succeed, and should not be rolled back because
568 # the assertRaises context should catch any exception before it
569 # propagates up to the outer transaction.
570 db.insert(tables.a, {"name": "a2"})
571 with self.assertRaises(sqlalchemy.exc.IntegrityError):
572 with db.transaction(savepoint=True):
573 # This insert should succeed, but should be rolled back.
574 db.insert(tables.a, {"name": "a4"})
575 # This insert should fail (duplicate primary key), raising
576 # an exception.
577 db.insert(tables.a, {"name": "a1"})
578 self.assertCountEqual(
579 [dict(r) for r in db.query(tables.a.select()).fetchall()],
580 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
581 )
582 # Second test: error recovery via implicit savepoint=True, when the
583 # innermost transaction is inside a savepoint=True transaction.
584 with db.transaction():
585 # This insert should succeed, and should not be rolled back
586 # because the assertRaises context should catch any
587 # exception before it propagates up to the outer
588 # transaction.
589 db.insert(tables.a, {"name": "a3"})
590 with self.assertRaises(sqlalchemy.exc.IntegrityError):
591 with db.transaction(savepoint=True):
592 # This insert should succeed, but should be rolled back.
593 db.insert(tables.a, {"name": "a4"})
594 with db.transaction():
595 # This insert should succeed, but should be rolled
596 # back.
597 db.insert(tables.a, {"name": "a5"})
598 # This insert should fail (duplicate primary key),
599 # raising an exception.
600 db.insert(tables.a, {"name": "a1"})
601 self.assertCountEqual(
602 [dict(r) for r in db.query(tables.a.select()).fetchall()],
603 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
604 )
606 def testTransactionLocking(self):
607 """Test that `Database.transaction` can be used to acquire a lock
608 that prohibits concurrent writes.
609 """
610 db1 = self.makeEmptyDatabase(origin=1)
611 with db1.declareStaticTables(create=True) as context:
612 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
614 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
615 """One side of the concurrent locking test.
617 This optionally locks the table (and maybe the whole database),
618 does a select for its contents, inserts a new row, and then selects
619 again, with some waiting in between to make sure the other side has
620 a chance to _attempt_ to insert in between. If the locking is
621 enabled and works, the difference between the selects should just
622 be the insert done on this thread.
623 """
624 # Give Side2 a chance to create a connection
625 await asyncio.sleep(1.0)
626 with db1.transaction(lock=lock):
627 names1 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
628 # Give Side2 a chance to insert (which will be blocked if
629 # we've acquired a lock).
630 await asyncio.sleep(2.0)
631 db1.insert(tables1.a, {"name": "a1"})
632 names2 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
633 return names1, names2
635 async def side2() -> None:
636 """The other side of the concurrent locking test.
638 This side just waits a bit and then tries to insert a row into the
639 table that the other side is trying to lock. Hopefully that
640 waiting is enough to give the other side a chance to acquire the
641 lock and thus make this side block until the lock is released. If
642 this side manages to do the insert before side1 acquires the lock,
643 we'll just warn about not succeeding at testing the locking,
644 because we can only make that unlikely, not impossible.
645 """
646 def toRunInThread():
647 """SQLite locking isn't asyncio-friendly unless we actually
648 run it in another thread. And SQLite gets very unhappy if
649 we try to use a connection from multiple threads, so we have
650 to create the new connection here instead of out in the main
651 body of the test function.
652 """
653 db2 = self.getNewConnection(db1, writeable=True)
654 with db2.declareStaticTables(create=False) as context:
655 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
656 with db2.transaction():
657 db2.insert(tables2.a, {"name": "a2"})
659 await asyncio.sleep(2.0)
660 loop = asyncio.get_running_loop()
661 with ThreadPoolExecutor() as pool:
662 await loop.run_in_executor(pool, toRunInThread)
664 async def testProblemsWithNoLocking() -> None:
665 """Run side1 and side2 with no locking, attempting to demonstrate
666 the problem that locking is supposed to solve. If we get unlucky
667 with scheduling, side2 will just happen to insert after side1 is
668 done, and we won't have anything definitive. We just warn in that
669 case because we really don't want spurious test failures.
670 """
671 task1 = asyncio.create_task(side1())
672 task2 = asyncio.create_task(side2())
674 names1, names2 = await task1
675 await task2
676 if "a2" in names1:
677 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
678 "happened before first SELECT.")
679 self.assertEqual(names1, {"a2"})
680 self.assertEqual(names2, {"a1", "a2"})
681 elif "a2" not in names2:
682 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
683 "happened after second SELECT even without locking.")
684 self.assertEqual(names1, set())
685 self.assertEqual(names2, {"a1"})
686 else:
687 # This is the expected case: both INSERTS happen between the
688 # two SELECTS. If we don't get this almost all of the time we
689 # should adjust the sleep amounts.
690 self.assertEqual(names1, set())
691 self.assertEqual(names2, {"a1", "a2"})
693 asyncio.run(testProblemsWithNoLocking())
695 # Clean up after first test.
696 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
698 async def testSolutionWithLocking() -> None:
699 """Run side1 and side2 with locking, which should make side2 block
700 its insert until side2 releases its lock.
701 """
702 task1 = asyncio.create_task(side1(lock=[tables1.a]))
703 task2 = asyncio.create_task(side2())
705 names1, names2 = await task1
706 await task2
707 if "a2" in names1:
708 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT "
709 "happened before first SELECT.")
710 self.assertEqual(names1, {"a2"})
711 self.assertEqual(names2, {"a1", "a2"})
712 else:
713 # This is the expected case: the side2 INSERT happens after the
714 # last SELECT on side1. This can also happen due to unlucky
715 # scheduling, and we have no way to detect that here, but the
716 # similar "no-locking" test has at least some chance of being
717 # affected by the same problem and warning about it.
718 self.assertEqual(names1, set())
719 self.assertEqual(names2, {"a1"})
721 asyncio.run(testSolutionWithLocking())
723 def testTimespanDatabaseRepresentation(self):
724 """Tests for `TimespanDatabaseRepresentation` and the `Database`
725 methods that interact with it.
726 """
727 # Make some test timespans to play with, with the full suite of
728 # topological relationships.
729 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai")
730 offset = astropy.time.TimeDelta(60, format="sec")
731 timestamps = [start + offset*n for n in range(3)]
732 aTimespans = [Timespan(begin=None, end=None)]
733 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
734 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
735 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
736 aTimespans.append(Timespan.makeEmpty())
737 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
738 # Make another list of timespans that span the full range but don't
739 # overlap. This is a subset of the previous list.
740 bTimespans = [Timespan(begin=None, end=timestamps[0])]
741 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
742 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
743 # Make a database and create a table with that database's timespan
744 # representation. This one will have no exclusion constraint and
745 # a nullable timespan.
746 db = self.makeEmptyDatabase(origin=1)
747 TimespanReprClass = db.getTimespanRepresentation()
748 aSpec = ddl.TableSpec(
749 fields=[
750 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
751 ],
752 )
753 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
754 aSpec.fields.add(fieldSpec)
755 with db.declareStaticTables(create=True) as context:
756 aTable = context.addTable("a", aSpec)
757 self.maxDiff = None
759 def convertRowForInsert(row: dict) -> dict:
760 """Convert a row containing a Timespan instance into one suitable
761 for insertion into the database.
762 """
763 result = row.copy()
764 ts = result.pop(TimespanReprClass.NAME)
765 return TimespanReprClass.update(ts, result=result)
767 def convertRowFromSelect(row: dict) -> dict:
768 """Convert a row from the database into one containing a Timespan.
769 """
770 result = row.copy()
771 timespan = TimespanReprClass.extract(result)
772 for name in TimespanReprClass.getFieldNames():
773 del result[name]
774 result[TimespanReprClass.NAME] = timespan
775 return result
777 # Insert rows into table A, in chunks just to make things interesting.
778 # Include one with a NULL timespan.
779 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
780 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
781 db.insert(aTable, convertRowForInsert(aRows[0]))
782 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
783 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
784 # Add another one with a NULL timespan, but this time by invoking
785 # the server-side default.
786 aRows.append({"id": len(aRows)})
787 db.insert(aTable, aRows[-1])
788 aRows[-1][TimespanReprClass.NAME] = None
789 # Test basic round-trip through database.
790 self.assertEqual(
791 aRows,
792 [convertRowFromSelect(dict(row))
793 for row in db.query(aTable.select().order_by(aTable.columns.id)).fetchall()]
794 )
795 # Create another table B with a not-null timespan and (if the database
796 # supports it), an exclusion constraint. Use ensureTableExists this
797 # time to check that mode of table creation vs. timespans.
798 bSpec = ddl.TableSpec(
799 fields=[
800 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
801 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
802 ],
803 )
804 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
805 bSpec.fields.add(fieldSpec)
806 if TimespanReprClass.hasExclusionConstraint():
807 bSpec.exclusion.add(("key", TimespanReprClass))
808 bTable = db.ensureTableExists("b", bSpec)
809 # Insert rows into table B, again in chunks. Each Timespan appears
810 # twice, but with different values for the 'key' field (which should
811 # still be okay for any exclusion constraint we may have defined).
812 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
813 offset = len(bRows)
814 bRows.extend({"id": n + offset, "key": 2, TimespanReprClass.NAME: t}
815 for n, t in enumerate(bTimespans))
816 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
817 db.insert(bTable, convertRowForInsert(bRows[2]))
818 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
819 # Insert a row with no timespan into table B. This should invoke the
820 # server-side default, which is a timespan over (-∞, ∞). We set
821 # key=3 to avoid upsetting an exclusion constraint that might exist.
822 bRows.append({"id": len(bRows), "key": 3})
823 db.insert(bTable, bRows[-1])
824 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
825 # Test basic round-trip through database.
826 self.assertEqual(
827 bRows,
828 [convertRowFromSelect(dict(row))
829 for row in db.query(bTable.select().order_by(bTable.columns.id)).fetchall()]
830 )
831 # Test that we can't insert timespan=None into this table.
832 with self.assertRaises(sqlalchemy.exc.IntegrityError):
833 db.insert(
834 bTable,
835 convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})
836 )
837 # IFF this database supports exclusion constraints, test that they
838 # also prevent inserts.
839 if TimespanReprClass.hasExclusionConstraint():
840 with self.assertRaises(sqlalchemy.exc.IntegrityError):
841 db.insert(
842 bTable,
843 convertRowForInsert({
844 "id": len(bRows), "key": 1,
845 TimespanReprClass.NAME: Timespan(None, timestamps[1])
846 })
847 )
848 with self.assertRaises(sqlalchemy.exc.IntegrityError):
849 db.insert(
850 bTable,
851 convertRowForInsert({
852 "id": len(bRows), "key": 1,
853 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2])
854 })
855 )
856 with self.assertRaises(sqlalchemy.exc.IntegrityError):
857 db.insert(
858 bTable,
859 convertRowForInsert({
860 "id": len(bRows), "key": 1,
861 TimespanReprClass.NAME: Timespan(timestamps[2], None)
862 })
863 )
864 # Test NULL checks in SELECT queries, on both tables.
865 aRepr = TimespanReprClass.fromSelectable(aTable)
866 self.assertEqual(
867 [row[TimespanReprClass.NAME] is None for row in aRows],
868 [
869 row["f"] for row in db.query(
870 sqlalchemy.sql.select(
871 [aRepr.isNull().label("f")]
872 ).order_by(
873 aTable.columns.id
874 )
875 ).fetchall()
876 ]
877 )
878 bRepr = TimespanReprClass.fromSelectable(bTable)
879 self.assertEqual(
880 [False for row in bRows],
881 [
882 row["f"] for row in db.query(
883 sqlalchemy.sql.select(
884 [bRepr.isNull().label("f")]
885 ).order_by(
886 bTable.columns.id
887 )
888 ).fetchall()
889 ]
890 )
891 # Test relationships expressions that relate in-database timespans to
892 # Python-literal timespans, all from the more complete 'a' set; check
893 # that this is consistent with Python-only relationship tests.
894 for rhsRow in aRows:
895 if rhsRow[TimespanReprClass.NAME] is None:
896 continue
897 with self.subTest(rhsRow=rhsRow):
898 expected = {}
899 for lhsRow in aRows:
900 if lhsRow[TimespanReprClass.NAME] is None:
901 expected[lhsRow["id"]] = (None, None, None, None)
902 else:
903 expected[lhsRow["id"]] = (
904 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
905 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
906 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
907 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
908 )
909 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
910 sql = sqlalchemy.sql.select([
911 aTable.columns.id.label("lhs"),
912 aRepr.overlaps(rhsRepr).label("overlaps"),
913 aRepr.contains(rhsRepr).label("contains"),
914 (aRepr < rhsRepr).label("less_than"),
915 (aRepr > rhsRepr).label("greater_than"),
916 ]).select_from(aTable)
917 queried = {
918 row["lhs"]: (row["overlaps"], row["contains"], row["less_than"], row["greater_than"])
919 for row in db.query(sql).fetchall()
920 }
921 self.assertEqual(expected, queried)
922 # Test relationship expressions that relate in-database timespans to
923 # each other, all from the more complete 'a' set; check that this is
924 # consistent with Python-only relationship tests.
925 expected = {}
926 for lhs, rhs in itertools.product(aRows, aRows):
927 lhsT = lhs[TimespanReprClass.NAME]
928 rhsT = rhs[TimespanReprClass.NAME]
929 if lhsT is not None and rhsT is not None:
930 expected[lhs["id"], rhs["id"]] = (
931 lhsT.overlaps(rhsT),
932 lhsT.contains(rhsT),
933 lhsT < rhsT,
934 lhsT > rhsT
935 )
936 else:
937 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
938 lhsSubquery = aTable.alias("lhs")
939 rhsSubquery = aTable.alias("rhs")
940 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery)
941 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery)
942 sql = sqlalchemy.sql.select(
943 [
944 lhsSubquery.columns.id.label("lhs"),
945 rhsSubquery.columns.id.label("rhs"),
946 lhsRepr.overlaps(rhsRepr).label("overlaps"),
947 lhsRepr.contains(rhsRepr).label("contains"),
948 (lhsRepr < rhsRepr).label("less_than"),
949 (lhsRepr > rhsRepr).label("greater_than"),
950 ]
951 ).select_from(
952 lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))
953 )
954 queried = {
955 (row["lhs"], row["rhs"]): (row["overlaps"], row["contains"],
956 row["less_than"], row["greater_than"])
957 for row in db.query(sql).fetchall()}
958 self.assertEqual(expected, queried)
959 # Test relationship expressions between in-database timespans and
960 # Python-literal instantaneous times.
961 for t in timestamps:
962 with self.subTest(t=t):
963 expected = {}
964 for lhsRow in aRows:
965 if lhsRow[TimespanReprClass.NAME] is None:
966 expected[lhsRow["id"]] = (None, None, None)
967 else:
968 expected[lhsRow["id"]] = (
969 lhsRow[TimespanReprClass.NAME].contains(t),
970 lhsRow[TimespanReprClass.NAME] < t,
971 lhsRow[TimespanReprClass.NAME] > t,
972 )
973 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
974 sql = sqlalchemy.sql.select([
975 aTable.columns.id.label("lhs"),
976 aRepr.contains(rhs).label("contains"),
977 (aRepr < rhs).label("less_than"),
978 (aRepr > rhs).label("greater_than"),
979 ]).select_from(aTable)
980 queried = {
981 row["lhs"]: (row["contains"], row["less_than"], row["greater_than"])
982 for row in db.query(sql).fetchall()
983 }
984 self.assertEqual(expected, queried)