Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26import asyncio
27from collections import namedtuple
28from concurrent.futures import ThreadPoolExecutor
29import itertools
30from typing import ContextManager, Iterable, Set, Tuple
31import warnings
33import astropy.time
34import sqlalchemy
36from lsst.sphgeom import ConvexPolygon, UnitVector3d
37from ..interfaces import (
38 Database,
39 ReadOnlyDatabaseError,
40 DatabaseConflictError,
41 SchemaAlreadyDefinedError
42)
43from ...core import ddl, Timespan
45StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
47STATIC_TABLE_SPECS = StaticTablesTuple(
48 a=ddl.TableSpec(
49 fields=[
50 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
51 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
52 ]
53 ),
54 b=ddl.TableSpec(
55 fields=[
56 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
58 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
59 ],
60 unique=[("name",)],
61 ),
62 c=ddl.TableSpec(
63 fields=[
64 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
65 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
66 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
67 ],
68 foreignKeys=[
69 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
70 ]
71 ),
72)
74DYNAMIC_TABLE_SPEC = ddl.TableSpec(
75 fields=[
76 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
77 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
78 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
79 ],
80 foreignKeys=[
81 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
82 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
83 ]
84)
86TEMPORARY_TABLE_SPEC = ddl.TableSpec(
87 fields=[
88 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
89 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
90 ],
91)
94class DatabaseTests(ABC):
95 """Generic tests for the `Database` interface that can be subclassed to
96 generate tests for concrete implementations.
97 """
99 @abstractmethod
100 def makeEmptyDatabase(self, origin: int = 0) -> Database:
101 """Return an empty `Database` with the given origin, or an
102 automatically-generated one if ``origin`` is `None`.
103 """
104 raise NotImplementedError()
106 @abstractmethod
107 def asReadOnly(self, database: Database) -> ContextManager[Database]:
108 """Return a context manager for a read-only connection into the given
109 database.
111 The original database should be considered unusable within the context
112 but safe to use again afterwards (this allows the context manager to
113 block write access by temporarily changing user permissions to really
114 guarantee that write operations are not performed).
115 """
116 raise NotImplementedError()
118 @abstractmethod
119 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
120 """Return a new `Database` instance that points to the same underlying
121 storage as the given one.
122 """
123 raise NotImplementedError()
125 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
126 self.assertCountEqual(spec.fields.names, table.columns.keys())
127 # Checking more than this currently seems fragile, as it might restrict
128 # what Database implementations do; we don't care if the spec is
129 # actually preserved in terms of types and constraints as long as we
130 # can use the returned table as if it was.
132 def checkStaticSchema(self, tables: StaticTablesTuple):
133 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
134 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
135 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
137 def testDeclareStaticTables(self):
138 """Tests for `Database.declareStaticSchema` and the methods it
139 delegates to.
140 """
141 # Create the static schema in a new, empty database.
142 newDatabase = self.makeEmptyDatabase()
143 with newDatabase.declareStaticTables(create=True) as context:
144 tables = context.addTableTuple(STATIC_TABLE_SPECS)
145 self.checkStaticSchema(tables)
146 # Check that we can load that schema even from a read-only connection.
147 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
148 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
149 tables = context.addTableTuple(STATIC_TABLE_SPECS)
150 self.checkStaticSchema(tables)
152 def testDeclareStaticTablesTwice(self):
153 """Tests for `Database.declareStaticSchema` being called twice.
154 """
155 # Create the static schema in a new, empty database.
156 newDatabase = self.makeEmptyDatabase()
157 with newDatabase.declareStaticTables(create=True) as context:
158 tables = context.addTableTuple(STATIC_TABLE_SPECS)
159 self.checkStaticSchema(tables)
160 # Second time it should raise
161 with self.assertRaises(SchemaAlreadyDefinedError):
162 with newDatabase.declareStaticTables(create=True) as context:
163 tables = context.addTableTuple(STATIC_TABLE_SPECS)
164 # Check schema, it should still contain all tables, and maybe some
165 # extra.
166 with newDatabase.declareStaticTables(create=False) as context:
167 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
169 def testRepr(self):
170 """Test that repr does not return a generic thing."""
171 newDatabase = self.makeEmptyDatabase()
172 rep = repr(newDatabase)
173 # Check that stringification works and gives us something different
174 self.assertNotEqual(rep, str(newDatabase))
175 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
176 self.assertIn("://", rep)
178 def testDynamicTables(self):
179 """Tests for `Database.ensureTableExists` and
180 `Database.getExistingTable`.
181 """
182 # Need to start with the static schema.
183 newDatabase = self.makeEmptyDatabase()
184 with newDatabase.declareStaticTables(create=True) as context:
185 context.addTableTuple(STATIC_TABLE_SPECS)
186 # Try to ensure the dynamic table exists in a read-only version of that
187 # database, which should fail because we can't create it.
188 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
189 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
190 context.addTableTuple(STATIC_TABLE_SPECS)
191 with self.assertRaises(ReadOnlyDatabaseError):
192 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
193 # Just getting the dynamic table before it exists should return None.
194 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
195 # Ensure the new table exists back in the original database, which
196 # should create it.
197 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
198 self.checkTable(DYNAMIC_TABLE_SPEC, table)
199 # Ensuring that it exists should just return the exact same table
200 # instance again.
201 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
202 # Try again from the read-only database.
203 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
204 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
205 context.addTableTuple(STATIC_TABLE_SPECS)
206 # Just getting the dynamic table should now work...
207 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
208 # ...as should ensuring that it exists, since it now does.
209 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
210 self.checkTable(DYNAMIC_TABLE_SPEC, table)
211 # Trying to get the table with a different specification (at least
212 # in terms of what columns are present) should raise.
213 with self.assertRaises(DatabaseConflictError):
214 newDatabase.ensureTableExists(
215 "d",
216 ddl.TableSpec(
217 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
218 )
219 )
220 # Calling ensureTableExists inside a transaction block is an error,
221 # even if it would do nothing.
222 with newDatabase.transaction():
223 with self.assertRaises(AssertionError):
224 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
226 def testTemporaryTables(self):
227 """Tests for `Database.makeTemporaryTable`,
228 `Database.dropTemporaryTable`, and `Database.insert` with
229 the ``select`` argument.
230 """
231 # Need to start with the static schema; also insert some test data.
232 newDatabase = self.makeEmptyDatabase()
233 with newDatabase.declareStaticTables(create=True) as context:
234 static = context.addTableTuple(STATIC_TABLE_SPECS)
235 newDatabase.insert(static.a,
236 {"name": "a1", "region": None},
237 {"name": "a2", "region": None})
238 bIds = newDatabase.insert(static.b,
239 {"name": "b1", "value": 11},
240 {"name": "b2", "value": 12},
241 {"name": "b3", "value": 13},
242 returnIds=True)
243 # Create the table.
244 with newDatabase.session() as session:
245 table1 = session.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
246 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
247 # Insert via a INSERT INTO ... SELECT query.
248 newDatabase.insert(
249 table1,
250 select=sqlalchemy.sql.select(
251 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
252 ).select_from(
253 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
254 ).where(
255 sqlalchemy.sql.and_(
256 static.a.columns.name == "a1",
257 static.b.columns.value <= 12,
258 )
259 )
260 )
261 # Check that the inserted rows are present.
262 self.assertCountEqual(
263 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
264 [row._asdict() for row in newDatabase.query(table1.select())]
265 )
266 # Create another one via a read-only connection to the database.
267 # We _do_ allow temporary table modifications in read-only
268 # databases.
269 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
270 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
271 context.addTableTuple(STATIC_TABLE_SPECS)
272 with existingReadOnlyDatabase.session() as session2:
273 table2 = session2.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
274 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
275 # Those tables should not be the same, despite having the
276 # same ddl.
277 self.assertIsNot(table1, table2)
278 # Do a slightly different insert into this table, to check
279 # that it works in a read-only database. This time we
280 # pass column names as a kwarg to insert instead of by
281 # labeling the columns in the select.
282 existingReadOnlyDatabase.insert(
283 table2,
284 select=sqlalchemy.sql.select(
285 static.a.columns.name, static.b.columns.id
286 ).select_from(
287 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
288 ).where(
289 sqlalchemy.sql.and_(
290 static.a.columns.name == "a2",
291 static.b.columns.value >= 12,
292 )
293 ),
294 names=["a_name", "b_id"],
295 )
296 # Check that the inserted rows are present.
297 self.assertCountEqual(
298 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
299 [row._asdict() for row in existingReadOnlyDatabase.query(table2.select())]
300 )
301 # Drop the temporary table from the read-only DB. It's
302 # unspecified whether attempting to use it after this
303 # point is an error or just never returns any results, so
304 # we can't test what it does, only that it's not an error.
305 session2.dropTemporaryTable(table2)
306 # Drop the original temporary table.
307 session.dropTemporaryTable(table1)
309 def testSchemaSeparation(self):
310 """Test that creating two different `Database` instances allows us
311 to create different tables with the same name in each.
312 """
313 db1 = self.makeEmptyDatabase(origin=1)
314 with db1.declareStaticTables(create=True) as context:
315 tables = context.addTableTuple(STATIC_TABLE_SPECS)
316 self.checkStaticSchema(tables)
318 db2 = self.makeEmptyDatabase(origin=2)
319 # Make the DDL here intentionally different so we'll definitely
320 # notice if db1 and db2 are pointing at the same schema.
321 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
322 with db2.declareStaticTables(create=True) as context:
323 # Make the DDL here intentionally different so we'll definitely
324 # notice if db1 and db2 are pointing at the same schema.
325 table = context.addTable("a", spec)
326 self.checkTable(spec, table)
328 def testInsertQueryDelete(self):
329 """Test the `Database.insert`, `Database.query`, and `Database.delete`
330 methods, as well as the `Base64Region` type and the ``onDelete``
331 argument to `ddl.ForeignKeySpec`.
332 """
333 db = self.makeEmptyDatabase(origin=1)
334 with db.declareStaticTables(create=True) as context:
335 tables = context.addTableTuple(STATIC_TABLE_SPECS)
336 # Insert a single, non-autoincrement row that contains a region and
337 # query to get it back.
338 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
339 row = {"name": "a1", "region": region}
340 db.insert(tables.a, row)
341 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row])
342 # Insert multiple autoincrement rows but do not try to get the IDs
343 # back immediately.
344 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
345 results = [r._asdict() for r in db.query(tables.b.select().order_by("id"))]
346 self.assertEqual(len(results), 2)
347 for row in results:
348 self.assertIn(row["name"], ("b1", "b2"))
349 self.assertIsInstance(row["id"], int)
350 self.assertGreater(results[1]["id"], results[0]["id"])
351 # Insert multiple autoincrement rows and get the IDs back from insert.
352 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
353 ids = db.insert(tables.b, *rows, returnIds=True)
354 results = [
355 r._asdict() for r in db.query(
356 tables.b.select().where(tables.b.columns.id > results[1]["id"])
357 )
358 ]
359 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
360 self.assertCountEqual(results, expected)
361 self.assertTrue(all(result["id"] is not None for result in results))
362 # Insert multiple rows into a table with an autoincrement+origin
363 # primary key, then use the returned IDs to insert into a dynamic
364 # table.
365 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
366 {"origin": db.origin, "b_id": None}]
367 ids = db.insert(tables.c, *rows, returnIds=True)
368 results = [r._asdict() for r in db.query(tables.c.select())]
369 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
370 self.assertCountEqual(results, expected)
371 self.assertTrue(all(result["id"] is not None for result in results))
372 # Add the dynamic table.
373 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
374 # Insert into it.
375 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
376 db.insert(d, *rows)
377 results = [r._asdict() for r in db.query(d.select())]
378 self.assertCountEqual(rows, results)
379 # Insert multiple rows into a table with an autoincrement+origin
380 # primary key (this is especially tricky for SQLite, but good to test
381 # for all DBs), but pass in a value for the autoincrement key.
382 # For extra complexity, we re-use the autoincrement value with a
383 # different value for origin.
384 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
385 {"id": 700, "origin": 60, "b_id": None},
386 {"id": 1, "origin": 60, "b_id": None}]
387 db.insert(tables.c, *rows2)
388 results = [r._asdict() for r in db.query(tables.c.select())]
389 self.assertCountEqual(results, expected + rows2)
390 self.assertTrue(all(result["id"] is not None for result in results))
392 # Define 'SELECT COUNT(*)' query for later use.
393 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
394 # Get the values we inserted into table b.
395 bValues = [r._asdict() for r in db.query(tables.b.select())]
396 # Remove two row from table b by ID.
397 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
398 self.assertEqual(n, 2)
399 # Remove the other two rows from table b by name.
400 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
401 self.assertEqual(n, 2)
402 # There should now be no rows in table b.
403 self.assertEqual(
404 db.query(count.select_from(tables.b)).scalar(),
405 0
406 )
407 # All b_id values in table c should now be NULL, because there's an
408 # onDelete='SET NULL' foreign key.
409 self.assertEqual(
410 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
411 0
412 )
413 # Remove all rows in table a (there's only one); this should remove all
414 # rows in d due to onDelete='CASCADE'.
415 n = db.delete(tables.a, [])
416 self.assertEqual(n, 1)
417 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
418 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
420 def testDeleteWhere(self):
421 """Tests for `Database.deleteWhere`.
422 """
423 db = self.makeEmptyDatabase(origin=1)
424 with db.declareStaticTables(create=True) as context:
425 tables = context.addTableTuple(STATIC_TABLE_SPECS)
426 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
427 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
429 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
430 self.assertEqual(n, 3)
431 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 7)
433 n = db.deleteWhere(tables.b, tables.b.columns.id.in_(
434 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
435 ))
436 self.assertEqual(n, 4)
437 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 3)
439 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
440 self.assertEqual(n, 1)
441 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 2)
443 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
444 self.assertEqual(n, 2)
445 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0)
447 def testUpdate(self):
448 """Tests for `Database.update`.
449 """
450 db = self.makeEmptyDatabase(origin=1)
451 with db.declareStaticTables(create=True) as context:
452 tables = context.addTableTuple(STATIC_TABLE_SPECS)
453 # Insert two rows into table a, both without regions.
454 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
455 # Update one of the rows with a region.
456 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
457 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
458 self.assertEqual(n, 1)
459 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
460 self.assertCountEqual(
461 [r._asdict() for r in db.query(sql)],
462 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
463 )
465 def testSync(self):
466 """Tests for `Database.sync`.
467 """
468 db = self.makeEmptyDatabase(origin=1)
469 with db.declareStaticTables(create=True) as context:
470 tables = context.addTableTuple(STATIC_TABLE_SPECS)
471 # Insert a row with sync, because it doesn't exist yet.
472 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
473 self.assertTrue(inserted)
474 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
475 [r._asdict() for r in db.query(tables.b.select())])
476 # Repeat that operation, which should do nothing but return the
477 # requested values.
478 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
479 self.assertFalse(inserted)
480 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
481 [r._asdict() for r in db.query(tables.b.select())])
482 # Repeat the operation without the 'extra' arg, which should also just
483 # return the existing row.
484 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
485 self.assertFalse(inserted)
486 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
487 [r._asdict() for r in db.query(tables.b.select())])
488 # Repeat the operation with a different value in 'extra'. That still
489 # shouldn't be an error, because 'extra' is only used if we really do
490 # insert. Also drop the 'returning' argument.
491 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
492 self.assertFalse(inserted)
493 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
494 [r._asdict() for r in db.query(tables.b.select())])
495 # Repeat the operation with the correct value in 'compared' instead of
496 # 'extra'.
497 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
498 self.assertFalse(inserted)
499 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
500 [r._asdict() for r in db.query(tables.b.select())])
501 # Repeat the operation with an incorrect value in 'compared'; this
502 # should raise.
503 with self.assertRaises(DatabaseConflictError):
504 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
505 # Try to sync in a read-only database. This should work if and only
506 # if the matching row already exists.
507 with self.asReadOnly(db) as rodb:
508 with rodb.declareStaticTables(create=False) as context:
509 tables = context.addTableTuple(STATIC_TABLE_SPECS)
510 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
511 self.assertFalse(inserted)
512 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
513 [r._asdict() for r in rodb.query(tables.b.select())])
514 with self.assertRaises(ReadOnlyDatabaseError):
515 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
516 # Repeat the operation with a different value in 'compared' and ask to
517 # update.
518 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
519 self.assertEqual(updated, {"value": 10})
520 self.assertEqual([{"id": values["id"], "name": "b1", "value": 20}],
521 [r._asdict() for r in db.query(tables.b.select())])
523 def testReplace(self):
524 """Tests for `Database.replace`.
525 """
526 db = self.makeEmptyDatabase(origin=1)
527 with db.declareStaticTables(create=True) as context:
528 tables = context.addTableTuple(STATIC_TABLE_SPECS)
529 # Use 'replace' to insert a single row that contains a region and
530 # query to get it back.
531 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
532 row1 = {"name": "a1", "region": region}
533 db.replace(tables.a, row1)
534 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
535 # Insert another row without a region.
536 row2 = {"name": "a2", "region": None}
537 db.replace(tables.a, row2)
538 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
539 # Use replace to re-insert both of those rows again, which should do
540 # nothing.
541 db.replace(tables.a, row1, row2)
542 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
543 # Replace row1 with a row with no region, while reinserting row2.
544 row1a = {"name": "a1", "region": None}
545 db.replace(tables.a, row1a, row2)
546 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1a, row2])
547 # Replace both rows, returning row1 to its original state, while adding
548 # a new one. Pass them in in a different order.
549 row2a = {"name": "a2", "region": region}
550 row3 = {"name": "a3", "region": None}
551 db.replace(tables.a, row3, row2a, row1)
552 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2a, row3])
554 def testEnsure(self):
555 """Tests for `Database.ensure`.
556 """
557 db = self.makeEmptyDatabase(origin=1)
558 with db.declareStaticTables(create=True) as context:
559 tables = context.addTableTuple(STATIC_TABLE_SPECS)
560 # Use 'ensure' to insert a single row that contains a region and
561 # query to get it back.
562 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
563 row1 = {"name": "a1", "region": region}
564 self.assertEqual(db.ensure(tables.a, row1), 1)
565 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
566 # Insert another row without a region.
567 row2 = {"name": "a2", "region": None}
568 self.assertEqual(db.ensure(tables.a, row2), 1)
569 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
570 # Use ensure to re-insert both of those rows again, which should do
571 # nothing.
572 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
573 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
574 # Attempt to insert row1's key with no region, while
575 # reinserting row2. This should also do nothing.
576 row1a = {"name": "a1", "region": None}
577 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
578 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
579 # Attempt to insert new rows for both existing keys, this time also
580 # adding a new row. Pass them in in a different order. Only the new
581 # row should be added.
582 row2a = {"name": "a2", "region": region}
583 row3 = {"name": "a3", "region": None}
584 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
585 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2, row3])
587 def testTransactionNesting(self):
588 """Test that transactions can be nested with the behavior in the
589 presence of exceptions working as documented.
590 """
591 db = self.makeEmptyDatabase(origin=1)
592 with db.declareStaticTables(create=True) as context:
593 tables = context.addTableTuple(STATIC_TABLE_SPECS)
594 # Insert one row so we can trigger integrity errors by trying to insert
595 # a duplicate of it below.
596 db.insert(tables.a, {"name": "a1"})
597 # First test: error recovery via explicit savepoint=True in the inner
598 # transaction.
599 with db.transaction():
600 # This insert should succeed, and should not be rolled back because
601 # the assertRaises context should catch any exception before it
602 # propagates up to the outer transaction.
603 db.insert(tables.a, {"name": "a2"})
604 with self.assertRaises(sqlalchemy.exc.IntegrityError):
605 with db.transaction(savepoint=True):
606 # This insert should succeed, but should be rolled back.
607 db.insert(tables.a, {"name": "a4"})
608 # This insert should fail (duplicate primary key), raising
609 # an exception.
610 db.insert(tables.a, {"name": "a1"})
611 self.assertCountEqual(
612 [r._asdict() for r in db.query(tables.a.select())],
613 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
614 )
615 # Second test: error recovery via implicit savepoint=True, when the
616 # innermost transaction is inside a savepoint=True transaction.
617 with db.transaction():
618 # This insert should succeed, and should not be rolled back
619 # because the assertRaises context should catch any
620 # exception before it propagates up to the outer
621 # transaction.
622 db.insert(tables.a, {"name": "a3"})
623 with self.assertRaises(sqlalchemy.exc.IntegrityError):
624 with db.transaction(savepoint=True):
625 # This insert should succeed, but should be rolled back.
626 db.insert(tables.a, {"name": "a4"})
627 with db.transaction():
628 # This insert should succeed, but should be rolled
629 # back.
630 db.insert(tables.a, {"name": "a5"})
631 # This insert should fail (duplicate primary key),
632 # raising an exception.
633 db.insert(tables.a, {"name": "a1"})
634 self.assertCountEqual(
635 [r._asdict() for r in db.query(tables.a.select())],
636 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
637 )
639 def testTransactionLocking(self):
640 """Test that `Database.transaction` can be used to acquire a lock
641 that prohibits concurrent writes.
642 """
643 db1 = self.makeEmptyDatabase(origin=1)
644 with db1.declareStaticTables(create=True) as context:
645 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
647 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
648 """One side of the concurrent locking test.
650 This optionally locks the table (and maybe the whole database),
651 does a select for its contents, inserts a new row, and then selects
652 again, with some waiting in between to make sure the other side has
653 a chance to _attempt_ to insert in between. If the locking is
654 enabled and works, the difference between the selects should just
655 be the insert done on this thread.
656 """
657 # Give Side2 a chance to create a connection
658 await asyncio.sleep(1.0)
659 with db1.transaction(lock=lock):
660 names1 = {row.name for row in db1.query(tables1.a.select())}
661 # Give Side2 a chance to insert (which will be blocked if
662 # we've acquired a lock).
663 await asyncio.sleep(2.0)
664 db1.insert(tables1.a, {"name": "a1"})
665 names2 = {row.name for row in db1.query(tables1.a.select())}
666 return names1, names2
668 async def side2() -> None:
669 """The other side of the concurrent locking test.
671 This side just waits a bit and then tries to insert a row into the
672 table that the other side is trying to lock. Hopefully that
673 waiting is enough to give the other side a chance to acquire the
674 lock and thus make this side block until the lock is released. If
675 this side manages to do the insert before side1 acquires the lock,
676 we'll just warn about not succeeding at testing the locking,
677 because we can only make that unlikely, not impossible.
678 """
679 def toRunInThread():
680 """SQLite locking isn't asyncio-friendly unless we actually
681 run it in another thread. And SQLite gets very unhappy if
682 we try to use a connection from multiple threads, so we have
683 to create the new connection here instead of out in the main
684 body of the test function.
685 """
686 db2 = self.getNewConnection(db1, writeable=True)
687 with db2.declareStaticTables(create=False) as context:
688 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
689 with db2.transaction():
690 db2.insert(tables2.a, {"name": "a2"})
692 await asyncio.sleep(2.0)
693 loop = asyncio.get_running_loop()
694 with ThreadPoolExecutor() as pool:
695 await loop.run_in_executor(pool, toRunInThread)
697 async def testProblemsWithNoLocking() -> None:
698 """Run side1 and side2 with no locking, attempting to demonstrate
699 the problem that locking is supposed to solve. If we get unlucky
700 with scheduling, side2 will just happen to insert after side1 is
701 done, and we won't have anything definitive. We just warn in that
702 case because we really don't want spurious test failures.
703 """
704 task1 = asyncio.create_task(side1())
705 task2 = asyncio.create_task(side2())
707 names1, names2 = await task1
708 await task2
709 if "a2" in names1:
710 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
711 "happened before first SELECT.")
712 self.assertEqual(names1, {"a2"})
713 self.assertEqual(names2, {"a1", "a2"})
714 elif "a2" not in names2:
715 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
716 "happened after second SELECT even without locking.")
717 self.assertEqual(names1, set())
718 self.assertEqual(names2, {"a1"})
719 else:
720 # This is the expected case: both INSERTS happen between the
721 # two SELECTS. If we don't get this almost all of the time we
722 # should adjust the sleep amounts.
723 self.assertEqual(names1, set())
724 self.assertEqual(names2, {"a1", "a2"})
726 asyncio.run(testProblemsWithNoLocking())
728 # Clean up after first test.
729 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
731 async def testSolutionWithLocking() -> None:
732 """Run side1 and side2 with locking, which should make side2 block
733 its insert until side2 releases its lock.
734 """
735 task1 = asyncio.create_task(side1(lock=[tables1.a]))
736 task2 = asyncio.create_task(side2())
738 names1, names2 = await task1
739 await task2
740 if "a2" in names1:
741 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT "
742 "happened before first SELECT.")
743 self.assertEqual(names1, {"a2"})
744 self.assertEqual(names2, {"a1", "a2"})
745 else:
746 # This is the expected case: the side2 INSERT happens after the
747 # last SELECT on side1. This can also happen due to unlucky
748 # scheduling, and we have no way to detect that here, but the
749 # similar "no-locking" test has at least some chance of being
750 # affected by the same problem and warning about it.
751 self.assertEqual(names1, set())
752 self.assertEqual(names2, {"a1"})
754 asyncio.run(testSolutionWithLocking())
756 def testTimespanDatabaseRepresentation(self):
757 """Tests for `TimespanDatabaseRepresentation` and the `Database`
758 methods that interact with it.
759 """
760 # Make some test timespans to play with, with the full suite of
761 # topological relationships.
762 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai")
763 offset = astropy.time.TimeDelta(60, format="sec")
764 timestamps = [start + offset*n for n in range(3)]
765 aTimespans = [Timespan(begin=None, end=None)]
766 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
767 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
768 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
769 aTimespans.append(Timespan.makeEmpty())
770 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
771 # Make another list of timespans that span the full range but don't
772 # overlap. This is a subset of the previous list.
773 bTimespans = [Timespan(begin=None, end=timestamps[0])]
774 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
775 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
776 # Make a database and create a table with that database's timespan
777 # representation. This one will have no exclusion constraint and
778 # a nullable timespan.
779 db = self.makeEmptyDatabase(origin=1)
780 TimespanReprClass = db.getTimespanRepresentation()
781 aSpec = ddl.TableSpec(
782 fields=[
783 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
784 ],
785 )
786 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
787 aSpec.fields.add(fieldSpec)
788 with db.declareStaticTables(create=True) as context:
789 aTable = context.addTable("a", aSpec)
790 self.maxDiff = None
792 def convertRowForInsert(row: dict) -> dict:
793 """Convert a row containing a Timespan instance into one suitable
794 for insertion into the database.
795 """
796 result = row.copy()
797 ts = result.pop(TimespanReprClass.NAME)
798 return TimespanReprClass.update(ts, result=result)
800 def convertRowFromSelect(row: dict) -> dict:
801 """Convert a row from the database into one containing a Timespan.
802 """
803 result = row.copy()
804 timespan = TimespanReprClass.extract(result)
805 for name in TimespanReprClass.getFieldNames():
806 del result[name]
807 result[TimespanReprClass.NAME] = timespan
808 return result
810 # Insert rows into table A, in chunks just to make things interesting.
811 # Include one with a NULL timespan.
812 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
813 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
814 db.insert(aTable, convertRowForInsert(aRows[0]))
815 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
816 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
817 # Add another one with a NULL timespan, but this time by invoking
818 # the server-side default.
819 aRows.append({"id": len(aRows)})
820 db.insert(aTable, aRows[-1])
821 aRows[-1][TimespanReprClass.NAME] = None
822 # Test basic round-trip through database.
823 self.assertEqual(
824 aRows,
825 [convertRowFromSelect(row._asdict())
826 for row in db.query(aTable.select().order_by(aTable.columns.id))]
827 )
828 # Create another table B with a not-null timespan and (if the database
829 # supports it), an exclusion constraint. Use ensureTableExists this
830 # time to check that mode of table creation vs. timespans.
831 bSpec = ddl.TableSpec(
832 fields=[
833 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
834 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
835 ],
836 )
837 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
838 bSpec.fields.add(fieldSpec)
839 if TimespanReprClass.hasExclusionConstraint():
840 bSpec.exclusion.add(("key", TimespanReprClass))
841 bTable = db.ensureTableExists("b", bSpec)
842 # Insert rows into table B, again in chunks. Each Timespan appears
843 # twice, but with different values for the 'key' field (which should
844 # still be okay for any exclusion constraint we may have defined).
845 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
846 offset = len(bRows)
847 bRows.extend({"id": n + offset, "key": 2, TimespanReprClass.NAME: t}
848 for n, t in enumerate(bTimespans))
849 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
850 db.insert(bTable, convertRowForInsert(bRows[2]))
851 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
852 # Insert a row with no timespan into table B. This should invoke the
853 # server-side default, which is a timespan over (-∞, ∞). We set
854 # key=3 to avoid upsetting an exclusion constraint that might exist.
855 bRows.append({"id": len(bRows), "key": 3})
856 db.insert(bTable, bRows[-1])
857 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
858 # Test basic round-trip through database.
859 self.assertEqual(
860 bRows,
861 [convertRowFromSelect(row._asdict())
862 for row in db.query(bTable.select().order_by(bTable.columns.id))]
863 )
864 # Test that we can't insert timespan=None into this table.
865 with self.assertRaises(sqlalchemy.exc.IntegrityError):
866 db.insert(
867 bTable,
868 convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})
869 )
870 # IFF this database supports exclusion constraints, test that they
871 # also prevent inserts.
872 if TimespanReprClass.hasExclusionConstraint():
873 with self.assertRaises(sqlalchemy.exc.IntegrityError):
874 db.insert(
875 bTable,
876 convertRowForInsert({
877 "id": len(bRows), "key": 1,
878 TimespanReprClass.NAME: Timespan(None, timestamps[1])
879 })
880 )
881 with self.assertRaises(sqlalchemy.exc.IntegrityError):
882 db.insert(
883 bTable,
884 convertRowForInsert({
885 "id": len(bRows), "key": 1,
886 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2])
887 })
888 )
889 with self.assertRaises(sqlalchemy.exc.IntegrityError):
890 db.insert(
891 bTable,
892 convertRowForInsert({
893 "id": len(bRows), "key": 1,
894 TimespanReprClass.NAME: Timespan(timestamps[2], None)
895 })
896 )
897 # Test NULL checks in SELECT queries, on both tables.
898 aRepr = TimespanReprClass.fromSelectable(aTable)
899 self.assertEqual(
900 [row[TimespanReprClass.NAME] is None for row in aRows],
901 [
902 row.f for row in db.query(
903 sqlalchemy.sql.select(
904 aRepr.isNull().label("f")
905 ).order_by(
906 aTable.columns.id
907 )
908 )
909 ]
910 )
911 bRepr = TimespanReprClass.fromSelectable(bTable)
912 self.assertEqual(
913 [False for row in bRows],
914 [
915 row.f for row in db.query(
916 sqlalchemy.sql.select(
917 bRepr.isNull().label("f")
918 ).order_by(
919 bTable.columns.id
920 )
921 )
922 ]
923 )
924 # Test relationships expressions that relate in-database timespans to
925 # Python-literal timespans, all from the more complete 'a' set; check
926 # that this is consistent with Python-only relationship tests.
927 for rhsRow in aRows:
928 if rhsRow[TimespanReprClass.NAME] is None:
929 continue
930 with self.subTest(rhsRow=rhsRow):
931 expected = {}
932 for lhsRow in aRows:
933 if lhsRow[TimespanReprClass.NAME] is None:
934 expected[lhsRow["id"]] = (None, None, None, None)
935 else:
936 expected[lhsRow["id"]] = (
937 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
938 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
939 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
940 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
941 )
942 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
943 sql = sqlalchemy.sql.select(
944 aTable.columns.id.label("lhs"),
945 aRepr.overlaps(rhsRepr).label("overlaps"),
946 aRepr.contains(rhsRepr).label("contains"),
947 (aRepr < rhsRepr).label("less_than"),
948 (aRepr > rhsRepr).label("greater_than"),
949 ).select_from(aTable)
950 queried = {
951 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
952 for row in db.query(sql)
953 }
954 self.assertEqual(expected, queried)
955 # Test relationship expressions that relate in-database timespans to
956 # each other, all from the more complete 'a' set; check that this is
957 # consistent with Python-only relationship tests.
958 expected = {}
959 for lhs, rhs in itertools.product(aRows, aRows):
960 lhsT = lhs[TimespanReprClass.NAME]
961 rhsT = rhs[TimespanReprClass.NAME]
962 if lhsT is not None and rhsT is not None:
963 expected[lhs["id"], rhs["id"]] = (
964 lhsT.overlaps(rhsT),
965 lhsT.contains(rhsT),
966 lhsT < rhsT,
967 lhsT > rhsT
968 )
969 else:
970 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
971 lhsSubquery = aTable.alias("lhs")
972 rhsSubquery = aTable.alias("rhs")
973 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery)
974 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery)
975 sql = sqlalchemy.sql.select(
976 lhsSubquery.columns.id.label("lhs"),
977 rhsSubquery.columns.id.label("rhs"),
978 lhsRepr.overlaps(rhsRepr).label("overlaps"),
979 lhsRepr.contains(rhsRepr).label("contains"),
980 (lhsRepr < rhsRepr).label("less_than"),
981 (lhsRepr > rhsRepr).label("greater_than"),
982 ).select_from(
983 lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))
984 )
985 queried = {
986 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
987 for row in db.query(sql)}
988 self.assertEqual(expected, queried)
989 # Test relationship expressions between in-database timespans and
990 # Python-literal instantaneous times.
991 for t in timestamps:
992 with self.subTest(t=t):
993 expected = {}
994 for lhsRow in aRows:
995 if lhsRow[TimespanReprClass.NAME] is None:
996 expected[lhsRow["id"]] = (None, None, None)
997 else:
998 expected[lhsRow["id"]] = (
999 lhsRow[TimespanReprClass.NAME].contains(t),
1000 lhsRow[TimespanReprClass.NAME] < t,
1001 lhsRow[TimespanReprClass.NAME] > t,
1002 )
1003 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1004 sql = sqlalchemy.sql.select(
1005 aTable.columns.id.label("lhs"),
1006 aRepr.contains(rhs).label("contains"),
1007 (aRepr < rhs).label("less_than"),
1008 (aRepr > rhs).label("greater_than"),
1009 ).select_from(aTable)
1010 queried = {
1011 row.lhs: (row.contains, row.less_than, row.greater_than)
1012 for row in db.query(sql)
1013 }
1014 self.assertEqual(expected, queried)