Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26import asyncio
27from collections import namedtuple
28from concurrent.futures import ThreadPoolExecutor
29import itertools
30from typing import ContextManager, Iterable, Set, Tuple
31import warnings
33import astropy.time
34import sqlalchemy
36from lsst.sphgeom import ConvexPolygon, UnitVector3d
37from ..interfaces import (
38 Database,
39 ReadOnlyDatabaseError,
40 DatabaseConflictError,
41 SchemaAlreadyDefinedError
42)
43from ...core import ddl, Timespan
45StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
47STATIC_TABLE_SPECS = StaticTablesTuple(
48 a=ddl.TableSpec(
49 fields=[
50 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
51 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
52 ]
53 ),
54 b=ddl.TableSpec(
55 fields=[
56 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
58 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
59 ],
60 unique=[("name",)],
61 ),
62 c=ddl.TableSpec(
63 fields=[
64 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
65 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
66 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
67 ],
68 foreignKeys=[
69 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
70 ]
71 ),
72)
74DYNAMIC_TABLE_SPEC = ddl.TableSpec(
75 fields=[
76 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
77 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
78 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
79 ],
80 foreignKeys=[
81 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
82 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
83 ]
84)
86TEMPORARY_TABLE_SPEC = ddl.TableSpec(
87 fields=[
88 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
89 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
90 ],
91)
94class DatabaseTests(ABC):
95 """Generic tests for the `Database` interface that can be subclassed to
96 generate tests for concrete implementations.
97 """
99 @abstractmethod
100 def makeEmptyDatabase(self, origin: int = 0) -> Database:
101 """Return an empty `Database` with the given origin, or an
102 automatically-generated one if ``origin`` is `None`.
103 """
104 raise NotImplementedError()
106 @abstractmethod
107 def asReadOnly(self, database: Database) -> ContextManager[Database]:
108 """Return a context manager for a read-only connection into the given
109 database.
111 The original database should be considered unusable within the context
112 but safe to use again afterwards (this allows the context manager to
113 block write access by temporarily changing user permissions to really
114 guarantee that write operations are not performed).
115 """
116 raise NotImplementedError()
118 @abstractmethod
119 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
120 """Return a new `Database` instance that points to the same underlying
121 storage as the given one.
122 """
123 raise NotImplementedError()
125 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
126 self.assertCountEqual(spec.fields.names, table.columns.keys())
127 # Checking more than this currently seems fragile, as it might restrict
128 # what Database implementations do; we don't care if the spec is
129 # actually preserved in terms of types and constraints as long as we
130 # can use the returned table as if it was.
132 def checkStaticSchema(self, tables: StaticTablesTuple):
133 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
134 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
135 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
137 def testDeclareStaticTables(self):
138 """Tests for `Database.declareStaticSchema` and the methods it
139 delegates to.
140 """
141 # Create the static schema in a new, empty database.
142 newDatabase = self.makeEmptyDatabase()
143 with newDatabase.declareStaticTables(create=True) as context:
144 tables = context.addTableTuple(STATIC_TABLE_SPECS)
145 self.checkStaticSchema(tables)
146 # Check that we can load that schema even from a read-only connection.
147 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
148 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
149 tables = context.addTableTuple(STATIC_TABLE_SPECS)
150 self.checkStaticSchema(tables)
152 def testDeclareStaticTablesTwice(self):
153 """Tests for `Database.declareStaticSchema` being called twice.
154 """
155 # Create the static schema in a new, empty database.
156 newDatabase = self.makeEmptyDatabase()
157 with newDatabase.declareStaticTables(create=True) as context:
158 tables = context.addTableTuple(STATIC_TABLE_SPECS)
159 self.checkStaticSchema(tables)
160 # Second time it should raise
161 with self.assertRaises(SchemaAlreadyDefinedError):
162 with newDatabase.declareStaticTables(create=True) as context:
163 tables = context.addTableTuple(STATIC_TABLE_SPECS)
164 # Check schema, it should still contain all tables, and maybe some
165 # extra.
166 with newDatabase.declareStaticTables(create=False) as context:
167 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
169 def testRepr(self):
170 """Test that repr does not return a generic thing."""
171 newDatabase = self.makeEmptyDatabase()
172 rep = repr(newDatabase)
173 # Check that stringification works and gives us something different
174 self.assertNotEqual(rep, str(newDatabase))
175 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
176 self.assertIn("://", rep)
178 def testDynamicTables(self):
179 """Tests for `Database.ensureTableExists` and
180 `Database.getExistingTable`.
181 """
182 # Need to start with the static schema.
183 newDatabase = self.makeEmptyDatabase()
184 with newDatabase.declareStaticTables(create=True) as context:
185 context.addTableTuple(STATIC_TABLE_SPECS)
186 # Try to ensure the dyamic table exists in a read-only version of that
187 # database, which should fail because we can't create it.
188 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
189 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
190 context.addTableTuple(STATIC_TABLE_SPECS)
191 with self.assertRaises(ReadOnlyDatabaseError):
192 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
193 # Just getting the dynamic table before it exists should return None.
194 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
195 # Ensure the new table exists back in the original database, which
196 # should create it.
197 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
198 self.checkTable(DYNAMIC_TABLE_SPEC, table)
199 # Ensuring that it exists should just return the exact same table
200 # instance again.
201 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
202 # Try again from the read-only database.
203 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
204 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
205 context.addTableTuple(STATIC_TABLE_SPECS)
206 # Just getting the dynamic table should now work...
207 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
208 # ...as should ensuring that it exists, since it now does.
209 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
210 self.checkTable(DYNAMIC_TABLE_SPEC, table)
211 # Trying to get the table with a different specification (at least
212 # in terms of what columns are present) should raise.
213 with self.assertRaises(DatabaseConflictError):
214 newDatabase.ensureTableExists(
215 "d",
216 ddl.TableSpec(
217 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
218 )
219 )
220 # Calling ensureTableExists inside a transaction block is an error,
221 # even if it would do nothing.
222 with newDatabase.transaction():
223 with self.assertRaises(AssertionError):
224 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
226 def testTemporaryTables(self):
227 """Tests for `Database.makeTemporaryTable`,
228 `Database.dropTemporaryTable`, and `Database.insert` with
229 the ``select`` argument.
230 """
231 # Need to start with the static schema; also insert some test data.
232 newDatabase = self.makeEmptyDatabase()
233 with newDatabase.declareStaticTables(create=True) as context:
234 static = context.addTableTuple(STATIC_TABLE_SPECS)
235 newDatabase.insert(static.a,
236 {"name": "a1", "region": None},
237 {"name": "a2", "region": None})
238 bIds = newDatabase.insert(static.b,
239 {"name": "b1", "value": 11},
240 {"name": "b2", "value": 12},
241 {"name": "b3", "value": 13},
242 returnIds=True)
243 # Create the table.
244 with newDatabase.session() as session:
245 table1 = session.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
246 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
247 # Insert via a INSERT INTO ... SELECT query.
248 newDatabase.insert(
249 table1,
250 select=sqlalchemy.sql.select(
251 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
252 ).select_from(
253 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
254 ).where(
255 sqlalchemy.sql.and_(
256 static.a.columns.name == "a1",
257 static.b.columns.value <= 12,
258 )
259 )
260 )
261 # Check that the inserted rows are present.
262 self.assertCountEqual(
263 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
264 [row._asdict() for row in newDatabase.query(table1.select())]
265 )
266 # Create another one via a read-only connection to the database.
267 # We _do_ allow temporary table modifications in read-only
268 # databases.
269 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
270 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
271 context.addTableTuple(STATIC_TABLE_SPECS)
272 with existingReadOnlyDatabase.session() as session2:
273 table2 = session2.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
274 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
275 # Those tables should not be the same, despite having the
276 # same ddl.
277 self.assertIsNot(table1, table2)
278 # Do a slightly different insert into this table, to check
279 # that it works in a read-only database. This time we
280 # pass column names as a kwarg to insert instead of by
281 # labeling the columns in the select.
282 existingReadOnlyDatabase.insert(
283 table2,
284 select=sqlalchemy.sql.select(
285 static.a.columns.name, static.b.columns.id
286 ).select_from(
287 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
288 ).where(
289 sqlalchemy.sql.and_(
290 static.a.columns.name == "a2",
291 static.b.columns.value >= 12,
292 )
293 ),
294 names=["a_name", "b_id"],
295 )
296 # Check that the inserted rows are present.
297 self.assertCountEqual(
298 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
299 [row._asdict() for row in existingReadOnlyDatabase.query(table2.select())]
300 )
301 # Drop the temporary table from the read-only DB. It's
302 # unspecified whether attempting to use it after this
303 # point is an error or just never returns any results, so
304 # we can't test what it does, only that it's not an error.
305 session2.dropTemporaryTable(table2)
306 # Drop the original temporary table.
307 session.dropTemporaryTable(table1)
309 def testSchemaSeparation(self):
310 """Test that creating two different `Database` instances allows us
311 to create different tables with the same name in each.
312 """
313 db1 = self.makeEmptyDatabase(origin=1)
314 with db1.declareStaticTables(create=True) as context:
315 tables = context.addTableTuple(STATIC_TABLE_SPECS)
316 self.checkStaticSchema(tables)
318 db2 = self.makeEmptyDatabase(origin=2)
319 # Make the DDL here intentionally different so we'll definitely
320 # notice if db1 and db2 are pointing at the same schema.
321 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
322 with db2.declareStaticTables(create=True) as context:
323 # Make the DDL here intentionally different so we'll definitely
324 # notice if db1 and db2 are pointing at the same schema.
325 table = context.addTable("a", spec)
326 self.checkTable(spec, table)
328 def testInsertQueryDelete(self):
329 """Test the `Database.insert`, `Database.query`, and `Database.delete`
330 methods, as well as the `Base64Region` type and the ``onDelete``
331 argument to `ddl.ForeignKeySpec`.
332 """
333 db = self.makeEmptyDatabase(origin=1)
334 with db.declareStaticTables(create=True) as context:
335 tables = context.addTableTuple(STATIC_TABLE_SPECS)
336 # Insert a single, non-autoincrement row that contains a region and
337 # query to get it back.
338 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
339 row = {"name": "a1", "region": region}
340 db.insert(tables.a, row)
341 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row])
342 # Insert multiple autoincrement rows but do not try to get the IDs
343 # back immediately.
344 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
345 results = [r._asdict() for r in db.query(tables.b.select().order_by("id"))]
346 self.assertEqual(len(results), 2)
347 for row in results:
348 self.assertIn(row["name"], ("b1", "b2"))
349 self.assertIsInstance(row["id"], int)
350 self.assertGreater(results[1]["id"], results[0]["id"])
351 # Insert multiple autoincrement rows and get the IDs back from insert.
352 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
353 ids = db.insert(tables.b, *rows, returnIds=True)
354 results = [
355 r._asdict() for r in db.query(
356 tables.b.select().where(tables.b.columns.id > results[1]["id"])
357 )
358 ]
359 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
360 self.assertCountEqual(results, expected)
361 self.assertTrue(all(result["id"] is not None for result in results))
362 # Insert multiple rows into a table with an autoincrement+origin
363 # primary key, then use the returned IDs to insert into a dynamic
364 # table.
365 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
366 {"origin": db.origin, "b_id": None}]
367 ids = db.insert(tables.c, *rows, returnIds=True)
368 results = [r._asdict() for r in db.query(tables.c.select())]
369 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
370 self.assertCountEqual(results, expected)
371 self.assertTrue(all(result["id"] is not None for result in results))
372 # Add the dynamic table.
373 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
374 # Insert into it.
375 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
376 db.insert(d, *rows)
377 results = [r._asdict() for r in db.query(d.select())]
378 self.assertCountEqual(rows, results)
379 # Insert multiple rows into a table with an autoincrement+origin
380 # primary key (this is especially tricky for SQLite, but good to test
381 # for all DBs), but pass in a value for the autoincrement key.
382 # For extra complexity, we re-use the autoincrement value with a
383 # different value for origin.
384 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
385 {"id": 700, "origin": 60, "b_id": None},
386 {"id": 1, "origin": 60, "b_id": None}]
387 db.insert(tables.c, *rows2)
388 results = [r._asdict() for r in db.query(tables.c.select())]
389 self.assertCountEqual(results, expected + rows2)
390 self.assertTrue(all(result["id"] is not None for result in results))
392 # Define 'SELECT COUNT(*)' query for later use.
393 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
394 # Get the values we inserted into table b.
395 bValues = [r._asdict() for r in db.query(tables.b.select())]
396 # Remove two row from table b by ID.
397 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
398 self.assertEqual(n, 2)
399 # Remove the other two rows from table b by name.
400 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
401 self.assertEqual(n, 2)
402 # There should now be no rows in table b.
403 self.assertEqual(
404 db.query(count.select_from(tables.b)).scalar(),
405 0
406 )
407 # All b_id values in table c should now be NULL, because there's an
408 # onDelete='SET NULL' foreign key.
409 self.assertEqual(
410 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
411 0
412 )
413 # Remove all rows in table a (there's only one); this should remove all
414 # rows in d due to onDelete='CASCADE'.
415 n = db.delete(tables.a, [])
416 self.assertEqual(n, 1)
417 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
418 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
420 def testUpdate(self):
421 """Tests for `Database.update`.
422 """
423 db = self.makeEmptyDatabase(origin=1)
424 with db.declareStaticTables(create=True) as context:
425 tables = context.addTableTuple(STATIC_TABLE_SPECS)
426 # Insert two rows into table a, both without regions.
427 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
428 # Update one of the rows with a region.
429 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
430 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
431 self.assertEqual(n, 1)
432 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
433 self.assertCountEqual(
434 [r._asdict() for r in db.query(sql)],
435 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
436 )
438 def testSync(self):
439 """Tests for `Database.sync`.
440 """
441 db = self.makeEmptyDatabase(origin=1)
442 with db.declareStaticTables(create=True) as context:
443 tables = context.addTableTuple(STATIC_TABLE_SPECS)
444 # Insert a row with sync, because it doesn't exist yet.
445 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
446 self.assertTrue(inserted)
447 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
448 [r._asdict() for r in db.query(tables.b.select())])
449 # Repeat that operation, which should do nothing but return the
450 # requested values.
451 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
452 self.assertFalse(inserted)
453 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
454 [r._asdict() for r in db.query(tables.b.select())])
455 # Repeat the operation without the 'extra' arg, which should also just
456 # return the existing row.
457 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
458 self.assertFalse(inserted)
459 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
460 [r._asdict() for r in db.query(tables.b.select())])
461 # Repeat the operation with a different value in 'extra'. That still
462 # shouldn't be an error, because 'extra' is only used if we really do
463 # insert. Also drop the 'returning' argument.
464 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
465 self.assertFalse(inserted)
466 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
467 [r._asdict() for r in db.query(tables.b.select())])
468 # Repeat the operation with the correct value in 'compared' instead of
469 # 'extra'.
470 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
471 self.assertFalse(inserted)
472 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
473 [r._asdict() for r in db.query(tables.b.select())])
474 # Repeat the operation with an incorrect value in 'compared'; this
475 # should raise.
476 with self.assertRaises(DatabaseConflictError):
477 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
478 # Try to sync in a read-only database. This should work if and only
479 # if the matching row already exists.
480 with self.asReadOnly(db) as rodb:
481 with rodb.declareStaticTables(create=False) as context:
482 tables = context.addTableTuple(STATIC_TABLE_SPECS)
483 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
484 self.assertFalse(inserted)
485 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
486 [r._asdict() for r in rodb.query(tables.b.select())])
487 with self.assertRaises(ReadOnlyDatabaseError):
488 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
489 # Repeat the operation with a different value in 'compared' and ask to
490 # update.
491 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
492 self.assertEqual(updated, {"value": 10})
493 self.assertEqual([{"id": values["id"], "name": "b1", "value": 20}],
494 [r._asdict() for r in db.query(tables.b.select())])
496 def testReplace(self):
497 """Tests for `Database.replace`.
498 """
499 db = self.makeEmptyDatabase(origin=1)
500 with db.declareStaticTables(create=True) as context:
501 tables = context.addTableTuple(STATIC_TABLE_SPECS)
502 # Use 'replace' to insert a single row that contains a region and
503 # query to get it back.
504 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
505 row1 = {"name": "a1", "region": region}
506 db.replace(tables.a, row1)
507 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
508 # Insert another row without a region.
509 row2 = {"name": "a2", "region": None}
510 db.replace(tables.a, row2)
511 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
512 # Use replace to re-insert both of those rows again, which should do
513 # nothing.
514 db.replace(tables.a, row1, row2)
515 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
516 # Replace row1 with a row with no region, while reinserting row2.
517 row1a = {"name": "a1", "region": None}
518 db.replace(tables.a, row1a, row2)
519 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1a, row2])
520 # Replace both rows, returning row1 to its original state, while adding
521 # a new one. Pass them in in a different order.
522 row2a = {"name": "a2", "region": region}
523 row3 = {"name": "a3", "region": None}
524 db.replace(tables.a, row3, row2a, row1)
525 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2a, row3])
527 def testEnsure(self):
528 """Tests for `Database.ensure`.
529 """
530 db = self.makeEmptyDatabase(origin=1)
531 with db.declareStaticTables(create=True) as context:
532 tables = context.addTableTuple(STATIC_TABLE_SPECS)
533 # Use 'ensure' to insert a single row that contains a region and
534 # query to get it back.
535 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
536 row1 = {"name": "a1", "region": region}
537 self.assertEqual(db.ensure(tables.a, row1), 1)
538 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
539 # Insert another row without a region.
540 row2 = {"name": "a2", "region": None}
541 self.assertEqual(db.ensure(tables.a, row2), 1)
542 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
543 # Use ensure to re-insert both of those rows again, which should do
544 # nothing.
545 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
546 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
547 # Attempt to insert row1's key with no region, while
548 # reinserting row2. This should also do nothing.
549 row1a = {"name": "a1", "region": None}
550 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
551 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
552 # Attempt to insert new rows for both existing keys, this time also
553 # adding a new row. Pass them in in a different order. Only the new
554 # row should be added.
555 row2a = {"name": "a2", "region": region}
556 row3 = {"name": "a3", "region": None}
557 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
558 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2, row3])
560 def testTransactionNesting(self):
561 """Test that transactions can be nested with the behavior in the
562 presence of exceptions working as documented.
563 """
564 db = self.makeEmptyDatabase(origin=1)
565 with db.declareStaticTables(create=True) as context:
566 tables = context.addTableTuple(STATIC_TABLE_SPECS)
567 # Insert one row so we can trigger integrity errors by trying to insert
568 # a duplicate of it below.
569 db.insert(tables.a, {"name": "a1"})
570 # First test: error recovery via explicit savepoint=True in the inner
571 # transaction.
572 with db.transaction():
573 # This insert should succeed, and should not be rolled back because
574 # the assertRaises context should catch any exception before it
575 # propagates up to the outer transaction.
576 db.insert(tables.a, {"name": "a2"})
577 with self.assertRaises(sqlalchemy.exc.IntegrityError):
578 with db.transaction(savepoint=True):
579 # This insert should succeed, but should be rolled back.
580 db.insert(tables.a, {"name": "a4"})
581 # This insert should fail (duplicate primary key), raising
582 # an exception.
583 db.insert(tables.a, {"name": "a1"})
584 self.assertCountEqual(
585 [r._asdict() for r in db.query(tables.a.select())],
586 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
587 )
588 # Second test: error recovery via implicit savepoint=True, when the
589 # innermost transaction is inside a savepoint=True transaction.
590 with db.transaction():
591 # This insert should succeed, and should not be rolled back
592 # because the assertRaises context should catch any
593 # exception before it propagates up to the outer
594 # transaction.
595 db.insert(tables.a, {"name": "a3"})
596 with self.assertRaises(sqlalchemy.exc.IntegrityError):
597 with db.transaction(savepoint=True):
598 # This insert should succeed, but should be rolled back.
599 db.insert(tables.a, {"name": "a4"})
600 with db.transaction():
601 # This insert should succeed, but should be rolled
602 # back.
603 db.insert(tables.a, {"name": "a5"})
604 # This insert should fail (duplicate primary key),
605 # raising an exception.
606 db.insert(tables.a, {"name": "a1"})
607 self.assertCountEqual(
608 [r._asdict() for r in db.query(tables.a.select())],
609 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
610 )
612 def testTransactionLocking(self):
613 """Test that `Database.transaction` can be used to acquire a lock
614 that prohibits concurrent writes.
615 """
616 db1 = self.makeEmptyDatabase(origin=1)
617 with db1.declareStaticTables(create=True) as context:
618 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
620 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
621 """One side of the concurrent locking test.
623 This optionally locks the table (and maybe the whole database),
624 does a select for its contents, inserts a new row, and then selects
625 again, with some waiting in between to make sure the other side has
626 a chance to _attempt_ to insert in between. If the locking is
627 enabled and works, the difference between the selects should just
628 be the insert done on this thread.
629 """
630 # Give Side2 a chance to create a connection
631 await asyncio.sleep(1.0)
632 with db1.transaction(lock=lock):
633 names1 = {row.name for row in db1.query(tables1.a.select())}
634 # Give Side2 a chance to insert (which will be blocked if
635 # we've acquired a lock).
636 await asyncio.sleep(2.0)
637 db1.insert(tables1.a, {"name": "a1"})
638 names2 = {row.name for row in db1.query(tables1.a.select())}
639 return names1, names2
641 async def side2() -> None:
642 """The other side of the concurrent locking test.
644 This side just waits a bit and then tries to insert a row into the
645 table that the other side is trying to lock. Hopefully that
646 waiting is enough to give the other side a chance to acquire the
647 lock and thus make this side block until the lock is released. If
648 this side manages to do the insert before side1 acquires the lock,
649 we'll just warn about not succeeding at testing the locking,
650 because we can only make that unlikely, not impossible.
651 """
652 def toRunInThread():
653 """SQLite locking isn't asyncio-friendly unless we actually
654 run it in another thread. And SQLite gets very unhappy if
655 we try to use a connection from multiple threads, so we have
656 to create the new connection here instead of out in the main
657 body of the test function.
658 """
659 db2 = self.getNewConnection(db1, writeable=True)
660 with db2.declareStaticTables(create=False) as context:
661 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
662 with db2.transaction():
663 db2.insert(tables2.a, {"name": "a2"})
665 await asyncio.sleep(2.0)
666 loop = asyncio.get_running_loop()
667 with ThreadPoolExecutor() as pool:
668 await loop.run_in_executor(pool, toRunInThread)
670 async def testProblemsWithNoLocking() -> None:
671 """Run side1 and side2 with no locking, attempting to demonstrate
672 the problem that locking is supposed to solve. If we get unlucky
673 with scheduling, side2 will just happen to insert after side1 is
674 done, and we won't have anything definitive. We just warn in that
675 case because we really don't want spurious test failures.
676 """
677 task1 = asyncio.create_task(side1())
678 task2 = asyncio.create_task(side2())
680 names1, names2 = await task1
681 await task2
682 if "a2" in names1:
683 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
684 "happened before first SELECT.")
685 self.assertEqual(names1, {"a2"})
686 self.assertEqual(names2, {"a1", "a2"})
687 elif "a2" not in names2:
688 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
689 "happened after second SELECT even without locking.")
690 self.assertEqual(names1, set())
691 self.assertEqual(names2, {"a1"})
692 else:
693 # This is the expected case: both INSERTS happen between the
694 # two SELECTS. If we don't get this almost all of the time we
695 # should adjust the sleep amounts.
696 self.assertEqual(names1, set())
697 self.assertEqual(names2, {"a1", "a2"})
699 asyncio.run(testProblemsWithNoLocking())
701 # Clean up after first test.
702 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
704 async def testSolutionWithLocking() -> None:
705 """Run side1 and side2 with locking, which should make side2 block
706 its insert until side2 releases its lock.
707 """
708 task1 = asyncio.create_task(side1(lock=[tables1.a]))
709 task2 = asyncio.create_task(side2())
711 names1, names2 = await task1
712 await task2
713 if "a2" in names1:
714 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT "
715 "happened before first SELECT.")
716 self.assertEqual(names1, {"a2"})
717 self.assertEqual(names2, {"a1", "a2"})
718 else:
719 # This is the expected case: the side2 INSERT happens after the
720 # last SELECT on side1. This can also happen due to unlucky
721 # scheduling, and we have no way to detect that here, but the
722 # similar "no-locking" test has at least some chance of being
723 # affected by the same problem and warning about it.
724 self.assertEqual(names1, set())
725 self.assertEqual(names2, {"a1"})
727 asyncio.run(testSolutionWithLocking())
729 def testTimespanDatabaseRepresentation(self):
730 """Tests for `TimespanDatabaseRepresentation` and the `Database`
731 methods that interact with it.
732 """
733 # Make some test timespans to play with, with the full suite of
734 # topological relationships.
735 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai")
736 offset = astropy.time.TimeDelta(60, format="sec")
737 timestamps = [start + offset*n for n in range(3)]
738 aTimespans = [Timespan(begin=None, end=None)]
739 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
740 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
741 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
742 aTimespans.append(Timespan.makeEmpty())
743 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
744 # Make another list of timespans that span the full range but don't
745 # overlap. This is a subset of the previous list.
746 bTimespans = [Timespan(begin=None, end=timestamps[0])]
747 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
748 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
749 # Make a database and create a table with that database's timespan
750 # representation. This one will have no exclusion constraint and
751 # a nullable timespan.
752 db = self.makeEmptyDatabase(origin=1)
753 TimespanReprClass = db.getTimespanRepresentation()
754 aSpec = ddl.TableSpec(
755 fields=[
756 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
757 ],
758 )
759 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
760 aSpec.fields.add(fieldSpec)
761 with db.declareStaticTables(create=True) as context:
762 aTable = context.addTable("a", aSpec)
763 self.maxDiff = None
765 def convertRowForInsert(row: dict) -> dict:
766 """Convert a row containing a Timespan instance into one suitable
767 for insertion into the database.
768 """
769 result = row.copy()
770 ts = result.pop(TimespanReprClass.NAME)
771 return TimespanReprClass.update(ts, result=result)
773 def convertRowFromSelect(row: dict) -> dict:
774 """Convert a row from the database into one containing a Timespan.
775 """
776 result = row.copy()
777 timespan = TimespanReprClass.extract(result)
778 for name in TimespanReprClass.getFieldNames():
779 del result[name]
780 result[TimespanReprClass.NAME] = timespan
781 return result
783 # Insert rows into table A, in chunks just to make things interesting.
784 # Include one with a NULL timespan.
785 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
786 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
787 db.insert(aTable, convertRowForInsert(aRows[0]))
788 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
789 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
790 # Add another one with a NULL timespan, but this time by invoking
791 # the server-side default.
792 aRows.append({"id": len(aRows)})
793 db.insert(aTable, aRows[-1])
794 aRows[-1][TimespanReprClass.NAME] = None
795 # Test basic round-trip through database.
796 self.assertEqual(
797 aRows,
798 [convertRowFromSelect(row._asdict())
799 for row in db.query(aTable.select().order_by(aTable.columns.id))]
800 )
801 # Create another table B with a not-null timespan and (if the database
802 # supports it), an exclusion constraint. Use ensureTableExists this
803 # time to check that mode of table creation vs. timespans.
804 bSpec = ddl.TableSpec(
805 fields=[
806 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
807 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
808 ],
809 )
810 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
811 bSpec.fields.add(fieldSpec)
812 if TimespanReprClass.hasExclusionConstraint():
813 bSpec.exclusion.add(("key", TimespanReprClass))
814 bTable = db.ensureTableExists("b", bSpec)
815 # Insert rows into table B, again in chunks. Each Timespan appears
816 # twice, but with different values for the 'key' field (which should
817 # still be okay for any exclusion constraint we may have defined).
818 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
819 offset = len(bRows)
820 bRows.extend({"id": n + offset, "key": 2, TimespanReprClass.NAME: t}
821 for n, t in enumerate(bTimespans))
822 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
823 db.insert(bTable, convertRowForInsert(bRows[2]))
824 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
825 # Insert a row with no timespan into table B. This should invoke the
826 # server-side default, which is a timespan over (-∞, ∞). We set
827 # key=3 to avoid upsetting an exclusion constraint that might exist.
828 bRows.append({"id": len(bRows), "key": 3})
829 db.insert(bTable, bRows[-1])
830 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
831 # Test basic round-trip through database.
832 self.assertEqual(
833 bRows,
834 [convertRowFromSelect(row._asdict())
835 for row in db.query(bTable.select().order_by(bTable.columns.id))]
836 )
837 # Test that we can't insert timespan=None into this table.
838 with self.assertRaises(sqlalchemy.exc.IntegrityError):
839 db.insert(
840 bTable,
841 convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})
842 )
843 # IFF this database supports exclusion constraints, test that they
844 # also prevent inserts.
845 if TimespanReprClass.hasExclusionConstraint():
846 with self.assertRaises(sqlalchemy.exc.IntegrityError):
847 db.insert(
848 bTable,
849 convertRowForInsert({
850 "id": len(bRows), "key": 1,
851 TimespanReprClass.NAME: Timespan(None, timestamps[1])
852 })
853 )
854 with self.assertRaises(sqlalchemy.exc.IntegrityError):
855 db.insert(
856 bTable,
857 convertRowForInsert({
858 "id": len(bRows), "key": 1,
859 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2])
860 })
861 )
862 with self.assertRaises(sqlalchemy.exc.IntegrityError):
863 db.insert(
864 bTable,
865 convertRowForInsert({
866 "id": len(bRows), "key": 1,
867 TimespanReprClass.NAME: Timespan(timestamps[2], None)
868 })
869 )
870 # Test NULL checks in SELECT queries, on both tables.
871 aRepr = TimespanReprClass.fromSelectable(aTable)
872 self.assertEqual(
873 [row[TimespanReprClass.NAME] is None for row in aRows],
874 [
875 row.f for row in db.query(
876 sqlalchemy.sql.select(
877 aRepr.isNull().label("f")
878 ).order_by(
879 aTable.columns.id
880 )
881 )
882 ]
883 )
884 bRepr = TimespanReprClass.fromSelectable(bTable)
885 self.assertEqual(
886 [False for row in bRows],
887 [
888 row.f for row in db.query(
889 sqlalchemy.sql.select(
890 bRepr.isNull().label("f")
891 ).order_by(
892 bTable.columns.id
893 )
894 )
895 ]
896 )
897 # Test relationships expressions that relate in-database timespans to
898 # Python-literal timespans, all from the more complete 'a' set; check
899 # that this is consistent with Python-only relationship tests.
900 for rhsRow in aRows:
901 if rhsRow[TimespanReprClass.NAME] is None:
902 continue
903 with self.subTest(rhsRow=rhsRow):
904 expected = {}
905 for lhsRow in aRows:
906 if lhsRow[TimespanReprClass.NAME] is None:
907 expected[lhsRow["id"]] = (None, None, None, None)
908 else:
909 expected[lhsRow["id"]] = (
910 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
911 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
912 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
913 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
914 )
915 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
916 sql = sqlalchemy.sql.select(
917 aTable.columns.id.label("lhs"),
918 aRepr.overlaps(rhsRepr).label("overlaps"),
919 aRepr.contains(rhsRepr).label("contains"),
920 (aRepr < rhsRepr).label("less_than"),
921 (aRepr > rhsRepr).label("greater_than"),
922 ).select_from(aTable)
923 queried = {
924 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
925 for row in db.query(sql)
926 }
927 self.assertEqual(expected, queried)
928 # Test relationship expressions that relate in-database timespans to
929 # each other, all from the more complete 'a' set; check that this is
930 # consistent with Python-only relationship tests.
931 expected = {}
932 for lhs, rhs in itertools.product(aRows, aRows):
933 lhsT = lhs[TimespanReprClass.NAME]
934 rhsT = rhs[TimespanReprClass.NAME]
935 if lhsT is not None and rhsT is not None:
936 expected[lhs["id"], rhs["id"]] = (
937 lhsT.overlaps(rhsT),
938 lhsT.contains(rhsT),
939 lhsT < rhsT,
940 lhsT > rhsT
941 )
942 else:
943 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
944 lhsSubquery = aTable.alias("lhs")
945 rhsSubquery = aTable.alias("rhs")
946 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery)
947 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery)
948 sql = sqlalchemy.sql.select(
949 lhsSubquery.columns.id.label("lhs"),
950 rhsSubquery.columns.id.label("rhs"),
951 lhsRepr.overlaps(rhsRepr).label("overlaps"),
952 lhsRepr.contains(rhsRepr).label("contains"),
953 (lhsRepr < rhsRepr).label("less_than"),
954 (lhsRepr > rhsRepr).label("greater_than"),
955 ).select_from(
956 lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))
957 )
958 queried = {
959 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
960 for row in db.query(sql)}
961 self.assertEqual(expected, queried)
962 # Test relationship expressions between in-database timespans and
963 # Python-literal instantaneous times.
964 for t in timestamps:
965 with self.subTest(t=t):
966 expected = {}
967 for lhsRow in aRows:
968 if lhsRow[TimespanReprClass.NAME] is None:
969 expected[lhsRow["id"]] = (None, None, None)
970 else:
971 expected[lhsRow["id"]] = (
972 lhsRow[TimespanReprClass.NAME].contains(t),
973 lhsRow[TimespanReprClass.NAME] < t,
974 lhsRow[TimespanReprClass.NAME] > t,
975 )
976 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
977 sql = sqlalchemy.sql.select(
978 aTable.columns.id.label("lhs"),
979 aRepr.contains(rhs).label("contains"),
980 (aRepr < rhs).label("less_than"),
981 (aRepr > rhs).label("greater_than"),
982 ).select_from(aTable)
983 queried = {
984 row.lhs: (row.contains, row.less_than, row.greater_than)
985 for row in db.query(sql)
986 }
987 self.assertEqual(expected, queried)