Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%
491 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ["DatabaseTests"]
31import asyncio
32import itertools
33import warnings
34from abc import ABC, abstractmethod
35from collections import namedtuple
36from collections.abc import Iterable
37from concurrent.futures import ThreadPoolExecutor
38from contextlib import AbstractContextManager, contextmanager
39from typing import Any
41import astropy.time
42import sqlalchemy
43from lsst.sphgeom import Circle, ConvexPolygon, UnitVector3d
45from ...core import Timespan, ddl
46from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError
48StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
50STATIC_TABLE_SPECS = StaticTablesTuple(
51 a=ddl.TableSpec(
52 fields=[
53 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
54 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
55 ]
56 ),
57 b=ddl.TableSpec(
58 fields=[
59 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
60 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
61 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
62 ],
63 unique=[("name",)],
64 ),
65 c=ddl.TableSpec(
66 fields=[
67 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
68 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
69 ],
70 foreignKeys=[
71 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
72 ],
73 ),
74)
76DYNAMIC_TABLE_SPEC = ddl.TableSpec(
77 fields=[
78 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
79 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
80 ],
81 foreignKeys=[
82 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"),
83 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
84 ],
85)
87TEMPORARY_TABLE_SPEC = ddl.TableSpec(
88 fields=[
89 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
90 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
91 ],
92)
95@contextmanager
96def _patch_getExistingTable(db: Database) -> Database:
97 """Patch getExistingTable method in a database instance to test concurrent
98 creation of tables. This patch obviously depends on knowning internals of
99 ``ensureTableExists()`` implementation.
100 """
101 original_method = db.getExistingTable
103 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None:
104 # Return None on first call, but forward to original method after that
105 db.getExistingTable = original_method
106 return None
108 db.getExistingTable = _getExistingTable
109 yield db
110 db.getExistingTable = original_method
113class DatabaseTests(ABC):
114 """Generic tests for the `Database` interface that can be subclassed to
115 generate tests for concrete implementations.
116 """
118 @abstractmethod
119 def makeEmptyDatabase(self, origin: int = 0) -> Database:
120 """Return an empty `Database` with the given origin, or an
121 automatically-generated one if ``origin`` is `None`.
122 """
123 raise NotImplementedError()
125 @abstractmethod
126 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]:
127 """Return a context manager for a read-only connection into the given
128 database.
130 The original database should be considered unusable within the context
131 but safe to use again afterwards (this allows the context manager to
132 block write access by temporarily changing user permissions to really
133 guarantee that write operations are not performed).
134 """
135 raise NotImplementedError()
137 @abstractmethod
138 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
139 """Return a new `Database` instance that points to the same underlying
140 storage as the given one.
141 """
142 raise NotImplementedError()
144 def query_list(
145 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase
146 ) -> list[sqlalchemy.engine.Row]:
147 """Run a SELECT or other read-only query against the database and
148 return the results as a list.
150 This is a thin wrapper around database.query() that just avoids
151 context-manager boilerplate that is usefully verbose in production code
152 but just noise in tests.
153 """
154 with database.transaction(), database.query(executable) as result:
155 return result.fetchall()
157 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any:
158 """Run a SELECT query that yields a single column and row against the
159 database and return its value.
161 This is a thin wrapper around database.query() that just avoids
162 context-manager boilerplate that is usefully verbose in production code
163 but just noise in tests.
164 """
165 with database.query(executable) as result:
166 return result.scalar()
168 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
169 self.assertCountEqual(spec.fields.names, table.columns.keys())
170 # Checking more than this currently seems fragile, as it might restrict
171 # what Database implementations do; we don't care if the spec is
172 # actually preserved in terms of types and constraints as long as we
173 # can use the returned table as if it was.
175 def checkStaticSchema(self, tables: StaticTablesTuple):
176 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
177 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
178 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
180 def testDeclareStaticTables(self):
181 """Tests for `Database.declareStaticSchema` and the methods it
182 delegates to.
183 """
184 # Create the static schema in a new, empty database.
185 newDatabase = self.makeEmptyDatabase()
186 with newDatabase.declareStaticTables(create=True) as context:
187 tables = context.addTableTuple(STATIC_TABLE_SPECS)
188 self.checkStaticSchema(tables)
189 # Check that we can load that schema even from a read-only connection.
190 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
191 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
192 tables = context.addTableTuple(STATIC_TABLE_SPECS)
193 self.checkStaticSchema(tables)
195 def testDeclareStaticTablesTwice(self):
196 """Tests for `Database.declareStaticSchema` being called twice."""
197 # Create the static schema in a new, empty database.
198 newDatabase = self.makeEmptyDatabase()
199 with newDatabase.declareStaticTables(create=True) as context:
200 tables = context.addTableTuple(STATIC_TABLE_SPECS)
201 self.checkStaticSchema(tables)
202 # Second time it should raise
203 with self.assertRaises(SchemaAlreadyDefinedError):
204 with newDatabase.declareStaticTables(create=True) as context:
205 tables = context.addTableTuple(STATIC_TABLE_SPECS)
206 # Check schema, it should still contain all tables, and maybe some
207 # extra.
208 with newDatabase.declareStaticTables(create=False) as context:
209 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
211 def testRepr(self):
212 """Test that repr does not return a generic thing."""
213 newDatabase = self.makeEmptyDatabase()
214 rep = repr(newDatabase)
215 # Check that stringification works and gives us something different
216 self.assertNotEqual(rep, str(newDatabase))
217 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
218 self.assertIn("://", rep)
220 def testDynamicTables(self):
221 """Tests for `Database.ensureTableExists` and
222 `Database.getExistingTable`.
223 """
224 # Need to start with the static schema.
225 newDatabase = self.makeEmptyDatabase()
226 with newDatabase.declareStaticTables(create=True) as context:
227 context.addTableTuple(STATIC_TABLE_SPECS)
228 # Try to ensure the dynamic table exists in a read-only version of that
229 # database, which should fail because we can't create it.
230 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
231 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
232 context.addTableTuple(STATIC_TABLE_SPECS)
233 with self.assertRaises(ReadOnlyDatabaseError):
234 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
235 # Just getting the dynamic table before it exists should return None.
236 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
237 # Ensure the new table exists back in the original database, which
238 # should create it.
239 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
240 self.checkTable(DYNAMIC_TABLE_SPEC, table)
241 # Ensuring that it exists should just return the exact same table
242 # instance again.
243 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
244 # Try again from the read-only database.
245 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
246 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
247 context.addTableTuple(STATIC_TABLE_SPECS)
248 # Just getting the dynamic table should now work...
249 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
250 # ...as should ensuring that it exists, since it now does.
251 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
252 self.checkTable(DYNAMIC_TABLE_SPEC, table)
253 # Trying to get the table with a different specification (at least
254 # in terms of what columns are present) should raise.
255 with self.assertRaises(DatabaseConflictError):
256 newDatabase.ensureTableExists(
257 "d",
258 ddl.TableSpec(
259 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
260 ),
261 )
262 # Calling ensureTableExists inside a transaction block is an error,
263 # even if it would do nothing.
264 with newDatabase.transaction():
265 with self.assertRaises(AssertionError):
266 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
268 def testDynamicTablesConcurrency(self):
269 """Tests for `Database.ensureTableExists` concurrent use."""
270 # We cannot really run things concurrently in a deterministic way, here
271 # we just simulate a situation when the table is created by other
272 # process between the call to getExistingTable() and actual table
273 # creation.
274 db1 = self.makeEmptyDatabase()
275 with db1.declareStaticTables(create=True) as context:
276 context.addTableTuple(STATIC_TABLE_SPECS)
277 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC))
279 # Make a dynamic table using separate connection
280 db2 = self.getNewConnection(db1, writeable=True)
281 with db2.declareStaticTables(create=False) as context:
282 context.addTableTuple(STATIC_TABLE_SPECS)
283 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
284 self.checkTable(DYNAMIC_TABLE_SPEC, table)
286 # Call it again but trick it into thinking that table is not there.
287 # This test depends on knowing implementation of ensureTableExists()
288 # which initially calls getExistingTable() to check that table may
289 # exist, the patch intercepts that call and returns None.
290 with _patch_getExistingTable(db1):
291 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
293 def testTemporaryTables(self):
294 """Tests for `Database.temporary_table`, and `Database.insert` with the
295 ``select`` argument.
296 """
297 # Need to start with the static schema; also insert some test data.
298 newDatabase = self.makeEmptyDatabase()
299 with newDatabase.declareStaticTables(create=True) as context:
300 static = context.addTableTuple(STATIC_TABLE_SPECS)
301 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None})
302 bIds = newDatabase.insert(
303 static.b,
304 {"name": "b1", "value": 11},
305 {"name": "b2", "value": 12},
306 {"name": "b3", "value": 13},
307 returnIds=True,
308 )
309 # Create the table.
310 with newDatabase.session():
311 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1:
312 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
313 # Insert via a INSERT INTO ... SELECT query.
314 newDatabase.insert(
315 table1,
316 select=sqlalchemy.sql.select(
317 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
318 )
319 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
320 .where(
321 sqlalchemy.sql.and_(
322 static.a.columns.name == "a1",
323 static.b.columns.value <= 12,
324 )
325 ),
326 )
327 # Check that the inserted rows are present.
328 self.assertCountEqual(
329 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
330 [row._asdict() for row in self.query_list(newDatabase, table1.select())],
331 )
332 # Create another one via a read-only connection to the
333 # database. We _do_ allow temporary table modifications in
334 # read-only databases.
335 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
336 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
337 context.addTableTuple(STATIC_TABLE_SPECS)
338 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2:
339 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
340 # Those tables should not be the same, despite having
341 # the same ddl.
342 self.assertIsNot(table1, table2)
343 # Do a slightly different insert into this table, to
344 # check that it works in a read-only database. This
345 # time we pass column names as a kwarg to insert
346 # instead of by labeling the columns in the select.
347 existingReadOnlyDatabase.insert(
348 table2,
349 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id)
350 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
351 .where(
352 sqlalchemy.sql.and_(
353 static.a.columns.name == "a2",
354 static.b.columns.value >= 12,
355 )
356 ),
357 names=["a_name", "b_id"],
358 )
359 # Check that the inserted rows are present.
360 self.assertCountEqual(
361 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
362 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())],
363 )
364 # Exiting the context managers will drop the temporary tables from
365 # the read-only DB. It's unspecified whether attempting to use it
366 # after this point is an error or just never returns any results,
367 # so we can't test what it does, only that it's not an error.
369 def testSchemaSeparation(self):
370 """Test that creating two different `Database` instances allows us
371 to create different tables with the same name in each.
372 """
373 db1 = self.makeEmptyDatabase(origin=1)
374 with db1.declareStaticTables(create=True) as context:
375 tables = context.addTableTuple(STATIC_TABLE_SPECS)
376 self.checkStaticSchema(tables)
378 db2 = self.makeEmptyDatabase(origin=2)
379 # Make the DDL here intentionally different so we'll definitely
380 # notice if db1 and db2 are pointing at the same schema.
381 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
382 with db2.declareStaticTables(create=True) as context:
383 # Make the DDL here intentionally different so we'll definitely
384 # notice if db1 and db2 are pointing at the same schema.
385 table = context.addTable("a", spec)
386 self.checkTable(spec, table)
388 def testInsertQueryDelete(self):
389 """Test the `Database.insert`, `Database.query`, and `Database.delete`
390 methods, as well as the `Base64Region` type and the ``onDelete``
391 argument to `ddl.ForeignKeySpec`.
392 """
393 db = self.makeEmptyDatabase(origin=1)
394 with db.declareStaticTables(create=True) as context:
395 tables = context.addTableTuple(STATIC_TABLE_SPECS)
396 # Insert a single, non-autoincrement row that contains a region and
397 # query to get it back.
398 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
399 row = {"name": "a1", "region": region}
400 db.insert(tables.a, row)
401 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row])
402 # Insert multiple autoincrement rows but do not try to get the IDs
403 # back immediately.
404 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
405 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))]
406 self.assertEqual(len(results), 2)
407 for row in results:
408 self.assertIn(row["name"], ("b1", "b2"))
409 self.assertIsInstance(row["id"], int)
410 self.assertGreater(results[1]["id"], results[0]["id"])
411 # Insert multiple autoincrement rows and get the IDs back from insert.
412 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
413 ids = db.insert(tables.b, *rows, returnIds=True)
414 results = [
415 r._asdict()
416 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"]))
417 ]
418 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
419 self.assertCountEqual(results, expected)
420 self.assertTrue(all(result["id"] is not None for result in results))
421 # Insert multiple rows into a table with an autoincrement primary key,
422 # then use the returned IDs to insert into a dynamic table.
423 rows = [{"b_id": results[0]["id"]}, {"b_id": None}]
424 ids = db.insert(tables.c, *rows, returnIds=True)
425 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
426 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
427 self.assertCountEqual(results, expected)
428 self.assertTrue(all(result["id"] is not None for result in results))
429 # Add the dynamic table.
430 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
431 # Insert into it.
432 rows = [{"c_id": id, "a_name": "a1"} for id in ids]
433 db.insert(d, *rows)
434 results = [r._asdict() for r in self.query_list(db, d.select())]
435 self.assertCountEqual(rows, results)
436 # Insert multiple rows into a table with an autoincrement primary key,
437 # but pass in a value for the autoincrement key.
438 rows2 = [
439 {"id": 700, "b_id": None},
440 {"id": 701, "b_id": None},
441 ]
442 db.insert(tables.c, *rows2)
443 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
444 self.assertCountEqual(results, expected + rows2)
445 self.assertTrue(all(result["id"] is not None for result in results))
447 # Define 'SELECT COUNT(*)' query for later use.
448 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
449 # Get the values we inserted into table b.
450 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())]
451 # Remove two row from table b by ID.
452 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
453 self.assertEqual(n, 2)
454 # Remove the other two rows from table b by name.
455 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
456 self.assertEqual(n, 2)
457 # There should now be no rows in table b.
458 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
459 # All b_id values in table c should now be NULL, because there's an
460 # onDelete='SET NULL' foreign key.
461 self.assertEqual(
462 self.query_scalar(
463 db,
464 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711
465 ),
466 0,
467 )
468 # Remove all rows in table a (there's only one); this should remove all
469 # rows in d due to onDelete='CASCADE'.
470 n = db.delete(tables.a, [])
471 self.assertEqual(n, 1)
472 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0)
473 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0)
475 def testDeleteWhere(self):
476 """Tests for `Database.deleteWhere`."""
477 db = self.makeEmptyDatabase(origin=1)
478 with db.declareStaticTables(create=True) as context:
479 tables = context.addTableTuple(STATIC_TABLE_SPECS)
480 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
481 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
483 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
484 self.assertEqual(n, 3)
485 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7)
487 n = db.deleteWhere(
488 tables.b,
489 tables.b.columns.id.in_(
490 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
491 ),
492 )
493 self.assertEqual(n, 4)
494 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3)
496 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
497 self.assertEqual(n, 1)
498 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2)
500 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
501 self.assertEqual(n, 2)
502 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
504 def testUpdate(self):
505 """Tests for `Database.update`."""
506 db = self.makeEmptyDatabase(origin=1)
507 with db.declareStaticTables(create=True) as context:
508 tables = context.addTableTuple(STATIC_TABLE_SPECS)
509 # Insert two rows into table a, both without regions.
510 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
511 # Update one of the rows with a region.
512 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
513 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
514 self.assertEqual(n, 1)
515 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
516 self.assertCountEqual(
517 [r._asdict() for r in self.query_list(db, sql)],
518 [{"name": "a1", "region": None}, {"name": "a2", "region": region}],
519 )
521 def testSync(self):
522 """Tests for `Database.sync`."""
523 db = self.makeEmptyDatabase(origin=1)
524 with db.declareStaticTables(create=True) as context:
525 tables = context.addTableTuple(STATIC_TABLE_SPECS)
526 # Insert a row with sync, because it doesn't exist yet.
527 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
528 self.assertTrue(inserted)
529 self.assertEqual(
530 [{"id": values["id"], "name": "b1", "value": 10}],
531 [r._asdict() for r in self.query_list(db, tables.b.select())],
532 )
533 # Repeat that operation, which should do nothing but return the
534 # requested values.
535 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
536 self.assertFalse(inserted)
537 self.assertEqual(
538 [{"id": values["id"], "name": "b1", "value": 10}],
539 [r._asdict() for r in self.query_list(db, tables.b.select())],
540 )
541 # Repeat the operation without the 'extra' arg, which should also just
542 # return the existing row.
543 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
544 self.assertFalse(inserted)
545 self.assertEqual(
546 [{"id": values["id"], "name": "b1", "value": 10}],
547 [r._asdict() for r in self.query_list(db, tables.b.select())],
548 )
549 # Repeat the operation with a different value in 'extra'. That still
550 # shouldn't be an error, because 'extra' is only used if we really do
551 # insert. Also drop the 'returning' argument.
552 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
553 self.assertFalse(inserted)
554 self.assertEqual(
555 [{"id": values["id"], "name": "b1", "value": 10}],
556 [r._asdict() for r in self.query_list(db, tables.b.select())],
557 )
558 # Repeat the operation with the correct value in 'compared' instead of
559 # 'extra'.
560 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
561 self.assertFalse(inserted)
562 self.assertEqual(
563 [{"id": values["id"], "name": "b1", "value": 10}],
564 [r._asdict() for r in self.query_list(db, tables.b.select())],
565 )
566 # Repeat the operation with an incorrect value in 'compared'; this
567 # should raise.
568 with self.assertRaises(DatabaseConflictError):
569 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
570 # Try to sync in a read-only database. This should work if and only
571 # if the matching row already exists.
572 with self.asReadOnly(db) as rodb:
573 with rodb.declareStaticTables(create=False) as context:
574 tables = context.addTableTuple(STATIC_TABLE_SPECS)
575 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
576 self.assertFalse(inserted)
577 self.assertEqual(
578 [{"id": values["id"], "name": "b1", "value": 10}],
579 [r._asdict() for r in self.query_list(rodb, tables.b.select())],
580 )
581 with self.assertRaises(ReadOnlyDatabaseError):
582 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
583 # Repeat the operation with a different value in 'compared' and ask to
584 # update.
585 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
586 self.assertEqual(updated, {"value": 10})
587 self.assertEqual(
588 [{"id": values["id"], "name": "b1", "value": 20}],
589 [r._asdict() for r in self.query_list(db, tables.b.select())],
590 )
592 def testReplace(self):
593 """Tests for `Database.replace`."""
594 db = self.makeEmptyDatabase(origin=1)
595 with db.declareStaticTables(create=True) as context:
596 tables = context.addTableTuple(STATIC_TABLE_SPECS)
597 # Use 'replace' to insert a single row that contains a region and
598 # query to get it back.
599 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
600 row1 = {"name": "a1", "region": region}
601 db.replace(tables.a, row1)
602 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
603 # Insert another row without a region.
604 row2 = {"name": "a2", "region": None}
605 db.replace(tables.a, row2)
606 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
607 # Use replace to re-insert both of those rows again, which should do
608 # nothing.
609 db.replace(tables.a, row1, row2)
610 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
611 # Replace row1 with a row with no region, while reinserting row2.
612 row1a = {"name": "a1", "region": None}
613 db.replace(tables.a, row1a, row2)
614 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2])
615 # Replace both rows, returning row1 to its original state, while adding
616 # a new one. Pass them in in a different order.
617 row2a = {"name": "a2", "region": region}
618 row3 = {"name": "a3", "region": None}
619 db.replace(tables.a, row3, row2a, row1)
620 self.assertCountEqual(
621 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3]
622 )
624 def testEnsure(self):
625 """Tests for `Database.ensure`."""
626 db = self.makeEmptyDatabase(origin=1)
627 with db.declareStaticTables(create=True) as context:
628 tables = context.addTableTuple(STATIC_TABLE_SPECS)
629 # Use 'ensure' to insert a single row that contains a region and
630 # query to get it back.
631 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
632 row1 = {"name": "a1", "region": region}
633 self.assertEqual(db.ensure(tables.a, row1), 1)
634 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
635 # Insert another row without a region.
636 row2 = {"name": "a2", "region": None}
637 self.assertEqual(db.ensure(tables.a, row2), 1)
638 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
639 # Use ensure to re-insert both of those rows again, which should do
640 # nothing.
641 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
642 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
643 # Attempt to insert row1's key with no region, while
644 # reinserting row2. This should also do nothing.
645 row1a = {"name": "a1", "region": None}
646 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
647 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
648 # Attempt to insert new rows for both existing keys, this time also
649 # adding a new row. Pass them in in a different order. Only the new
650 # row should be added.
651 row2a = {"name": "a2", "region": region}
652 row3 = {"name": "a3", "region": None}
653 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
654 self.assertCountEqual(
655 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3]
656 )
657 # Add some data to a table with both a primary key and a different
658 # unique constraint.
659 row_b = {"id": 5, "name": "five", "value": 50}
660 db.insert(tables.b, row_b)
661 # Attempt ensure with primary_key_only=False and a conflict for the
662 # non-PK constraint. This should do nothing.
663 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200})
664 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
665 # Now use primary_key_only=True with conflict in only the non-PK field.
666 # This should be an integrity error and nothing should change.
667 with self.assertRaises(sqlalchemy.exc.IntegrityError):
668 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True)
669 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
670 # With primary_key_only=True a conflict in the primary key is ignored
671 # regardless of whether there is a conflict elsewhere.
672 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True)
673 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
674 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True)
675 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
677 def testTransactionNesting(self):
678 """Test that transactions can be nested with the behavior in the
679 presence of exceptions working as documented.
680 """
681 db = self.makeEmptyDatabase(origin=1)
682 with db.declareStaticTables(create=True) as context:
683 tables = context.addTableTuple(STATIC_TABLE_SPECS)
684 # Insert one row so we can trigger integrity errors by trying to insert
685 # a duplicate of it below.
686 db.insert(tables.a, {"name": "a1"})
687 # First test: error recovery via explicit savepoint=True in the inner
688 # transaction.
689 with db.transaction():
690 # This insert should succeed, and should not be rolled back because
691 # the assertRaises context should catch any exception before it
692 # propagates up to the outer transaction.
693 db.insert(tables.a, {"name": "a2"})
694 with self.assertRaises(sqlalchemy.exc.IntegrityError):
695 with db.transaction(savepoint=True):
696 # This insert should succeed, but should be rolled back.
697 db.insert(tables.a, {"name": "a4"})
698 # This insert should fail (duplicate primary key), raising
699 # an exception.
700 db.insert(tables.a, {"name": "a1"})
701 self.assertCountEqual(
702 [r._asdict() for r in self.query_list(db, tables.a.select())],
703 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
704 )
705 # Second test: error recovery via implicit savepoint=True, when the
706 # innermost transaction is inside a savepoint=True transaction.
707 with db.transaction():
708 # This insert should succeed, and should not be rolled back
709 # because the assertRaises context should catch any
710 # exception before it propagates up to the outer
711 # transaction.
712 db.insert(tables.a, {"name": "a3"})
713 with self.assertRaises(sqlalchemy.exc.IntegrityError):
714 with db.transaction(savepoint=True):
715 # This insert should succeed, but should be rolled back.
716 db.insert(tables.a, {"name": "a4"})
717 with db.transaction():
718 # This insert should succeed, but should be rolled
719 # back.
720 db.insert(tables.a, {"name": "a5"})
721 # This insert should fail (duplicate primary key),
722 # raising an exception.
723 db.insert(tables.a, {"name": "a1"})
724 self.assertCountEqual(
725 [r._asdict() for r in self.query_list(db, tables.a.select())],
726 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
727 )
729 def testTransactionLocking(self):
730 """Test that `Database.transaction` can be used to acquire a lock
731 that prohibits concurrent writes.
732 """
733 db1 = self.makeEmptyDatabase(origin=1)
734 with db1.declareStaticTables(create=True) as context:
735 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
737 async def side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]:
738 """One side of the concurrent locking test.
740 This optionally locks the table (and maybe the whole database),
741 does a select for its contents, inserts a new row, and then selects
742 again, with some waiting in between to make sure the other side has
743 a chance to _attempt_ to insert in between. If the locking is
744 enabled and works, the difference between the selects should just
745 be the insert done on this thread.
746 """
747 # Give Side2 a chance to create a connection
748 await asyncio.sleep(1.0)
749 with db1.transaction(lock=lock):
750 names1 = {row.name for row in self.query_list(db1, tables1.a.select())}
751 # Give Side2 a chance to insert (which will be blocked if
752 # we've acquired a lock).
753 await asyncio.sleep(2.0)
754 db1.insert(tables1.a, {"name": "a1"})
755 names2 = {row.name for row in self.query_list(db1, tables1.a.select())}
756 return names1, names2
758 async def side2() -> None:
759 """Other side of the concurrent locking test.
761 This side just waits a bit and then tries to insert a row into the
762 table that the other side is trying to lock. Hopefully that
763 waiting is enough to give the other side a chance to acquire the
764 lock and thus make this side block until the lock is released. If
765 this side manages to do the insert before side1 acquires the lock,
766 we'll just warn about not succeeding at testing the locking,
767 because we can only make that unlikely, not impossible.
768 """
770 def toRunInThread():
771 """Create new SQLite connection for use in thread.
773 SQLite locking isn't asyncio-friendly unless we actually
774 run it in another thread. And SQLite gets very unhappy if
775 we try to use a connection from multiple threads, so we have
776 to create the new connection here instead of out in the main
777 body of the test function.
778 """
779 db2 = self.getNewConnection(db1, writeable=True)
780 with db2.declareStaticTables(create=False) as context:
781 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
782 with db2.transaction():
783 db2.insert(tables2.a, {"name": "a2"})
785 await asyncio.sleep(2.0)
786 loop = asyncio.get_running_loop()
787 with ThreadPoolExecutor() as pool:
788 await loop.run_in_executor(pool, toRunInThread)
790 async def testProblemsWithNoLocking() -> None:
791 """Run side1 and side2 with no locking, attempting to demonstrate
792 the problem that locking is supposed to solve. If we get unlucky
793 with scheduling, side2 will just happen to insert after side1 is
794 done, and we won't have anything definitive. We just warn in that
795 case because we really don't want spurious test failures.
796 """
797 task1 = asyncio.create_task(side1())
798 task2 = asyncio.create_task(side2())
800 names1, names2 = await task1
801 await task2
802 if "a2" in names1:
803 warnings.warn(
804 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.",
805 stacklevel=1,
806 )
807 self.assertEqual(names1, {"a2"})
808 self.assertEqual(names2, {"a1", "a2"})
809 elif "a2" not in names2:
810 warnings.warn(
811 "Unlucky scheduling in no-locking test: concurrent INSERT "
812 "happened after second SELECT even without locking.",
813 stacklevel=1,
814 )
815 self.assertEqual(names1, set())
816 self.assertEqual(names2, {"a1"})
817 else:
818 # This is the expected case: both INSERTS happen between the
819 # two SELECTS. If we don't get this almost all of the time we
820 # should adjust the sleep amounts.
821 self.assertEqual(names1, set())
822 self.assertEqual(names2, {"a1", "a2"})
824 asyncio.run(testProblemsWithNoLocking())
826 # Clean up after first test.
827 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
829 async def testSolutionWithLocking() -> None:
830 """Run side1 and side2 with locking, which should make side2 block
831 its insert until side2 releases its lock.
832 """
833 task1 = asyncio.create_task(side1(lock=[tables1.a]))
834 task2 = asyncio.create_task(side2())
836 names1, names2 = await task1
837 await task2
838 if "a2" in names1:
839 warnings.warn(
840 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.",
841 stacklevel=1,
842 )
843 self.assertEqual(names1, {"a2"})
844 self.assertEqual(names2, {"a1", "a2"})
845 else:
846 # This is the expected case: the side2 INSERT happens after the
847 # last SELECT on side1. This can also happen due to unlucky
848 # scheduling, and we have no way to detect that here, but the
849 # similar "no-locking" test has at least some chance of being
850 # affected by the same problem and warning about it.
851 self.assertEqual(names1, set())
852 self.assertEqual(names2, {"a1"})
854 asyncio.run(testSolutionWithLocking())
856 def testTimespanDatabaseRepresentation(self):
857 """Tests for `TimespanDatabaseRepresentation` and the `Database`
858 methods that interact with it.
859 """
860 # Make some test timespans to play with, with the full suite of
861 # topological relationships.
862 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
863 offset = astropy.time.TimeDelta(60, format="sec")
864 timestamps = [start + offset * n for n in range(3)]
865 aTimespans = [Timespan(begin=None, end=None)]
866 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
867 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
868 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
869 aTimespans.append(Timespan.makeEmpty())
870 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
871 # Make another list of timespans that span the full range but don't
872 # overlap. This is a subset of the previous list.
873 bTimespans = [Timespan(begin=None, end=timestamps[0])]
874 bTimespans.extend(
875 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True)
876 )
877 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
878 # Make a database and create a table with that database's timespan
879 # representation. This one will have no exclusion constraint and
880 # a nullable timespan.
881 db = self.makeEmptyDatabase(origin=1)
882 TimespanReprClass = db.getTimespanRepresentation()
883 aSpec = ddl.TableSpec(
884 fields=[
885 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
886 ],
887 )
888 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
889 aSpec.fields.add(fieldSpec)
890 with db.declareStaticTables(create=True) as context:
891 aTable = context.addTable("a", aSpec)
892 self.maxDiff = None
894 def convertRowForInsert(row: dict) -> dict:
895 """Convert a row containing a Timespan instance into one suitable
896 for insertion into the database.
897 """
898 result = row.copy()
899 ts = result.pop(TimespanReprClass.NAME)
900 return TimespanReprClass.update(ts, result=result)
902 def convertRowFromSelect(row: dict) -> dict:
903 """Convert a row from the database into one containing a Timespan.
905 Parameters
906 ----------
907 row : `dict`
908 Original row.
910 Returns
911 -------
912 row : `dict`
913 The updated row.
914 """
915 result = row.copy()
916 timespan = TimespanReprClass.extract(result)
917 for name in TimespanReprClass.getFieldNames():
918 del result[name]
919 result[TimespanReprClass.NAME] = timespan
920 return result
922 # Insert rows into table A, in chunks just to make things interesting.
923 # Include one with a NULL timespan.
924 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
925 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
926 db.insert(aTable, convertRowForInsert(aRows[0]))
927 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
928 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
929 # Add another one with a NULL timespan, but this time by invoking
930 # the server-side default.
931 aRows.append({"id": len(aRows)})
932 db.insert(aTable, aRows[-1])
933 aRows[-1][TimespanReprClass.NAME] = None
934 # Test basic round-trip through database.
935 self.assertEqual(
936 aRows,
937 [
938 convertRowFromSelect(row._asdict())
939 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id))
940 ],
941 )
942 # Create another table B with a not-null timespan and (if the database
943 # supports it), an exclusion constraint. Use ensureTableExists this
944 # time to check that mode of table creation vs. timespans.
945 bSpec = ddl.TableSpec(
946 fields=[
947 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
948 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
949 ],
950 )
951 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
952 bSpec.fields.add(fieldSpec)
953 if TimespanReprClass.hasExclusionConstraint():
954 bSpec.exclusion.add(("key", TimespanReprClass))
955 bTable = db.ensureTableExists("b", bSpec)
956 # Insert rows into table B, again in chunks. Each Timespan appears
957 # twice, but with different values for the 'key' field (which should
958 # still be okay for any exclusion constraint we may have defined).
959 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
960 offset = len(bRows)
961 bRows.extend(
962 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)
963 )
964 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
965 db.insert(bTable, convertRowForInsert(bRows[2]))
966 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
967 # Insert a row with no timespan into table B. This should invoke the
968 # server-side default, which is a timespan over (-∞, ∞). We set
969 # key=3 to avoid upsetting an exclusion constraint that might exist.
970 bRows.append({"id": len(bRows), "key": 3})
971 db.insert(bTable, bRows[-1])
972 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
973 # Test basic round-trip through database.
974 self.assertEqual(
975 bRows,
976 [
977 convertRowFromSelect(row._asdict())
978 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id))
979 ],
980 )
981 # Test that we can't insert timespan=None into this table.
982 with self.assertRaises(sqlalchemy.exc.IntegrityError):
983 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None}))
984 # IFF this database supports exclusion constraints, test that they
985 # also prevent inserts.
986 if TimespanReprClass.hasExclusionConstraint():
987 with self.assertRaises(sqlalchemy.exc.IntegrityError):
988 db.insert(
989 bTable,
990 convertRowForInsert(
991 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])}
992 ),
993 )
994 with self.assertRaises(sqlalchemy.exc.IntegrityError):
995 db.insert(
996 bTable,
997 convertRowForInsert(
998 {
999 "id": len(bRows),
1000 "key": 1,
1001 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]),
1002 }
1003 ),
1004 )
1005 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1006 db.insert(
1007 bTable,
1008 convertRowForInsert(
1009 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)}
1010 ),
1011 )
1012 # Test NULL checks in SELECT queries, on both tables.
1013 aRepr = TimespanReprClass.from_columns(aTable.columns)
1014 self.assertEqual(
1015 [row[TimespanReprClass.NAME] is None for row in aRows],
1016 [
1017 row.f
1018 for row in self.query_list(
1019 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id)
1020 )
1021 ],
1022 )
1023 bRepr = TimespanReprClass.from_columns(bTable.columns)
1024 self.assertEqual(
1025 [False for row in bRows],
1026 [
1027 row.f
1028 for row in self.query_list(
1029 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id)
1030 )
1031 ],
1032 )
1033 # Test relationships expressions that relate in-database timespans to
1034 # Python-literal timespans, all from the more complete 'a' set; check
1035 # that this is consistent with Python-only relationship tests.
1036 for rhsRow in aRows:
1037 if rhsRow[TimespanReprClass.NAME] is None:
1038 continue
1039 with self.subTest(rhsRow=rhsRow):
1040 expected = {}
1041 for lhsRow in aRows:
1042 if lhsRow[TimespanReprClass.NAME] is None:
1043 expected[lhsRow["id"]] = (None, None, None, None)
1044 else:
1045 expected[lhsRow["id"]] = (
1046 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
1047 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
1048 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
1049 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
1050 )
1051 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
1052 sql = sqlalchemy.sql.select(
1053 aTable.columns.id.label("lhs"),
1054 aRepr.overlaps(rhsRepr).label("overlaps"),
1055 aRepr.contains(rhsRepr).label("contains"),
1056 (aRepr < rhsRepr).label("less_than"),
1057 (aRepr > rhsRepr).label("greater_than"),
1058 ).select_from(aTable)
1059 queried = {
1060 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
1061 for row in self.query_list(db, sql)
1062 }
1063 self.assertEqual(expected, queried)
1064 # Test relationship expressions that relate in-database timespans to
1065 # each other, all from the more complete 'a' set; check that this is
1066 # consistent with Python-only relationship tests.
1067 expected = {}
1068 for lhs, rhs in itertools.product(aRows, aRows):
1069 lhsT = lhs[TimespanReprClass.NAME]
1070 rhsT = rhs[TimespanReprClass.NAME]
1071 if lhsT is not None and rhsT is not None:
1072 expected[lhs["id"], rhs["id"]] = (
1073 lhsT.overlaps(rhsT),
1074 lhsT.contains(rhsT),
1075 lhsT < rhsT,
1076 lhsT > rhsT,
1077 )
1078 else:
1079 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
1080 lhsSubquery = aTable.alias("lhs")
1081 rhsSubquery = aTable.alias("rhs")
1082 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns)
1083 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns)
1084 sql = sqlalchemy.sql.select(
1085 lhsSubquery.columns.id.label("lhs"),
1086 rhsSubquery.columns.id.label("rhs"),
1087 lhsRepr.overlaps(rhsRepr).label("overlaps"),
1088 lhsRepr.contains(rhsRepr).label("contains"),
1089 (lhsRepr < rhsRepr).label("less_than"),
1090 (lhsRepr > rhsRepr).label("greater_than"),
1091 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True)))
1092 queried = {
1093 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
1094 for row in self.query_list(db, sql)
1095 }
1096 self.assertEqual(expected, queried)
1097 # Test relationship expressions between in-database timespans and
1098 # Python-literal instantaneous times.
1099 for t in timestamps:
1100 with self.subTest(t=t):
1101 expected = {}
1102 for lhsRow in aRows:
1103 if lhsRow[TimespanReprClass.NAME] is None:
1104 expected[lhsRow["id"]] = (None, None, None, None)
1105 else:
1106 expected[lhsRow["id"]] = (
1107 lhsRow[TimespanReprClass.NAME].contains(t),
1108 lhsRow[TimespanReprClass.NAME].overlaps(t),
1109 lhsRow[TimespanReprClass.NAME] < t,
1110 lhsRow[TimespanReprClass.NAME] > t,
1111 )
1112 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1113 sql = sqlalchemy.sql.select(
1114 aTable.columns.id.label("lhs"),
1115 aRepr.contains(rhs).label("contains"),
1116 aRepr.overlaps(rhs).label("overlaps_point"),
1117 (aRepr < rhs).label("less_than"),
1118 (aRepr > rhs).label("greater_than"),
1119 ).select_from(aTable)
1120 queried = {
1121 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than)
1122 for row in self.query_list(db, sql)
1123 }
1124 self.assertEqual(expected, queried)
1126 def testConstantRows(self):
1127 """Test Database.constant_rows."""
1128 new_db = self.makeEmptyDatabase()
1129 with new_db.declareStaticTables(create=True) as context:
1130 static = context.addTableTuple(STATIC_TABLE_SPECS)
1131 b_ids = new_db.insert(
1132 static.b,
1133 {"name": "b1", "value": 11},
1134 {"name": "b2", "value": 12},
1135 {"name": "b3", "value": 13},
1136 returnIds=True,
1137 )
1138 values_spec = ddl.TableSpec(
1139 [
1140 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger),
1141 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)),
1142 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()),
1143 ],
1144 )
1145 values_data = [
1146 {"b": b_ids[0], "s": "b1", "r": None},
1147 {"b": b_ids[2], "s": "b3", "r": Circle.empty()},
1148 ]
1149 values = new_db.constant_rows(values_spec.fields, *values_data)
1150 select_values_alone = sqlalchemy.sql.select(
1151 values.columns["b"], values.columns["s"], values.columns["r"]
1152 )
1153 self.assertCountEqual(
1154 [row._mapping for row in self.query_list(new_db, select_values_alone)],
1155 values_data,
1156 )
1157 select_values_joined = sqlalchemy.sql.select(
1158 values.columns["s"].label("name"), static.b.columns["value"].label("value")
1159 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"]))
1160 self.assertCountEqual(
1161 [row._mapping for row in self.query_list(new_db, select_values_joined)],
1162 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}],
1163 )