Coverage for python/lsst/daf/butler/registry/tests/_database.py: 8%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25import asyncio
26import itertools
27import warnings
28from abc import ABC, abstractmethod
29from collections import namedtuple
30from concurrent.futures import ThreadPoolExecutor
31from contextlib import contextmanager
32from typing import ContextManager, Iterable, Optional, Set, Tuple
34import astropy.time
35import sqlalchemy
36from lsst.sphgeom import ConvexPolygon, UnitVector3d
38from ...core import Timespan, ddl
39from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError
41StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
43STATIC_TABLE_SPECS = StaticTablesTuple(
44 a=ddl.TableSpec(
45 fields=[
46 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
47 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
48 ]
49 ),
50 b=ddl.TableSpec(
51 fields=[
52 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
53 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
54 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
55 ],
56 unique=[("name",)],
57 ),
58 c=ddl.TableSpec(
59 fields=[
60 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
61 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
62 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
63 ],
64 foreignKeys=[
65 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
66 ],
67 ),
68)
70DYNAMIC_TABLE_SPEC = ddl.TableSpec(
71 fields=[
72 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
73 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
74 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
75 ],
76 foreignKeys=[
77 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
78 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
79 ],
80)
82TEMPORARY_TABLE_SPEC = ddl.TableSpec(
83 fields=[
84 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
85 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
86 ],
87)
90@contextmanager
91def _patch_getExistingTable(db: Database) -> Database:
92 """Patch getExistingTable method in a database instance to test concurrent
93 creation of tables. This patch obviously depends on knowning internals of
94 ``ensureTableExists()`` implementation.
95 """
96 original_method = db.getExistingTable
98 def _getExistingTable(name: str, spec: ddl.TableSpec) -> Optional[sqlalchemy.schema.Table]:
99 # Return None on first call, but forward to original method after that
100 db.getExistingTable = original_method
101 return None
103 db.getExistingTable = _getExistingTable
104 yield db
105 db.getExistingTable = original_method
108class DatabaseTests(ABC):
109 """Generic tests for the `Database` interface that can be subclassed to
110 generate tests for concrete implementations.
111 """
113 @abstractmethod
114 def makeEmptyDatabase(self, origin: int = 0) -> Database:
115 """Return an empty `Database` with the given origin, or an
116 automatically-generated one if ``origin`` is `None`.
117 """
118 raise NotImplementedError()
120 @abstractmethod
121 def asReadOnly(self, database: Database) -> ContextManager[Database]:
122 """Return a context manager for a read-only connection into the given
123 database.
125 The original database should be considered unusable within the context
126 but safe to use again afterwards (this allows the context manager to
127 block write access by temporarily changing user permissions to really
128 guarantee that write operations are not performed).
129 """
130 raise NotImplementedError()
132 @abstractmethod
133 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
134 """Return a new `Database` instance that points to the same underlying
135 storage as the given one.
136 """
137 raise NotImplementedError()
139 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
140 self.assertCountEqual(spec.fields.names, table.columns.keys())
141 # Checking more than this currently seems fragile, as it might restrict
142 # what Database implementations do; we don't care if the spec is
143 # actually preserved in terms of types and constraints as long as we
144 # can use the returned table as if it was.
146 def checkStaticSchema(self, tables: StaticTablesTuple):
147 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
148 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
149 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
151 def testDeclareStaticTables(self):
152 """Tests for `Database.declareStaticSchema` and the methods it
153 delegates to.
154 """
155 # Create the static schema in a new, empty database.
156 newDatabase = self.makeEmptyDatabase()
157 with newDatabase.declareStaticTables(create=True) as context:
158 tables = context.addTableTuple(STATIC_TABLE_SPECS)
159 self.checkStaticSchema(tables)
160 # Check that we can load that schema even from a read-only connection.
161 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
162 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
163 tables = context.addTableTuple(STATIC_TABLE_SPECS)
164 self.checkStaticSchema(tables)
166 def testDeclareStaticTablesTwice(self):
167 """Tests for `Database.declareStaticSchema` being called twice."""
168 # Create the static schema in a new, empty database.
169 newDatabase = self.makeEmptyDatabase()
170 with newDatabase.declareStaticTables(create=True) as context:
171 tables = context.addTableTuple(STATIC_TABLE_SPECS)
172 self.checkStaticSchema(tables)
173 # Second time it should raise
174 with self.assertRaises(SchemaAlreadyDefinedError):
175 with newDatabase.declareStaticTables(create=True) as context:
176 tables = context.addTableTuple(STATIC_TABLE_SPECS)
177 # Check schema, it should still contain all tables, and maybe some
178 # extra.
179 with newDatabase.declareStaticTables(create=False) as context:
180 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
182 def testRepr(self):
183 """Test that repr does not return a generic thing."""
184 newDatabase = self.makeEmptyDatabase()
185 rep = repr(newDatabase)
186 # Check that stringification works and gives us something different
187 self.assertNotEqual(rep, str(newDatabase))
188 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
189 self.assertIn("://", rep)
191 def testDynamicTables(self):
192 """Tests for `Database.ensureTableExists` and
193 `Database.getExistingTable`.
194 """
195 # Need to start with the static schema.
196 newDatabase = self.makeEmptyDatabase()
197 with newDatabase.declareStaticTables(create=True) as context:
198 context.addTableTuple(STATIC_TABLE_SPECS)
199 # Try to ensure the dynamic table exists in a read-only version of that
200 # database, which should fail because we can't create it.
201 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
202 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
203 context.addTableTuple(STATIC_TABLE_SPECS)
204 with self.assertRaises(ReadOnlyDatabaseError):
205 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
206 # Just getting the dynamic table before it exists should return None.
207 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
208 # Ensure the new table exists back in the original database, which
209 # should create it.
210 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
211 self.checkTable(DYNAMIC_TABLE_SPEC, table)
212 # Ensuring that it exists should just return the exact same table
213 # instance again.
214 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
215 # Try again from the read-only database.
216 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
217 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
218 context.addTableTuple(STATIC_TABLE_SPECS)
219 # Just getting the dynamic table should now work...
220 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
221 # ...as should ensuring that it exists, since it now does.
222 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
223 self.checkTable(DYNAMIC_TABLE_SPEC, table)
224 # Trying to get the table with a different specification (at least
225 # in terms of what columns are present) should raise.
226 with self.assertRaises(DatabaseConflictError):
227 newDatabase.ensureTableExists(
228 "d",
229 ddl.TableSpec(
230 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
231 ),
232 )
233 # Calling ensureTableExists inside a transaction block is an error,
234 # even if it would do nothing.
235 with newDatabase.transaction():
236 with self.assertRaises(AssertionError):
237 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
239 def testDynamicTablesConcurrency(self):
240 """Tests for `Database.ensureTableExists` concurrent use."""
241 # We cannot really run things concurrently in a deterministic way, here
242 # we just simulate a situation when the table is created by other
243 # process between the call to getExistingTable() and actual table
244 # creation.
245 db1 = self.makeEmptyDatabase()
246 with db1.declareStaticTables(create=True) as context:
247 context.addTableTuple(STATIC_TABLE_SPECS)
248 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC))
250 # Make a dynamic table using separate connection
251 db2 = self.getNewConnection(db1, writeable=True)
252 with db2.declareStaticTables(create=False) as context:
253 context.addTableTuple(STATIC_TABLE_SPECS)
254 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
255 self.checkTable(DYNAMIC_TABLE_SPEC, table)
257 # Call it again but trick it into thinking that table is not there.
258 # This test depends on knowing implementation of ensureTableExists()
259 # which initially calls getExistingTable() to check that table may
260 # exist, the patch intercepts that call and returns None.
261 with _patch_getExistingTable(db1):
262 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
264 def testTemporaryTables(self):
265 """Tests for `Database.makeTemporaryTable`,
266 `Database.dropTemporaryTable`, and `Database.insert` with
267 the ``select`` argument.
268 """
269 # Need to start with the static schema; also insert some test data.
270 newDatabase = self.makeEmptyDatabase()
271 with newDatabase.declareStaticTables(create=True) as context:
272 static = context.addTableTuple(STATIC_TABLE_SPECS)
273 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None})
274 bIds = newDatabase.insert(
275 static.b,
276 {"name": "b1", "value": 11},
277 {"name": "b2", "value": 12},
278 {"name": "b3", "value": 13},
279 returnIds=True,
280 )
281 # Create the table.
282 with newDatabase.session() as session:
283 table1 = session.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
284 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
285 # Insert via a INSERT INTO ... SELECT query.
286 newDatabase.insert(
287 table1,
288 select=sqlalchemy.sql.select(
289 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
290 )
291 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
292 .where(
293 sqlalchemy.sql.and_(
294 static.a.columns.name == "a1",
295 static.b.columns.value <= 12,
296 )
297 ),
298 )
299 # Check that the inserted rows are present.
300 self.assertCountEqual(
301 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
302 [row._asdict() for row in newDatabase.query(table1.select())],
303 )
304 # Create another one via a read-only connection to the database.
305 # We _do_ allow temporary table modifications in read-only
306 # databases.
307 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
308 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
309 context.addTableTuple(STATIC_TABLE_SPECS)
310 with existingReadOnlyDatabase.session() as session2:
311 table2 = session2.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
312 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
313 # Those tables should not be the same, despite having the
314 # same ddl.
315 self.assertIsNot(table1, table2)
316 # Do a slightly different insert into this table, to check
317 # that it works in a read-only database. This time we
318 # pass column names as a kwarg to insert instead of by
319 # labeling the columns in the select.
320 existingReadOnlyDatabase.insert(
321 table2,
322 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id)
323 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
324 .where(
325 sqlalchemy.sql.and_(
326 static.a.columns.name == "a2",
327 static.b.columns.value >= 12,
328 )
329 ),
330 names=["a_name", "b_id"],
331 )
332 # Check that the inserted rows are present.
333 self.assertCountEqual(
334 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
335 [row._asdict() for row in existingReadOnlyDatabase.query(table2.select())],
336 )
337 # Drop the temporary table from the read-only DB. It's
338 # unspecified whether attempting to use it after this
339 # point is an error or just never returns any results, so
340 # we can't test what it does, only that it's not an error.
341 session2.dropTemporaryTable(table2)
342 # Drop the original temporary table.
343 session.dropTemporaryTable(table1)
345 def testSchemaSeparation(self):
346 """Test that creating two different `Database` instances allows us
347 to create different tables with the same name in each.
348 """
349 db1 = self.makeEmptyDatabase(origin=1)
350 with db1.declareStaticTables(create=True) as context:
351 tables = context.addTableTuple(STATIC_TABLE_SPECS)
352 self.checkStaticSchema(tables)
354 db2 = self.makeEmptyDatabase(origin=2)
355 # Make the DDL here intentionally different so we'll definitely
356 # notice if db1 and db2 are pointing at the same schema.
357 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
358 with db2.declareStaticTables(create=True) as context:
359 # Make the DDL here intentionally different so we'll definitely
360 # notice if db1 and db2 are pointing at the same schema.
361 table = context.addTable("a", spec)
362 self.checkTable(spec, table)
364 def testInsertQueryDelete(self):
365 """Test the `Database.insert`, `Database.query`, and `Database.delete`
366 methods, as well as the `Base64Region` type and the ``onDelete``
367 argument to `ddl.ForeignKeySpec`.
368 """
369 db = self.makeEmptyDatabase(origin=1)
370 with db.declareStaticTables(create=True) as context:
371 tables = context.addTableTuple(STATIC_TABLE_SPECS)
372 # Insert a single, non-autoincrement row that contains a region and
373 # query to get it back.
374 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
375 row = {"name": "a1", "region": region}
376 db.insert(tables.a, row)
377 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row])
378 # Insert multiple autoincrement rows but do not try to get the IDs
379 # back immediately.
380 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
381 results = [r._asdict() for r in db.query(tables.b.select().order_by("id"))]
382 self.assertEqual(len(results), 2)
383 for row in results:
384 self.assertIn(row["name"], ("b1", "b2"))
385 self.assertIsInstance(row["id"], int)
386 self.assertGreater(results[1]["id"], results[0]["id"])
387 # Insert multiple autoincrement rows and get the IDs back from insert.
388 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
389 ids = db.insert(tables.b, *rows, returnIds=True)
390 results = [
391 r._asdict() for r in db.query(tables.b.select().where(tables.b.columns.id > results[1]["id"]))
392 ]
393 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
394 self.assertCountEqual(results, expected)
395 self.assertTrue(all(result["id"] is not None for result in results))
396 # Insert multiple rows into a table with an autoincrement+origin
397 # primary key, then use the returned IDs to insert into a dynamic
398 # table.
399 rows = [{"origin": db.origin, "b_id": results[0]["id"]}, {"origin": db.origin, "b_id": None}]
400 ids = db.insert(tables.c, *rows, returnIds=True)
401 results = [r._asdict() for r in db.query(tables.c.select())]
402 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
403 self.assertCountEqual(results, expected)
404 self.assertTrue(all(result["id"] is not None for result in results))
405 # Add the dynamic table.
406 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
407 # Insert into it.
408 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
409 db.insert(d, *rows)
410 results = [r._asdict() for r in db.query(d.select())]
411 self.assertCountEqual(rows, results)
412 # Insert multiple rows into a table with an autoincrement+origin
413 # primary key (this is especially tricky for SQLite, but good to test
414 # for all DBs), but pass in a value for the autoincrement key.
415 # For extra complexity, we re-use the autoincrement value with a
416 # different value for origin.
417 rows2 = [
418 {"id": 700, "origin": db.origin, "b_id": None},
419 {"id": 700, "origin": 60, "b_id": None},
420 {"id": 1, "origin": 60, "b_id": None},
421 ]
422 db.insert(tables.c, *rows2)
423 results = [r._asdict() for r in db.query(tables.c.select())]
424 self.assertCountEqual(results, expected + rows2)
425 self.assertTrue(all(result["id"] is not None for result in results))
427 # Define 'SELECT COUNT(*)' query for later use.
428 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
429 # Get the values we inserted into table b.
430 bValues = [r._asdict() for r in db.query(tables.b.select())]
431 # Remove two row from table b by ID.
432 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
433 self.assertEqual(n, 2)
434 # Remove the other two rows from table b by name.
435 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
436 self.assertEqual(n, 2)
437 # There should now be no rows in table b.
438 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0)
439 # All b_id values in table c should now be NULL, because there's an
440 # onDelete='SET NULL' foreign key.
441 self.assertEqual(
442 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
443 0,
444 )
445 # Remove all rows in table a (there's only one); this should remove all
446 # rows in d due to onDelete='CASCADE'.
447 n = db.delete(tables.a, [])
448 self.assertEqual(n, 1)
449 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
450 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
452 def testDeleteWhere(self):
453 """Tests for `Database.deleteWhere`."""
454 db = self.makeEmptyDatabase(origin=1)
455 with db.declareStaticTables(create=True) as context:
456 tables = context.addTableTuple(STATIC_TABLE_SPECS)
457 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
458 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
460 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
461 self.assertEqual(n, 3)
462 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 7)
464 n = db.deleteWhere(
465 tables.b,
466 tables.b.columns.id.in_(
467 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
468 ),
469 )
470 self.assertEqual(n, 4)
471 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 3)
473 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
474 self.assertEqual(n, 1)
475 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 2)
477 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
478 self.assertEqual(n, 2)
479 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0)
481 def testUpdate(self):
482 """Tests for `Database.update`."""
483 db = self.makeEmptyDatabase(origin=1)
484 with db.declareStaticTables(create=True) as context:
485 tables = context.addTableTuple(STATIC_TABLE_SPECS)
486 # Insert two rows into table a, both without regions.
487 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
488 # Update one of the rows with a region.
489 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
490 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
491 self.assertEqual(n, 1)
492 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
493 self.assertCountEqual(
494 [r._asdict() for r in db.query(sql)],
495 [{"name": "a1", "region": None}, {"name": "a2", "region": region}],
496 )
498 def testSync(self):
499 """Tests for `Database.sync`."""
500 db = self.makeEmptyDatabase(origin=1)
501 with db.declareStaticTables(create=True) as context:
502 tables = context.addTableTuple(STATIC_TABLE_SPECS)
503 # Insert a row with sync, because it doesn't exist yet.
504 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
505 self.assertTrue(inserted)
506 self.assertEqual(
507 [{"id": values["id"], "name": "b1", "value": 10}],
508 [r._asdict() for r in db.query(tables.b.select())],
509 )
510 # Repeat that operation, which should do nothing but return the
511 # requested values.
512 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
513 self.assertFalse(inserted)
514 self.assertEqual(
515 [{"id": values["id"], "name": "b1", "value": 10}],
516 [r._asdict() for r in db.query(tables.b.select())],
517 )
518 # Repeat the operation without the 'extra' arg, which should also just
519 # return the existing row.
520 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
521 self.assertFalse(inserted)
522 self.assertEqual(
523 [{"id": values["id"], "name": "b1", "value": 10}],
524 [r._asdict() for r in db.query(tables.b.select())],
525 )
526 # Repeat the operation with a different value in 'extra'. That still
527 # shouldn't be an error, because 'extra' is only used if we really do
528 # insert. Also drop the 'returning' argument.
529 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
530 self.assertFalse(inserted)
531 self.assertEqual(
532 [{"id": values["id"], "name": "b1", "value": 10}],
533 [r._asdict() for r in db.query(tables.b.select())],
534 )
535 # Repeat the operation with the correct value in 'compared' instead of
536 # 'extra'.
537 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
538 self.assertFalse(inserted)
539 self.assertEqual(
540 [{"id": values["id"], "name": "b1", "value": 10}],
541 [r._asdict() for r in db.query(tables.b.select())],
542 )
543 # Repeat the operation with an incorrect value in 'compared'; this
544 # should raise.
545 with self.assertRaises(DatabaseConflictError):
546 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
547 # Try to sync in a read-only database. This should work if and only
548 # if the matching row already exists.
549 with self.asReadOnly(db) as rodb:
550 with rodb.declareStaticTables(create=False) as context:
551 tables = context.addTableTuple(STATIC_TABLE_SPECS)
552 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
553 self.assertFalse(inserted)
554 self.assertEqual(
555 [{"id": values["id"], "name": "b1", "value": 10}],
556 [r._asdict() for r in rodb.query(tables.b.select())],
557 )
558 with self.assertRaises(ReadOnlyDatabaseError):
559 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
560 # Repeat the operation with a different value in 'compared' and ask to
561 # update.
562 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
563 self.assertEqual(updated, {"value": 10})
564 self.assertEqual(
565 [{"id": values["id"], "name": "b1", "value": 20}],
566 [r._asdict() for r in db.query(tables.b.select())],
567 )
569 def testReplace(self):
570 """Tests for `Database.replace`."""
571 db = self.makeEmptyDatabase(origin=1)
572 with db.declareStaticTables(create=True) as context:
573 tables = context.addTableTuple(STATIC_TABLE_SPECS)
574 # Use 'replace' to insert a single row that contains a region and
575 # query to get it back.
576 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
577 row1 = {"name": "a1", "region": region}
578 db.replace(tables.a, row1)
579 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
580 # Insert another row without a region.
581 row2 = {"name": "a2", "region": None}
582 db.replace(tables.a, row2)
583 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
584 # Use replace to re-insert both of those rows again, which should do
585 # nothing.
586 db.replace(tables.a, row1, row2)
587 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
588 # Replace row1 with a row with no region, while reinserting row2.
589 row1a = {"name": "a1", "region": None}
590 db.replace(tables.a, row1a, row2)
591 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1a, row2])
592 # Replace both rows, returning row1 to its original state, while adding
593 # a new one. Pass them in in a different order.
594 row2a = {"name": "a2", "region": region}
595 row3 = {"name": "a3", "region": None}
596 db.replace(tables.a, row3, row2a, row1)
597 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2a, row3])
599 def testEnsure(self):
600 """Tests for `Database.ensure`."""
601 db = self.makeEmptyDatabase(origin=1)
602 with db.declareStaticTables(create=True) as context:
603 tables = context.addTableTuple(STATIC_TABLE_SPECS)
604 # Use 'ensure' to insert a single row that contains a region and
605 # query to get it back.
606 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
607 row1 = {"name": "a1", "region": region}
608 self.assertEqual(db.ensure(tables.a, row1), 1)
609 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1])
610 # Insert another row without a region.
611 row2 = {"name": "a2", "region": None}
612 self.assertEqual(db.ensure(tables.a, row2), 1)
613 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
614 # Use ensure to re-insert both of those rows again, which should do
615 # nothing.
616 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
617 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
618 # Attempt to insert row1's key with no region, while
619 # reinserting row2. This should also do nothing.
620 row1a = {"name": "a1", "region": None}
621 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
622 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2])
623 # Attempt to insert new rows for both existing keys, this time also
624 # adding a new row. Pass them in in a different order. Only the new
625 # row should be added.
626 row2a = {"name": "a2", "region": region}
627 row3 = {"name": "a3", "region": None}
628 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
629 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2, row3])
631 def testTransactionNesting(self):
632 """Test that transactions can be nested with the behavior in the
633 presence of exceptions working as documented.
634 """
635 db = self.makeEmptyDatabase(origin=1)
636 with db.declareStaticTables(create=True) as context:
637 tables = context.addTableTuple(STATIC_TABLE_SPECS)
638 # Insert one row so we can trigger integrity errors by trying to insert
639 # a duplicate of it below.
640 db.insert(tables.a, {"name": "a1"})
641 # First test: error recovery via explicit savepoint=True in the inner
642 # transaction.
643 with db.transaction():
644 # This insert should succeed, and should not be rolled back because
645 # the assertRaises context should catch any exception before it
646 # propagates up to the outer transaction.
647 db.insert(tables.a, {"name": "a2"})
648 with self.assertRaises(sqlalchemy.exc.IntegrityError):
649 with db.transaction(savepoint=True):
650 # This insert should succeed, but should be rolled back.
651 db.insert(tables.a, {"name": "a4"})
652 # This insert should fail (duplicate primary key), raising
653 # an exception.
654 db.insert(tables.a, {"name": "a1"})
655 self.assertCountEqual(
656 [r._asdict() for r in db.query(tables.a.select())],
657 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
658 )
659 # Second test: error recovery via implicit savepoint=True, when the
660 # innermost transaction is inside a savepoint=True transaction.
661 with db.transaction():
662 # This insert should succeed, and should not be rolled back
663 # because the assertRaises context should catch any
664 # exception before it propagates up to the outer
665 # transaction.
666 db.insert(tables.a, {"name": "a3"})
667 with self.assertRaises(sqlalchemy.exc.IntegrityError):
668 with db.transaction(savepoint=True):
669 # This insert should succeed, but should be rolled back.
670 db.insert(tables.a, {"name": "a4"})
671 with db.transaction():
672 # This insert should succeed, but should be rolled
673 # back.
674 db.insert(tables.a, {"name": "a5"})
675 # This insert should fail (duplicate primary key),
676 # raising an exception.
677 db.insert(tables.a, {"name": "a1"})
678 self.assertCountEqual(
679 [r._asdict() for r in db.query(tables.a.select())],
680 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
681 )
683 def testTransactionLocking(self):
684 """Test that `Database.transaction` can be used to acquire a lock
685 that prohibits concurrent writes.
686 """
687 db1 = self.makeEmptyDatabase(origin=1)
688 with db1.declareStaticTables(create=True) as context:
689 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
691 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
692 """One side of the concurrent locking test.
694 This optionally locks the table (and maybe the whole database),
695 does a select for its contents, inserts a new row, and then selects
696 again, with some waiting in between to make sure the other side has
697 a chance to _attempt_ to insert in between. If the locking is
698 enabled and works, the difference between the selects should just
699 be the insert done on this thread.
700 """
701 # Give Side2 a chance to create a connection
702 await asyncio.sleep(1.0)
703 with db1.transaction(lock=lock):
704 names1 = {row.name for row in db1.query(tables1.a.select())}
705 # Give Side2 a chance to insert (which will be blocked if
706 # we've acquired a lock).
707 await asyncio.sleep(2.0)
708 db1.insert(tables1.a, {"name": "a1"})
709 names2 = {row.name for row in db1.query(tables1.a.select())}
710 return names1, names2
712 async def side2() -> None:
713 """The other side of the concurrent locking test.
715 This side just waits a bit and then tries to insert a row into the
716 table that the other side is trying to lock. Hopefully that
717 waiting is enough to give the other side a chance to acquire the
718 lock and thus make this side block until the lock is released. If
719 this side manages to do the insert before side1 acquires the lock,
720 we'll just warn about not succeeding at testing the locking,
721 because we can only make that unlikely, not impossible.
722 """
724 def toRunInThread():
725 """SQLite locking isn't asyncio-friendly unless we actually
726 run it in another thread. And SQLite gets very unhappy if
727 we try to use a connection from multiple threads, so we have
728 to create the new connection here instead of out in the main
729 body of the test function.
730 """
731 db2 = self.getNewConnection(db1, writeable=True)
732 with db2.declareStaticTables(create=False) as context:
733 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
734 with db2.transaction():
735 db2.insert(tables2.a, {"name": "a2"})
737 await asyncio.sleep(2.0)
738 loop = asyncio.get_running_loop()
739 with ThreadPoolExecutor() as pool:
740 await loop.run_in_executor(pool, toRunInThread)
742 async def testProblemsWithNoLocking() -> None:
743 """Run side1 and side2 with no locking, attempting to demonstrate
744 the problem that locking is supposed to solve. If we get unlucky
745 with scheduling, side2 will just happen to insert after side1 is
746 done, and we won't have anything definitive. We just warn in that
747 case because we really don't want spurious test failures.
748 """
749 task1 = asyncio.create_task(side1())
750 task2 = asyncio.create_task(side2())
752 names1, names2 = await task1
753 await task2
754 if "a2" in names1:
755 warnings.warn(
756 "Unlucky scheduling in no-locking test: concurrent INSERT "
757 "happened before first SELECT."
758 )
759 self.assertEqual(names1, {"a2"})
760 self.assertEqual(names2, {"a1", "a2"})
761 elif "a2" not in names2:
762 warnings.warn(
763 "Unlucky scheduling in no-locking test: concurrent INSERT "
764 "happened after second SELECT even without locking."
765 )
766 self.assertEqual(names1, set())
767 self.assertEqual(names2, {"a1"})
768 else:
769 # This is the expected case: both INSERTS happen between the
770 # two SELECTS. If we don't get this almost all of the time we
771 # should adjust the sleep amounts.
772 self.assertEqual(names1, set())
773 self.assertEqual(names2, {"a1", "a2"})
775 asyncio.run(testProblemsWithNoLocking())
777 # Clean up after first test.
778 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
780 async def testSolutionWithLocking() -> None:
781 """Run side1 and side2 with locking, which should make side2 block
782 its insert until side2 releases its lock.
783 """
784 task1 = asyncio.create_task(side1(lock=[tables1.a]))
785 task2 = asyncio.create_task(side2())
787 names1, names2 = await task1
788 await task2
789 if "a2" in names1:
790 warnings.warn(
791 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT."
792 )
793 self.assertEqual(names1, {"a2"})
794 self.assertEqual(names2, {"a1", "a2"})
795 else:
796 # This is the expected case: the side2 INSERT happens after the
797 # last SELECT on side1. This can also happen due to unlucky
798 # scheduling, and we have no way to detect that here, but the
799 # similar "no-locking" test has at least some chance of being
800 # affected by the same problem and warning about it.
801 self.assertEqual(names1, set())
802 self.assertEqual(names2, {"a1"})
804 asyncio.run(testSolutionWithLocking())
806 def testTimespanDatabaseRepresentation(self):
807 """Tests for `TimespanDatabaseRepresentation` and the `Database`
808 methods that interact with it.
809 """
810 # Make some test timespans to play with, with the full suite of
811 # topological relationships.
812 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
813 offset = astropy.time.TimeDelta(60, format="sec")
814 timestamps = [start + offset * n for n in range(3)]
815 aTimespans = [Timespan(begin=None, end=None)]
816 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
817 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
818 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
819 aTimespans.append(Timespan.makeEmpty())
820 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
821 # Make another list of timespans that span the full range but don't
822 # overlap. This is a subset of the previous list.
823 bTimespans = [Timespan(begin=None, end=timestamps[0])]
824 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
825 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
826 # Make a database and create a table with that database's timespan
827 # representation. This one will have no exclusion constraint and
828 # a nullable timespan.
829 db = self.makeEmptyDatabase(origin=1)
830 TimespanReprClass = db.getTimespanRepresentation()
831 aSpec = ddl.TableSpec(
832 fields=[
833 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
834 ],
835 )
836 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
837 aSpec.fields.add(fieldSpec)
838 with db.declareStaticTables(create=True) as context:
839 aTable = context.addTable("a", aSpec)
840 self.maxDiff = None
842 def convertRowForInsert(row: dict) -> dict:
843 """Convert a row containing a Timespan instance into one suitable
844 for insertion into the database.
845 """
846 result = row.copy()
847 ts = result.pop(TimespanReprClass.NAME)
848 return TimespanReprClass.update(ts, result=result)
850 def convertRowFromSelect(row: dict) -> dict:
851 """Convert a row from the database into one containing a Timespan.
853 Parameters
854 ----------
855 row : `dict`
856 Original row.
858 Returns
859 -------
860 row : `dict`
861 The updated row.
862 """
863 result = row.copy()
864 timespan = TimespanReprClass.extract(result)
865 for name in TimespanReprClass.getFieldNames():
866 del result[name]
867 result[TimespanReprClass.NAME] = timespan
868 return result
870 # Insert rows into table A, in chunks just to make things interesting.
871 # Include one with a NULL timespan.
872 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
873 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
874 db.insert(aTable, convertRowForInsert(aRows[0]))
875 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
876 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
877 # Add another one with a NULL timespan, but this time by invoking
878 # the server-side default.
879 aRows.append({"id": len(aRows)})
880 db.insert(aTable, aRows[-1])
881 aRows[-1][TimespanReprClass.NAME] = None
882 # Test basic round-trip through database.
883 self.assertEqual(
884 aRows,
885 [
886 convertRowFromSelect(row._asdict())
887 for row in db.query(aTable.select().order_by(aTable.columns.id))
888 ],
889 )
890 # Create another table B with a not-null timespan and (if the database
891 # supports it), an exclusion constraint. Use ensureTableExists this
892 # time to check that mode of table creation vs. timespans.
893 bSpec = ddl.TableSpec(
894 fields=[
895 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
896 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
897 ],
898 )
899 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
900 bSpec.fields.add(fieldSpec)
901 if TimespanReprClass.hasExclusionConstraint():
902 bSpec.exclusion.add(("key", TimespanReprClass))
903 bTable = db.ensureTableExists("b", bSpec)
904 # Insert rows into table B, again in chunks. Each Timespan appears
905 # twice, but with different values for the 'key' field (which should
906 # still be okay for any exclusion constraint we may have defined).
907 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
908 offset = len(bRows)
909 bRows.extend(
910 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)
911 )
912 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
913 db.insert(bTable, convertRowForInsert(bRows[2]))
914 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
915 # Insert a row with no timespan into table B. This should invoke the
916 # server-side default, which is a timespan over (-∞, ∞). We set
917 # key=3 to avoid upsetting an exclusion constraint that might exist.
918 bRows.append({"id": len(bRows), "key": 3})
919 db.insert(bTable, bRows[-1])
920 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
921 # Test basic round-trip through database.
922 self.assertEqual(
923 bRows,
924 [
925 convertRowFromSelect(row._asdict())
926 for row in db.query(bTable.select().order_by(bTable.columns.id))
927 ],
928 )
929 # Test that we can't insert timespan=None into this table.
930 with self.assertRaises(sqlalchemy.exc.IntegrityError):
931 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None}))
932 # IFF this database supports exclusion constraints, test that they
933 # also prevent inserts.
934 if TimespanReprClass.hasExclusionConstraint():
935 with self.assertRaises(sqlalchemy.exc.IntegrityError):
936 db.insert(
937 bTable,
938 convertRowForInsert(
939 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])}
940 ),
941 )
942 with self.assertRaises(sqlalchemy.exc.IntegrityError):
943 db.insert(
944 bTable,
945 convertRowForInsert(
946 {
947 "id": len(bRows),
948 "key": 1,
949 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]),
950 }
951 ),
952 )
953 with self.assertRaises(sqlalchemy.exc.IntegrityError):
954 db.insert(
955 bTable,
956 convertRowForInsert(
957 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)}
958 ),
959 )
960 # Test NULL checks in SELECT queries, on both tables.
961 aRepr = TimespanReprClass.fromSelectable(aTable)
962 self.assertEqual(
963 [row[TimespanReprClass.NAME] is None for row in aRows],
964 [
965 row.f
966 for row in db.query(
967 sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id)
968 )
969 ],
970 )
971 bRepr = TimespanReprClass.fromSelectable(bTable)
972 self.assertEqual(
973 [False for row in bRows],
974 [
975 row.f
976 for row in db.query(
977 sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id)
978 )
979 ],
980 )
981 # Test relationships expressions that relate in-database timespans to
982 # Python-literal timespans, all from the more complete 'a' set; check
983 # that this is consistent with Python-only relationship tests.
984 for rhsRow in aRows:
985 if rhsRow[TimespanReprClass.NAME] is None:
986 continue
987 with self.subTest(rhsRow=rhsRow):
988 expected = {}
989 for lhsRow in aRows:
990 if lhsRow[TimespanReprClass.NAME] is None:
991 expected[lhsRow["id"]] = (None, None, None, None)
992 else:
993 expected[lhsRow["id"]] = (
994 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
995 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
996 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
997 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
998 )
999 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
1000 sql = sqlalchemy.sql.select(
1001 aTable.columns.id.label("lhs"),
1002 aRepr.overlaps(rhsRepr).label("overlaps"),
1003 aRepr.contains(rhsRepr).label("contains"),
1004 (aRepr < rhsRepr).label("less_than"),
1005 (aRepr > rhsRepr).label("greater_than"),
1006 ).select_from(aTable)
1007 queried = {
1008 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
1009 for row in db.query(sql)
1010 }
1011 self.assertEqual(expected, queried)
1012 # Test relationship expressions that relate in-database timespans to
1013 # each other, all from the more complete 'a' set; check that this is
1014 # consistent with Python-only relationship tests.
1015 expected = {}
1016 for lhs, rhs in itertools.product(aRows, aRows):
1017 lhsT = lhs[TimespanReprClass.NAME]
1018 rhsT = rhs[TimespanReprClass.NAME]
1019 if lhsT is not None and rhsT is not None:
1020 expected[lhs["id"], rhs["id"]] = (
1021 lhsT.overlaps(rhsT),
1022 lhsT.contains(rhsT),
1023 lhsT < rhsT,
1024 lhsT > rhsT,
1025 )
1026 else:
1027 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
1028 lhsSubquery = aTable.alias("lhs")
1029 rhsSubquery = aTable.alias("rhs")
1030 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery)
1031 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery)
1032 sql = sqlalchemy.sql.select(
1033 lhsSubquery.columns.id.label("lhs"),
1034 rhsSubquery.columns.id.label("rhs"),
1035 lhsRepr.overlaps(rhsRepr).label("overlaps"),
1036 lhsRepr.contains(rhsRepr).label("contains"),
1037 (lhsRepr < rhsRepr).label("less_than"),
1038 (lhsRepr > rhsRepr).label("greater_than"),
1039 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True)))
1040 queried = {
1041 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
1042 for row in db.query(sql)
1043 }
1044 self.assertEqual(expected, queried)
1045 # Test relationship expressions between in-database timespans and
1046 # Python-literal instantaneous times.
1047 for t in timestamps:
1048 with self.subTest(t=t):
1049 expected = {}
1050 for lhsRow in aRows:
1051 if lhsRow[TimespanReprClass.NAME] is None:
1052 expected[lhsRow["id"]] = (None, None, None)
1053 else:
1054 expected[lhsRow["id"]] = (
1055 lhsRow[TimespanReprClass.NAME].contains(t),
1056 lhsRow[TimespanReprClass.NAME] < t,
1057 lhsRow[TimespanReprClass.NAME] > t,
1058 )
1059 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1060 sql = sqlalchemy.sql.select(
1061 aTable.columns.id.label("lhs"),
1062 aRepr.contains(rhs).label("contains"),
1063 (aRepr < rhs).label("less_than"),
1064 (aRepr > rhs).label("greater_than"),
1065 ).select_from(aTable)
1066 queried = {row.lhs: (row.contains, row.less_than, row.greater_than) for row in db.query(sql)}
1067 self.assertEqual(expected, queried)