Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25import asyncio
26import itertools
27import warnings
28from abc import ABC, abstractmethod
29from collections import namedtuple
30from concurrent.futures import ThreadPoolExecutor
31from typing import ContextManager, Iterable, Set, Tuple
33import astropy.time
34import sqlalchemy
35from lsst.sphgeom import ConvexPolygon, UnitVector3d
37from ...core import Timespan, ddl
38from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError
40StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
42STATIC_TABLE_SPECS = StaticTablesTuple(
43 a=ddl.TableSpec(
44 fields=[
45 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
46 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
47 ]
48 ),
49 b=ddl.TableSpec(
50 fields=[
51 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
52 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
53 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
54 ],
55 unique=[("name",)],
56 ),
57 c=ddl.TableSpec(
58 fields=[
59 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
60 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
61 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
62 ],
63 foreignKeys=[
64 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
65 ],
66 ),
67)
69DYNAMIC_TABLE_SPEC = ddl.TableSpec(
70 fields=[
71 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
72 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
73 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
74 ],
75 foreignKeys=[
76 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
77 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
78 ],
79)
81TEMPORARY_TABLE_SPEC = ddl.TableSpec(
82 fields=[
83 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
84 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
85 ],
86)
89class DatabaseTests(ABC):
90 """Generic tests for the `Database` interface that can be subclassed to
91 generate tests for concrete implementations.
92 """
94 @abstractmethod
95 def makeEmptyDatabase(self, origin: int = 0) -> Database:
96 """Return an empty `Database` with the given origin, or an
97 automatically-generated one if ``origin`` is `None`.
98 """
99 raise NotImplementedError()
101 @abstractmethod
102 def asReadOnly(self, database: Database) -> ContextManager[Database]:
103 """Return a context manager for a read-only connection into the given
104 database.
106 The original database should be considered unusable within the context
107 but safe to use again afterwards (this allows the context manager to
108 block write access by temporarily changing user permissions to really
109 guarantee that write operations are not performed).
110 """
111 raise NotImplementedError()
113 @abstractmethod
114 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
115 """Return a new `Database` instance that points to the same underlying
116 storage as the given one.
117 """
118 raise NotImplementedError()
120 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
121 self.assertCountEqual(spec.fields.names, table.columns.keys())
122 # Checking more than this currently seems fragile, as it might restrict
123 # what Database implementations do; we don't care if the spec is
124 # actually preserved in terms of types and constraints as long as we
125 # can use the returned table as if it was.
127 def checkStaticSchema(self, tables: StaticTablesTuple):
128 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
129 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
130 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
132 def testDeclareStaticTables(self):
133 """Tests for `Database.declareStaticSchema` and the methods it
134 delegates to.
135 """
136 # Create the static schema in a new, empty database.
137 newDatabase = self.makeEmptyDatabase()
138 with newDatabase.declareStaticTables(create=True) as context:
139 tables = context.addTableTuple(STATIC_TABLE_SPECS)
140 self.checkStaticSchema(tables)
141 # Check that we can load that schema even from a read-only connection.
142 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
143 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
144 tables = context.addTableTuple(STATIC_TABLE_SPECS)
145 self.checkStaticSchema(tables)
147 def testDeclareStaticTablesTwice(self):
148 """Tests for `Database.declareStaticSchema` being called twice."""
149 # Create the static schema in a new, empty database.
150 newDatabase = self.makeEmptyDatabase()
151 with newDatabase.declareStaticTables(create=True) as context:
152 tables = context.addTableTuple(STATIC_TABLE_SPECS)
153 self.checkStaticSchema(tables)
154 # Second time it should raise
155 with self.assertRaises(SchemaAlreadyDefinedError):
156 with newDatabase.declareStaticTables(create=True) as context:
157 tables = context.addTableTuple(STATIC_TABLE_SPECS)
158 # Check schema, it should still contain all tables, and maybe some
159 # extra.
160 with newDatabase.declareStaticTables(create=False) as context:
161 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
163 def testRepr(self):
164 """Test that repr does not return a generic thing."""
165 newDatabase = self.makeEmptyDatabase()
166 rep = repr(newDatabase)
167 # Check that stringification works and gives us something different
168 self.assertNotEqual(rep, str(newDatabase))
169 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
170 self.assertIn("://", rep)
172 def testDynamicTables(self):
173 """Tests for `Database.ensureTableExists` and
174 `Database.getExistingTable`.
175 """
176 # Need to start with the static schema.
177 newDatabase = self.makeEmptyDatabase()
178 with newDatabase.declareStaticTables(create=True) as context:
179 context.addTableTuple(STATIC_TABLE_SPECS)
180 # Try to ensure the dynamic table exists in a read-only version of that
181 # database, which should fail because we can't create it.
182 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
183 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
184 context.addTableTuple(STATIC_TABLE_SPECS)
185 with self.assertRaises(ReadOnlyDatabaseError):
186 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
187 # Just getting the dynamic table before it exists should return None.
188 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
189 # Ensure the new table exists back in the original database, which
190 # should create it.
191 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
192 self.checkTable(DYNAMIC_TABLE_SPEC, table)
193 # Ensuring that it exists should just return the exact same table
194 # instance again.
195 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
196 # Try again from the read-only database.
197 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
198 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
199 context.addTableTuple(STATIC_TABLE_SPECS)
200 # Just getting the dynamic table should now work...
201 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
202 # ...as should ensuring that it exists, since it now does.
203 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
204 self.checkTable(DYNAMIC_TABLE_SPEC, table)
205 # Trying to get the table with a different specification (at least
206 # in terms of what columns are present) should raise.
207 with self.assertRaises(DatabaseConflictError):
208 newDatabase.ensureTableExists(
209 "d",
210 ddl.TableSpec(
211 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
212 ),
213 )
214 # Calling ensureTableExists inside a transaction block is an error,
215 # even if it would do nothing.
216 with newDatabase.transaction():
217 with self.assertRaises(AssertionError):
218 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
220 def testTemporaryTables(self):
221 """Tests for `Database.makeTemporaryTable`,
222 `Database.dropTemporaryTable`, and `Database.insert` with
223 the ``select`` argument.
224 """
225 # Need to start with the static schema; also insert some test data.
226 newDatabase = self.makeEmptyDatabase()
227 with newDatabase.declareStaticTables(create=True) as context:
228 static = context.addTableTuple(STATIC_TABLE_SPECS)
229 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None})
230 bIds = newDatabase.insert(
231 static.b,
232 {"name": "b1", "value": 11},
233 {"name": "b2", "value": 12},
234 {"name": "b3", "value": 13},
235 returnIds=True,
236 )
237 # Create the table.
238 with newDatabase.session() as session:
239 table1 = session.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
240 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
241 # Insert via a INSERT INTO ... SELECT query.
242 newDatabase.insert(
243 table1,
244 select=sqlalchemy.sql.select(
245 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
246 )
247 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
248 .where(
249 sqlalchemy.sql.and_(
250 static.a.columns.name == "a1",
251 static.b.columns.value <= 12,
252 )
253 ),
254 )
255 # Check that the inserted rows are present.
256 self.assertCountEqual(
257 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
258 [row._asdict() for row in newDatabase.query(table1.select())],
259 )
260 # Create another one via a read-only connection to the database.
261 # We _do_ allow temporary table modifications in read-only
262 # databases.
263 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
264 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
265 context.addTableTuple(STATIC_TABLE_SPECS)
266 with existingReadOnlyDatabase.session() as session2:
267 table2 = session2.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
268 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
269 # Those tables should not be the same, despite having the
270 # same ddl.
271 self.assertIsNot(table1, table2)
272 # Do a slightly different insert into this table, to check
273 # that it works in a read-only database. This time we
274 # pass column names as a kwarg to insert instead of by
275 # labeling the columns in the select.
276 existingReadOnlyDatabase.insert(
277 table2,
278 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id)
279 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
280 .where(
281 sqlalchemy.sql.and_(
282 static.a.columns.name == "a2",
283 static.b.columns.value >= 12,
284 )
285 ),
286 names=["a_name", "b_id"],
287 )
288 # Check that the inserted rows are present.
289 self.assertCountEqual(
290 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
291 [row._asdict() for row in existingReadOnlyDatabase.query(table2.select())],
292 )
293 # Drop the temporary table from the read-only DB. It's
294 # unspecified whether attempting to use it after this
295 # point is an error or just never returns any results, so
296 # we can't test what it does, only that it's not an error.
297 session2.dropTemporaryTable(table2)
298 # Drop the original temporary table.
299 session.dropTemporaryTable(table1)
301 def testSchemaSeparation(self):
302 """Test that creating two different `Database` instances allows us
303 to create different tables with the same name in each.
304 """
305 db1 = self.makeEmptyDatabase(origin=1)
306 with db1.declareStaticTables(create=True) as context:
307 tables = context.addTableTuple(STATIC_TABLE_SPECS)
308 self.checkStaticSchema(tables)
310 db2 = self.makeEmptyDatabase(origin=2)
311 # Make the DDL here intentionally different so we'll definitely
312 # notice if db1 and db2 are pointing at the same schema.
313 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
314 with db2.declareStaticTables(create=True) as context:
315 # Make the DDL here intentionally different so we'll definitely
316 # notice if db1 and db2 are pointing at the same schema.
317 table = context.addTable("a", spec)
318 self.checkTable(spec, table)
320 def testInsertQueryDelete(self):
321 """Test the `Database.insert`, `Database.query`, and `Database.delete`
322 methods, as well as the `Base64Region` type and the ``onDelete``
323 argument to `ddl.ForeignKeySpec`.
324 """
325 db = self.makeEmptyDatabase(origin=1)
326 with db.declareStaticTables(create=True) as context:
327 tables = context.addTableTuple(STATIC_TABLE_SPECS)
328 # Insert a single, non-autoincrement row that contains a region and
329 # query to get it back.
330 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
331 row = {"name": "a1", "region": region}
332 db.insert(tables.a, row)
333 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row])
334 # Insert multiple autoincrement rows but do not try to get the IDs
335 # back immediately.
336 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
337 results = [r._asdict() for r in db.query(tables.b.select().order_by("id"))]
338 self.assertEqual(len(results), 2)
339 for row in results:
340 self.assertIn(row["name"], ("b1", "b2"))
341 self.assertIsInstance(row["id"], int)
342 self.assertGreater(results[1]["id"], results[0]["id"])
343 # Insert multiple autoincrement rows and get the IDs back from insert.
344 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
345 ids = db.insert(tables.b, *rows, returnIds=True)
346 results = [
347 r._asdict() for r in db.query(tables.b.select().where(tables.b.columns.id > results[1]["id"]))
348 ]
349 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
350 self.assertCountEqual(results, expected)
351 self.assertTrue(all(result["id"] is not None for result in results))
352 # Insert multiple rows into a table with an autoincrement+origin
353 # primary key, then use the returned IDs to insert into a dynamic
354 # table.
355 rows = [{"origin": db.origin, "b_id": results[0]["id"]}, {"origin": db.origin, "b_id": None}]
356 ids = db.insert(tables.c, *rows, returnIds=True)
357 results = [r._asdict() for r in db.query(tables.c.select())]
358 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
359 self.assertCountEqual(results, expected)
360 self.assertTrue(all(result["id"] is not None for result in results))
361 # Add the dynamic table.
362 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
363 # Insert into it.
364 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
365 db.insert(d, *rows)
366 results = [r._asdict() for r in db.query(d.select())]
367 self.assertCountEqual(rows, results)
368 # Insert multiple rows into a table with an autoincrement+origin
369 # primary key (this is especially tricky for SQLite, but good to test
370 # for all DBs), but pass in a value for the autoincrement key.
371 # For extra complexity, we re-use the autoincrement value with a
372 # different value for origin.
373 rows2 = [
374 {"id": 700, "origin": db.origin, "b_id": None},
375 {"id": 700, "origin": 60, "b_id": None},
376 {"id": 1, "origin": 60, "b_id": None},
377 ]
378 db.insert(tables.c, *rows2)
379 results = [r._asdict() for r in db.query(tables.c.select())]
380 self.assertCountEqual(results, expected + rows2)
381 self.assertTrue(all(result["id"] is not None for result in results))
383 # Define 'SELECT COUNT(*)' query for later use.
384 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
385 # Get the values we inserted into table b.
386 bValues = [r._asdict() for r in db.query(tables.b.select())]
387 # Remove two row from table b by ID.
388 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
389 self.assertEqual(n, 2)
390 # Remove the other two rows from table b by name.
391 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
392 self.assertEqual(n, 2)
393 # There should now be no rows in table b.
394 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0)
395 # All b_id values in table c should now be NULL, because there's an
396 # onDelete='SET NULL' foreign key.
397 self.assertEqual(
398 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
399 0,
400 )
401 # Remove all rows in table a (there's only one); this should remove all
402 # rows in d due to onDelete='CASCADE'.
403 n = db.delete(tables.a, [])
404 self.assertEqual(n, 1)
405 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
406 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
408 def testDeleteWhere(self):
409 """Tests for `Database.deleteWhere`."""
410 db = self.makeEmptyDatabase(origin=1)
411 with db.declareStaticTables(create=True) as context:
412 tables = context.addTableTuple(STATIC_TABLE_SPECS)
413 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
414 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
416 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
417 self.assertEqual(n, 3)
418 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 7)
420 n = db.deleteWhere(
421 tables.b,
422 tables.b.columns.id.in_(
423 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
424 ),
425 )
426 self.assertEqual(n, 4)
427 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 3)
429 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
430 self.assertEqual(n, 1)
431 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 2)
433 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
434 self.assertEqual(n, 2)
435 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0)
437 def testUpdate(self):
438 """Tests for `Database.update`."""
439 db = self.makeEmptyDatabase(origin=1)
440 with db.declareStaticTables(create=True) as context:
441 tables = context.addTableTuple(STATIC_TABLE_SPECS)
442 # Insert two rows into table a, both without regions.
443 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
444 # Update one of the rows with a region.
445 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
446 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
447 self.assertEqual(n, 1)
448 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
449 self.assertCountEqual(
450 [r._asdict() for r in db.query(sql)],
451 [{"name": "a1", "region": None}, {"name": "a2", "region": region}],
452 )
454 def testSync(self):
455 """Tests for `Database.sync`."""
456 db = self.makeEmptyDatabase(origin=1)
457 with db.declareStaticTables(create=True) as context:
458 tables = context.addTableTuple(STATIC_TABLE_SPECS)
459 # Insert a row with sync, because it doesn't exist yet.
460 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
461 self.assertTrue(inserted)
462 self.assertEqual(
463 [{"id": values["id"], "name": "b1", "value": 10}],
464 [r._asdict() for r in db.query(tables.b.select())],
465 )
466 # Repeat that operation, which should do nothing but return the
467 # requested values.
468 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
469 self.assertFalse(inserted)
470 self.assertEqual(
471 [{"id": values["id"], "name": "b1", "value": 10}],
472 [r._asdict() for r in db.query(tables.b.select())],
473 )
474 # Repeat the operation without the 'extra' arg, which should also just
475 # return the existing row.
476 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
477 self.assertFalse(inserted)
478 self.assertEqual(
479 [{"id": values["id"], "name": "b1", "value": 10}],
480 [r._asdict() for r in db.query(tables.b.select())],
481 )
482 # Repeat the operation with a different value in 'extra'. That still
483 # shouldn't be an error, because 'extra' is only used if we really do
484 # insert. Also drop the 'returning' argument.
485 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
486 self.assertFalse(inserted)
487 self.assertEqual(
488 [{"id": values["id"], "name": "b1", "value": 10}],
489 [r._asdict() for r in db.query(tables.b.select())],
490 )
491 # Repeat the operation with the correct value in 'compared' instead of
492 # 'extra'.
493 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
494 self.assertFalse(inserted)
495 self.assertEqual(
496 [{"id": values["id"], "name": "b1", "value": 10}],
497 [r._asdict() for r in db.query(tables.b.select())],
498 )
499 # Repeat the operation with an incorrect value in 'compared'; this
500 # should raise.
501 with self.assertRaises(DatabaseConflictError):
502 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
503 # Try to sync in a read-only database. This should work if and only
504 # if the matching row already exists.
505 with self.asReadOnly(db) as rodb:
506 with rodb.declareStaticTables(create=False) as context:
507 tables = context.addTableTuple(STATIC_TABLE_SPECS)
508 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
509 self.assertFalse(inserted)
510 self.assertEqual(
511 [{"id": values["id"], "name": "b1", "value": 10}],
512 [r._asdict() for r in rodb.query(tables.b.select())],
513 )
514 with self.assertRaises(ReadOnlyDatabaseError):
515 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
516 # Repeat the operation with a different value in 'compared' and ask to
517 # update.
518 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
519 self.assertEqual(updated, {"value": 10})
520 self.assertEqual(
521 [{"id": values["id"], "name": "b1", "value": 20}],
522 [r._asdict() for r in db.query(tables.b.select())],
523 )
525 def testReplace(self):
526 """Tests for `Database.replace`."""
527 db = self.makeEmptyDatabase(origin=1)
528 with db.declareStaticTables(create=True) as context:
529 tables = context.addTableTuple(STATIC_TABLE_SPECS)
530 # Use 'replace' to insert a single row that contains a region and
531 # query to get it back.
532 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
533 row1 = {"name": "a1", "region": region}
534 db.replace(tables.a, row1)
535 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
536 # Insert another row without a region.
537 row2 = {"name": "a2", "region": None}
538 db.replace(tables.a, row2)
539 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
540 # Use replace to re-insert both of those rows again, which should do
541 # nothing.
542 db.replace(tables.a, row1, row2)
543 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
544 # Replace row1 with a row with no region, while reinserting row2.
545 row1a = {"name": "a1", "region": None}
546 db.replace(tables.a, row1a, row2)
547 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1a, row2])
548 # Replace both rows, returning row1 to its original state, while adding
549 # a new one. Pass them in in a different order.
550 row2a = {"name": "a2", "region": region}
551 row3 = {"name": "a3", "region": None}
552 db.replace(tables.a, row3, row2a, row1)
553 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2a, row3])
555 def testEnsure(self):
556 """Tests for `Database.ensure`."""
557 db = self.makeEmptyDatabase(origin=1)
558 with db.declareStaticTables(create=True) as context:
559 tables = context.addTableTuple(STATIC_TABLE_SPECS)
560 # Use 'ensure' to insert a single row that contains a region and
561 # query to get it back.
562 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
563 row1 = {"name": "a1", "region": region}
564 self.assertEqual(db.ensure(tables.a, row1), 1)
565 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
566 # Insert another row without a region.
567 row2 = {"name": "a2", "region": None}
568 self.assertEqual(db.ensure(tables.a, row2), 1)
569 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
570 # Use ensure to re-insert both of those rows again, which should do
571 # nothing.
572 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
573 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
574 # Attempt to insert row1's key with no region, while
575 # reinserting row2. This should also do nothing.
576 row1a = {"name": "a1", "region": None}
577 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
578 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
579 # Attempt to insert new rows for both existing keys, this time also
580 # adding a new row. Pass them in in a different order. Only the new
581 # row should be added.
582 row2a = {"name": "a2", "region": region}
583 row3 = {"name": "a3", "region": None}
584 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
585 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2, row3])
587 def testTransactionNesting(self):
588 """Test that transactions can be nested with the behavior in the
589 presence of exceptions working as documented.
590 """
591 db = self.makeEmptyDatabase(origin=1)
592 with db.declareStaticTables(create=True) as context:
593 tables = context.addTableTuple(STATIC_TABLE_SPECS)
594 # Insert one row so we can trigger integrity errors by trying to insert
595 # a duplicate of it below.
596 db.insert(tables.a, {"name": "a1"})
597 # First test: error recovery via explicit savepoint=True in the inner
598 # transaction.
599 with db.transaction():
600 # This insert should succeed, and should not be rolled back because
601 # the assertRaises context should catch any exception before it
602 # propagates up to the outer transaction.
603 db.insert(tables.a, {"name": "a2"})
604 with self.assertRaises(sqlalchemy.exc.IntegrityError):
605 with db.transaction(savepoint=True):
606 # This insert should succeed, but should be rolled back.
607 db.insert(tables.a, {"name": "a4"})
608 # This insert should fail (duplicate primary key), raising
609 # an exception.
610 db.insert(tables.a, {"name": "a1"})
611 self.assertCountEqual(
612 [r._asdict() for r in db.query(tables.a.select())],
613 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
614 )
615 # Second test: error recovery via implicit savepoint=True, when the
616 # innermost transaction is inside a savepoint=True transaction.
617 with db.transaction():
618 # This insert should succeed, and should not be rolled back
619 # because the assertRaises context should catch any
620 # exception before it propagates up to the outer
621 # transaction.
622 db.insert(tables.a, {"name": "a3"})
623 with self.assertRaises(sqlalchemy.exc.IntegrityError):
624 with db.transaction(savepoint=True):
625 # This insert should succeed, but should be rolled back.
626 db.insert(tables.a, {"name": "a4"})
627 with db.transaction():
628 # This insert should succeed, but should be rolled
629 # back.
630 db.insert(tables.a, {"name": "a5"})
631 # This insert should fail (duplicate primary key),
632 # raising an exception.
633 db.insert(tables.a, {"name": "a1"})
634 self.assertCountEqual(
635 [r._asdict() for r in db.query(tables.a.select())],
636 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
637 )
639 def testTransactionLocking(self):
640 """Test that `Database.transaction` can be used to acquire a lock
641 that prohibits concurrent writes.
642 """
643 db1 = self.makeEmptyDatabase(origin=1)
644 with db1.declareStaticTables(create=True) as context:
645 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
647 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
648 """One side of the concurrent locking test.
650 This optionally locks the table (and maybe the whole database),
651 does a select for its contents, inserts a new row, and then selects
652 again, with some waiting in between to make sure the other side has
653 a chance to _attempt_ to insert in between. If the locking is
654 enabled and works, the difference between the selects should just
655 be the insert done on this thread.
656 """
657 # Give Side2 a chance to create a connection
658 await asyncio.sleep(1.0)
659 with db1.transaction(lock=lock):
660 names1 = {row.name for row in db1.query(tables1.a.select())}
661 # Give Side2 a chance to insert (which will be blocked if
662 # we've acquired a lock).
663 await asyncio.sleep(2.0)
664 db1.insert(tables1.a, {"name": "a1"})
665 names2 = {row.name for row in db1.query(tables1.a.select())}
666 return names1, names2
668 async def side2() -> None:
669 """The other side of the concurrent locking test.
671 This side just waits a bit and then tries to insert a row into the
672 table that the other side is trying to lock. Hopefully that
673 waiting is enough to give the other side a chance to acquire the
674 lock and thus make this side block until the lock is released. If
675 this side manages to do the insert before side1 acquires the lock,
676 we'll just warn about not succeeding at testing the locking,
677 because we can only make that unlikely, not impossible.
678 """
680 def toRunInThread():
681 """SQLite locking isn't asyncio-friendly unless we actually
682 run it in another thread. And SQLite gets very unhappy if
683 we try to use a connection from multiple threads, so we have
684 to create the new connection here instead of out in the main
685 body of the test function.
686 """
687 db2 = self.getNewConnection(db1, writeable=True)
688 with db2.declareStaticTables(create=False) as context:
689 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
690 with db2.transaction():
691 db2.insert(tables2.a, {"name": "a2"})
693 await asyncio.sleep(2.0)
694 loop = asyncio.get_running_loop()
695 with ThreadPoolExecutor() as pool:
696 await loop.run_in_executor(pool, toRunInThread)
698 async def testProblemsWithNoLocking() -> None:
699 """Run side1 and side2 with no locking, attempting to demonstrate
700 the problem that locking is supposed to solve. If we get unlucky
701 with scheduling, side2 will just happen to insert after side1 is
702 done, and we won't have anything definitive. We just warn in that
703 case because we really don't want spurious test failures.
704 """
705 task1 = asyncio.create_task(side1())
706 task2 = asyncio.create_task(side2())
708 names1, names2 = await task1
709 await task2
710 if "a2" in names1:
711 warnings.warn(
712 "Unlucky scheduling in no-locking test: concurrent INSERT "
713 "happened before first SELECT."
714 )
715 self.assertEqual(names1, {"a2"})
716 self.assertEqual(names2, {"a1", "a2"})
717 elif "a2" not in names2:
718 warnings.warn(
719 "Unlucky scheduling in no-locking test: concurrent INSERT "
720 "happened after second SELECT even without locking."
721 )
722 self.assertEqual(names1, set())
723 self.assertEqual(names2, {"a1"})
724 else:
725 # This is the expected case: both INSERTS happen between the
726 # two SELECTS. If we don't get this almost all of the time we
727 # should adjust the sleep amounts.
728 self.assertEqual(names1, set())
729 self.assertEqual(names2, {"a1", "a2"})
731 asyncio.run(testProblemsWithNoLocking())
733 # Clean up after first test.
734 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
736 async def testSolutionWithLocking() -> None:
737 """Run side1 and side2 with locking, which should make side2 block
738 its insert until side2 releases its lock.
739 """
740 task1 = asyncio.create_task(side1(lock=[tables1.a]))
741 task2 = asyncio.create_task(side2())
743 names1, names2 = await task1
744 await task2
745 if "a2" in names1:
746 warnings.warn(
747 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT."
748 )
749 self.assertEqual(names1, {"a2"})
750 self.assertEqual(names2, {"a1", "a2"})
751 else:
752 # This is the expected case: the side2 INSERT happens after the
753 # last SELECT on side1. This can also happen due to unlucky
754 # scheduling, and we have no way to detect that here, but the
755 # similar "no-locking" test has at least some chance of being
756 # affected by the same problem and warning about it.
757 self.assertEqual(names1, set())
758 self.assertEqual(names2, {"a1"})
760 asyncio.run(testSolutionWithLocking())
762 def testTimespanDatabaseRepresentation(self):
763 """Tests for `TimespanDatabaseRepresentation` and the `Database`
764 methods that interact with it.
765 """
766 # Make some test timespans to play with, with the full suite of
767 # topological relationships.
768 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
769 offset = astropy.time.TimeDelta(60, format="sec")
770 timestamps = [start + offset * n for n in range(3)]
771 aTimespans = [Timespan(begin=None, end=None)]
772 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
773 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
774 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
775 aTimespans.append(Timespan.makeEmpty())
776 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
777 # Make another list of timespans that span the full range but don't
778 # overlap. This is a subset of the previous list.
779 bTimespans = [Timespan(begin=None, end=timestamps[0])]
780 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
781 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
782 # Make a database and create a table with that database's timespan
783 # representation. This one will have no exclusion constraint and
784 # a nullable timespan.
785 db = self.makeEmptyDatabase(origin=1)
786 TimespanReprClass = db.getTimespanRepresentation()
787 aSpec = ddl.TableSpec(
788 fields=[
789 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
790 ],
791 )
792 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
793 aSpec.fields.add(fieldSpec)
794 with db.declareStaticTables(create=True) as context:
795 aTable = context.addTable("a", aSpec)
796 self.maxDiff = None
798 def convertRowForInsert(row: dict) -> dict:
799 """Convert a row containing a Timespan instance into one suitable
800 for insertion into the database.
801 """
802 result = row.copy()
803 ts = result.pop(TimespanReprClass.NAME)
804 return TimespanReprClass.update(ts, result=result)
806 def convertRowFromSelect(row: dict) -> dict:
807 """Convert a row from the database into one containing a Timespan.
809 Parameters
810 ----------
811 row : `dict`
812 Original row.
814 Returns
815 -------
816 row : `dict`
817 The updated row.
818 """
819 result = row.copy()
820 timespan = TimespanReprClass.extract(result)
821 for name in TimespanReprClass.getFieldNames():
822 del result[name]
823 result[TimespanReprClass.NAME] = timespan
824 return result
826 # Insert rows into table A, in chunks just to make things interesting.
827 # Include one with a NULL timespan.
828 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
829 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
830 db.insert(aTable, convertRowForInsert(aRows[0]))
831 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
832 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
833 # Add another one with a NULL timespan, but this time by invoking
834 # the server-side default.
835 aRows.append({"id": len(aRows)})
836 db.insert(aTable, aRows[-1])
837 aRows[-1][TimespanReprClass.NAME] = None
838 # Test basic round-trip through database.
839 self.assertEqual(
840 aRows,
841 [
842 convertRowFromSelect(row._asdict())
843 for row in db.query(aTable.select().order_by(aTable.columns.id))
844 ],
845 )
846 # Create another table B with a not-null timespan and (if the database
847 # supports it), an exclusion constraint. Use ensureTableExists this
848 # time to check that mode of table creation vs. timespans.
849 bSpec = ddl.TableSpec(
850 fields=[
851 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
852 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
853 ],
854 )
855 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
856 bSpec.fields.add(fieldSpec)
857 if TimespanReprClass.hasExclusionConstraint():
858 bSpec.exclusion.add(("key", TimespanReprClass))
859 bTable = db.ensureTableExists("b", bSpec)
860 # Insert rows into table B, again in chunks. Each Timespan appears
861 # twice, but with different values for the 'key' field (which should
862 # still be okay for any exclusion constraint we may have defined).
863 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
864 offset = len(bRows)
865 bRows.extend(
866 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)
867 )
868 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
869 db.insert(bTable, convertRowForInsert(bRows[2]))
870 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
871 # Insert a row with no timespan into table B. This should invoke the
872 # server-side default, which is a timespan over (-∞, ∞). We set
873 # key=3 to avoid upsetting an exclusion constraint that might exist.
874 bRows.append({"id": len(bRows), "key": 3})
875 db.insert(bTable, bRows[-1])
876 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
877 # Test basic round-trip through database.
878 self.assertEqual(
879 bRows,
880 [
881 convertRowFromSelect(row._asdict())
882 for row in db.query(bTable.select().order_by(bTable.columns.id))
883 ],
884 )
885 # Test that we can't insert timespan=None into this table.
886 with self.assertRaises(sqlalchemy.exc.IntegrityError):
887 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None}))
888 # IFF this database supports exclusion constraints, test that they
889 # also prevent inserts.
890 if TimespanReprClass.hasExclusionConstraint():
891 with self.assertRaises(sqlalchemy.exc.IntegrityError):
892 db.insert(
893 bTable,
894 convertRowForInsert(
895 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])}
896 ),
897 )
898 with self.assertRaises(sqlalchemy.exc.IntegrityError):
899 db.insert(
900 bTable,
901 convertRowForInsert(
902 {
903 "id": len(bRows),
904 "key": 1,
905 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]),
906 }
907 ),
908 )
909 with self.assertRaises(sqlalchemy.exc.IntegrityError):
910 db.insert(
911 bTable,
912 convertRowForInsert(
913 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)}
914 ),
915 )
916 # Test NULL checks in SELECT queries, on both tables.
917 aRepr = TimespanReprClass.fromSelectable(aTable)
918 self.assertEqual(
919 [row[TimespanReprClass.NAME] is None for row in aRows],
920 [
921 row.f
922 for row in db.query(
923 sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id)
924 )
925 ],
926 )
927 bRepr = TimespanReprClass.fromSelectable(bTable)
928 self.assertEqual(
929 [False for row in bRows],
930 [
931 row.f
932 for row in db.query(
933 sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id)
934 )
935 ],
936 )
937 # Test relationships expressions that relate in-database timespans to
938 # Python-literal timespans, all from the more complete 'a' set; check
939 # that this is consistent with Python-only relationship tests.
940 for rhsRow in aRows:
941 if rhsRow[TimespanReprClass.NAME] is None:
942 continue
943 with self.subTest(rhsRow=rhsRow):
944 expected = {}
945 for lhsRow in aRows:
946 if lhsRow[TimespanReprClass.NAME] is None:
947 expected[lhsRow["id"]] = (None, None, None, None)
948 else:
949 expected[lhsRow["id"]] = (
950 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
951 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
952 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
953 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
954 )
955 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
956 sql = sqlalchemy.sql.select(
957 aTable.columns.id.label("lhs"),
958 aRepr.overlaps(rhsRepr).label("overlaps"),
959 aRepr.contains(rhsRepr).label("contains"),
960 (aRepr < rhsRepr).label("less_than"),
961 (aRepr > rhsRepr).label("greater_than"),
962 ).select_from(aTable)
963 queried = {
964 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
965 for row in db.query(sql)
966 }
967 self.assertEqual(expected, queried)
968 # Test relationship expressions that relate in-database timespans to
969 # each other, all from the more complete 'a' set; check that this is
970 # consistent with Python-only relationship tests.
971 expected = {}
972 for lhs, rhs in itertools.product(aRows, aRows):
973 lhsT = lhs[TimespanReprClass.NAME]
974 rhsT = rhs[TimespanReprClass.NAME]
975 if lhsT is not None and rhsT is not None:
976 expected[lhs["id"], rhs["id"]] = (
977 lhsT.overlaps(rhsT),
978 lhsT.contains(rhsT),
979 lhsT < rhsT,
980 lhsT > rhsT,
981 )
982 else:
983 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
984 lhsSubquery = aTable.alias("lhs")
985 rhsSubquery = aTable.alias("rhs")
986 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery)
987 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery)
988 sql = sqlalchemy.sql.select(
989 lhsSubquery.columns.id.label("lhs"),
990 rhsSubquery.columns.id.label("rhs"),
991 lhsRepr.overlaps(rhsRepr).label("overlaps"),
992 lhsRepr.contains(rhsRepr).label("contains"),
993 (lhsRepr < rhsRepr).label("less_than"),
994 (lhsRepr > rhsRepr).label("greater_than"),
995 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True)))
996 queried = {
997 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
998 for row in db.query(sql)
999 }
1000 self.assertEqual(expected, queried)
1001 # Test relationship expressions between in-database timespans and
1002 # Python-literal instantaneous times.
1003 for t in timestamps:
1004 with self.subTest(t=t):
1005 expected = {}
1006 for lhsRow in aRows:
1007 if lhsRow[TimespanReprClass.NAME] is None:
1008 expected[lhsRow["id"]] = (None, None, None)
1009 else:
1010 expected[lhsRow["id"]] = (
1011 lhsRow[TimespanReprClass.NAME].contains(t),
1012 lhsRow[TimespanReprClass.NAME] < t,
1013 lhsRow[TimespanReprClass.NAME] > t,
1014 )
1015 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1016 sql = sqlalchemy.sql.select(
1017 aTable.columns.id.label("lhs"),
1018 aRepr.contains(rhs).label("contains"),
1019 (aRepr < rhs).label("less_than"),
1020 (aRepr > rhs).label("greater_than"),
1021 ).select_from(aTable)
1022 queried = {row.lhs: (row.contains, row.less_than, row.greater_than) for row in db.query(sql)}
1023 self.assertEqual(expected, queried)