Coverage for python/lsst/daf/butler/registry/tests/_database.py: 8%
492 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-01 11:00 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-01 11:00 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ["DatabaseTests"]
33import asyncio
34import itertools
35import warnings
36from abc import ABC, abstractmethod
37from collections import namedtuple
38from collections.abc import Iterable
39from concurrent.futures import ThreadPoolExecutor
40from contextlib import AbstractContextManager, contextmanager
41from typing import Any
43import astropy.time
44import sqlalchemy
45from lsst.sphgeom import Circle, ConvexPolygon, UnitVector3d
47from ..._timespan import Timespan
48from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError
50StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
52STATIC_TABLE_SPECS = StaticTablesTuple(
53 a=ddl.TableSpec(
54 fields=[
55 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
56 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
57 ]
58 ),
59 b=ddl.TableSpec(
60 fields=[
61 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
62 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
63 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
64 ],
65 unique=[("name",)],
66 ),
67 c=ddl.TableSpec(
68 fields=[
69 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
70 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
71 ],
72 foreignKeys=[
73 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
74 ],
75 ),
76)
78DYNAMIC_TABLE_SPEC = ddl.TableSpec(
79 fields=[
80 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
81 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
82 ],
83 foreignKeys=[
84 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"),
85 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
86 ],
87)
89TEMPORARY_TABLE_SPEC = ddl.TableSpec(
90 fields=[
91 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
92 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
93 ],
94)
97@contextmanager
98def _patch_getExistingTable(db: Database) -> Database:
99 """Patch getExistingTable method in a database instance to test concurrent
100 creation of tables. This patch obviously depends on knowning internals of
101 ``ensureTableExists()`` implementation.
102 """
103 original_method = db.getExistingTable
105 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None:
106 # Return None on first call, but forward to original method after that
107 db.getExistingTable = original_method
108 return None
110 db.getExistingTable = _getExistingTable
111 yield db
112 db.getExistingTable = original_method
115class DatabaseTests(ABC):
116 """Generic tests for the `Database` interface that can be subclassed to
117 generate tests for concrete implementations.
118 """
120 @abstractmethod
121 def makeEmptyDatabase(self, origin: int = 0) -> Database:
122 """Return an empty `Database` with the given origin, or an
123 automatically-generated one if ``origin`` is `None`.
124 """
125 raise NotImplementedError()
127 @abstractmethod
128 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]:
129 """Return a context manager for a read-only connection into the given
130 database.
132 The original database should be considered unusable within the context
133 but safe to use again afterwards (this allows the context manager to
134 block write access by temporarily changing user permissions to really
135 guarantee that write operations are not performed).
136 """
137 raise NotImplementedError()
139 @abstractmethod
140 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
141 """Return a new `Database` instance that points to the same underlying
142 storage as the given one.
143 """
144 raise NotImplementedError()
146 def query_list(
147 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase
148 ) -> list[sqlalchemy.engine.Row]:
149 """Run a SELECT or other read-only query against the database and
150 return the results as a list.
152 This is a thin wrapper around database.query() that just avoids
153 context-manager boilerplate that is usefully verbose in production code
154 but just noise in tests.
155 """
156 with database.transaction(), database.query(executable) as result:
157 return result.fetchall()
159 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any:
160 """Run a SELECT query that yields a single column and row against the
161 database and return its value.
163 This is a thin wrapper around database.query() that just avoids
164 context-manager boilerplate that is usefully verbose in production code
165 but just noise in tests.
166 """
167 with database.query(executable) as result:
168 return result.scalar()
170 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
171 self.assertCountEqual(spec.fields.names, table.columns.keys())
172 # Checking more than this currently seems fragile, as it might restrict
173 # what Database implementations do; we don't care if the spec is
174 # actually preserved in terms of types and constraints as long as we
175 # can use the returned table as if it was.
177 def checkStaticSchema(self, tables: StaticTablesTuple):
178 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
179 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
180 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
182 def testDeclareStaticTables(self):
183 """Tests for `Database.declareStaticSchema` and the methods it
184 delegates to.
185 """
186 # Create the static schema in a new, empty database.
187 newDatabase = self.makeEmptyDatabase()
188 with newDatabase.declareStaticTables(create=True) as context:
189 tables = context.addTableTuple(STATIC_TABLE_SPECS)
190 self.checkStaticSchema(tables)
191 # Check that we can load that schema even from a read-only connection.
192 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
193 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
194 tables = context.addTableTuple(STATIC_TABLE_SPECS)
195 self.checkStaticSchema(tables)
197 def testDeclareStaticTablesTwice(self):
198 """Tests for `Database.declareStaticSchema` being called twice."""
199 # Create the static schema in a new, empty database.
200 newDatabase = self.makeEmptyDatabase()
201 with newDatabase.declareStaticTables(create=True) as context:
202 tables = context.addTableTuple(STATIC_TABLE_SPECS)
203 self.checkStaticSchema(tables)
204 # Second time it should raise
205 with self.assertRaises(SchemaAlreadyDefinedError):
206 with newDatabase.declareStaticTables(create=True) as context:
207 tables = context.addTableTuple(STATIC_TABLE_SPECS)
208 # Check schema, it should still contain all tables, and maybe some
209 # extra.
210 with newDatabase.declareStaticTables(create=False) as context:
211 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
213 def testRepr(self):
214 """Test that repr does not return a generic thing."""
215 newDatabase = self.makeEmptyDatabase()
216 rep = repr(newDatabase)
217 # Check that stringification works and gives us something different
218 self.assertNotEqual(rep, str(newDatabase))
219 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
220 self.assertIn("://", rep)
222 def testDynamicTables(self):
223 """Tests for `Database.ensureTableExists` and
224 `Database.getExistingTable`.
225 """
226 # Need to start with the static schema.
227 newDatabase = self.makeEmptyDatabase()
228 with newDatabase.declareStaticTables(create=True) as context:
229 context.addTableTuple(STATIC_TABLE_SPECS)
230 # Try to ensure the dynamic table exists in a read-only version of that
231 # database, which should fail because we can't create it.
232 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
233 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
234 context.addTableTuple(STATIC_TABLE_SPECS)
235 with self.assertRaises(ReadOnlyDatabaseError):
236 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
237 # Just getting the dynamic table before it exists should return None.
238 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
239 # Ensure the new table exists back in the original database, which
240 # should create it.
241 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
242 self.checkTable(DYNAMIC_TABLE_SPEC, table)
243 # Ensuring that it exists should just return the exact same table
244 # instance again.
245 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
246 # Try again from the read-only database.
247 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
248 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
249 context.addTableTuple(STATIC_TABLE_SPECS)
250 # Just getting the dynamic table should now work...
251 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
252 # ...as should ensuring that it exists, since it now does.
253 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
254 self.checkTable(DYNAMIC_TABLE_SPEC, table)
255 # Trying to get the table with a different specification (at least
256 # in terms of what columns are present) should raise.
257 with self.assertRaises(DatabaseConflictError):
258 newDatabase.ensureTableExists(
259 "d",
260 ddl.TableSpec(
261 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
262 ),
263 )
264 # Calling ensureTableExists inside a transaction block is an error,
265 # even if it would do nothing.
266 with newDatabase.transaction():
267 with self.assertRaises(AssertionError):
268 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
270 def testDynamicTablesConcurrency(self):
271 """Tests for `Database.ensureTableExists` concurrent use."""
272 # We cannot really run things concurrently in a deterministic way, here
273 # we just simulate a situation when the table is created by other
274 # process between the call to getExistingTable() and actual table
275 # creation.
276 db1 = self.makeEmptyDatabase()
277 with db1.declareStaticTables(create=True) as context:
278 context.addTableTuple(STATIC_TABLE_SPECS)
279 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC))
281 # Make a dynamic table using separate connection
282 db2 = self.getNewConnection(db1, writeable=True)
283 with db2.declareStaticTables(create=False) as context:
284 context.addTableTuple(STATIC_TABLE_SPECS)
285 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
286 self.checkTable(DYNAMIC_TABLE_SPEC, table)
288 # Call it again but trick it into thinking that table is not there.
289 # This test depends on knowing implementation of ensureTableExists()
290 # which initially calls getExistingTable() to check that table may
291 # exist, the patch intercepts that call and returns None.
292 with _patch_getExistingTable(db1):
293 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
295 def testTemporaryTables(self):
296 """Tests for `Database.temporary_table`, and `Database.insert` with the
297 ``select`` argument.
298 """
299 # Need to start with the static schema; also insert some test data.
300 newDatabase = self.makeEmptyDatabase()
301 with newDatabase.declareStaticTables(create=True) as context:
302 static = context.addTableTuple(STATIC_TABLE_SPECS)
303 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None})
304 bIds = newDatabase.insert(
305 static.b,
306 {"name": "b1", "value": 11},
307 {"name": "b2", "value": 12},
308 {"name": "b3", "value": 13},
309 returnIds=True,
310 )
311 # Create the table.
312 with newDatabase.session():
313 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1:
314 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
315 # Insert via a INSERT INTO ... SELECT query.
316 newDatabase.insert(
317 table1,
318 select=sqlalchemy.sql.select(
319 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
320 )
321 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
322 .where(
323 sqlalchemy.sql.and_(
324 static.a.columns.name == "a1",
325 static.b.columns.value <= 12,
326 )
327 ),
328 )
329 # Check that the inserted rows are present.
330 self.assertCountEqual(
331 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
332 [row._asdict() for row in self.query_list(newDatabase, table1.select())],
333 )
334 # Create another one via a read-only connection to the
335 # database. We _do_ allow temporary table modifications in
336 # read-only databases.
337 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
338 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
339 context.addTableTuple(STATIC_TABLE_SPECS)
340 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2:
341 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
342 # Those tables should not be the same, despite having
343 # the same ddl.
344 self.assertIsNot(table1, table2)
345 # Do a slightly different insert into this table, to
346 # check that it works in a read-only database. This
347 # time we pass column names as a kwarg to insert
348 # instead of by labeling the columns in the select.
349 existingReadOnlyDatabase.insert(
350 table2,
351 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id)
352 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
353 .where(
354 sqlalchemy.sql.and_(
355 static.a.columns.name == "a2",
356 static.b.columns.value >= 12,
357 )
358 ),
359 names=["a_name", "b_id"],
360 )
361 # Check that the inserted rows are present.
362 self.assertCountEqual(
363 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
364 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())],
365 )
366 # Exiting the context managers will drop the temporary tables from
367 # the read-only DB. It's unspecified whether attempting to use it
368 # after this point is an error or just never returns any results,
369 # so we can't test what it does, only that it's not an error.
371 def testSchemaSeparation(self):
372 """Test that creating two different `Database` instances allows us
373 to create different tables with the same name in each.
374 """
375 db1 = self.makeEmptyDatabase(origin=1)
376 with db1.declareStaticTables(create=True) as context:
377 tables = context.addTableTuple(STATIC_TABLE_SPECS)
378 self.checkStaticSchema(tables)
380 db2 = self.makeEmptyDatabase(origin=2)
381 # Make the DDL here intentionally different so we'll definitely
382 # notice if db1 and db2 are pointing at the same schema.
383 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
384 with db2.declareStaticTables(create=True) as context:
385 # Make the DDL here intentionally different so we'll definitely
386 # notice if db1 and db2 are pointing at the same schema.
387 table = context.addTable("a", spec)
388 self.checkTable(spec, table)
390 def testInsertQueryDelete(self):
391 """Test the `Database.insert`, `Database.query`, and `Database.delete`
392 methods, as well as the `Base64Region` type and the ``onDelete``
393 argument to `ddl.ForeignKeySpec`.
394 """
395 db = self.makeEmptyDatabase(origin=1)
396 with db.declareStaticTables(create=True) as context:
397 tables = context.addTableTuple(STATIC_TABLE_SPECS)
398 # Insert a single, non-autoincrement row that contains a region and
399 # query to get it back.
400 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
401 row = {"name": "a1", "region": region}
402 db.insert(tables.a, row)
403 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row])
404 # Insert multiple autoincrement rows but do not try to get the IDs
405 # back immediately.
406 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
407 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))]
408 self.assertEqual(len(results), 2)
409 for row in results:
410 self.assertIn(row["name"], ("b1", "b2"))
411 self.assertIsInstance(row["id"], int)
412 self.assertGreater(results[1]["id"], results[0]["id"])
413 # Insert multiple autoincrement rows and get the IDs back from insert.
414 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
415 ids = db.insert(tables.b, *rows, returnIds=True)
416 results = [
417 r._asdict()
418 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"]))
419 ]
420 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
421 self.assertCountEqual(results, expected)
422 self.assertTrue(all(result["id"] is not None for result in results))
423 # Insert multiple rows into a table with an autoincrement primary key,
424 # then use the returned IDs to insert into a dynamic table.
425 rows = [{"b_id": results[0]["id"]}, {"b_id": None}]
426 ids = db.insert(tables.c, *rows, returnIds=True)
427 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
428 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
429 self.assertCountEqual(results, expected)
430 self.assertTrue(all(result["id"] is not None for result in results))
431 # Add the dynamic table.
432 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
433 # Insert into it.
434 rows = [{"c_id": id, "a_name": "a1"} for id in ids]
435 db.insert(d, *rows)
436 results = [r._asdict() for r in self.query_list(db, d.select())]
437 self.assertCountEqual(rows, results)
438 # Insert multiple rows into a table with an autoincrement primary key,
439 # but pass in a value for the autoincrement key.
440 rows2 = [
441 {"id": 700, "b_id": None},
442 {"id": 701, "b_id": None},
443 ]
444 db.insert(tables.c, *rows2)
445 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
446 self.assertCountEqual(results, expected + rows2)
447 self.assertTrue(all(result["id"] is not None for result in results))
449 # Define 'SELECT COUNT(*)' query for later use.
450 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
451 # Get the values we inserted into table b.
452 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())]
453 # Remove two row from table b by ID.
454 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
455 self.assertEqual(n, 2)
456 # Remove the other two rows from table b by name.
457 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
458 self.assertEqual(n, 2)
459 # There should now be no rows in table b.
460 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
461 # All b_id values in table c should now be NULL, because there's an
462 # onDelete='SET NULL' foreign key.
463 self.assertEqual(
464 self.query_scalar(
465 db,
466 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711
467 ),
468 0,
469 )
470 # Remove all rows in table a (there's only one); this should remove all
471 # rows in d due to onDelete='CASCADE'.
472 n = db.delete(tables.a, [])
473 self.assertEqual(n, 1)
474 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0)
475 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0)
477 def testDeleteWhere(self):
478 """Tests for `Database.deleteWhere`."""
479 db = self.makeEmptyDatabase(origin=1)
480 with db.declareStaticTables(create=True) as context:
481 tables = context.addTableTuple(STATIC_TABLE_SPECS)
482 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
483 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
485 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
486 self.assertEqual(n, 3)
487 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7)
489 n = db.deleteWhere(
490 tables.b,
491 tables.b.columns.id.in_(
492 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
493 ),
494 )
495 self.assertEqual(n, 4)
496 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3)
498 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
499 self.assertEqual(n, 1)
500 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2)
502 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
503 self.assertEqual(n, 2)
504 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
506 def testUpdate(self):
507 """Tests for `Database.update`."""
508 db = self.makeEmptyDatabase(origin=1)
509 with db.declareStaticTables(create=True) as context:
510 tables = context.addTableTuple(STATIC_TABLE_SPECS)
511 # Insert two rows into table a, both without regions.
512 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
513 # Update one of the rows with a region.
514 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
515 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
516 self.assertEqual(n, 1)
517 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
518 self.assertCountEqual(
519 [r._asdict() for r in self.query_list(db, sql)],
520 [{"name": "a1", "region": None}, {"name": "a2", "region": region}],
521 )
523 def testSync(self):
524 """Tests for `Database.sync`."""
525 db = self.makeEmptyDatabase(origin=1)
526 with db.declareStaticTables(create=True) as context:
527 tables = context.addTableTuple(STATIC_TABLE_SPECS)
528 # Insert a row with sync, because it doesn't exist yet.
529 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
530 self.assertTrue(inserted)
531 self.assertEqual(
532 [{"id": values["id"], "name": "b1", "value": 10}],
533 [r._asdict() for r in self.query_list(db, tables.b.select())],
534 )
535 # Repeat that operation, which should do nothing but return the
536 # requested values.
537 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
538 self.assertFalse(inserted)
539 self.assertEqual(
540 [{"id": values["id"], "name": "b1", "value": 10}],
541 [r._asdict() for r in self.query_list(db, tables.b.select())],
542 )
543 # Repeat the operation without the 'extra' arg, which should also just
544 # return the existing row.
545 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
546 self.assertFalse(inserted)
547 self.assertEqual(
548 [{"id": values["id"], "name": "b1", "value": 10}],
549 [r._asdict() for r in self.query_list(db, tables.b.select())],
550 )
551 # Repeat the operation with a different value in 'extra'. That still
552 # shouldn't be an error, because 'extra' is only used if we really do
553 # insert. Also drop the 'returning' argument.
554 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
555 self.assertFalse(inserted)
556 self.assertEqual(
557 [{"id": values["id"], "name": "b1", "value": 10}],
558 [r._asdict() for r in self.query_list(db, tables.b.select())],
559 )
560 # Repeat the operation with the correct value in 'compared' instead of
561 # 'extra'.
562 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
563 self.assertFalse(inserted)
564 self.assertEqual(
565 [{"id": values["id"], "name": "b1", "value": 10}],
566 [r._asdict() for r in self.query_list(db, tables.b.select())],
567 )
568 # Repeat the operation with an incorrect value in 'compared'; this
569 # should raise.
570 with self.assertRaises(DatabaseConflictError):
571 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
572 # Try to sync in a read-only database. This should work if and only
573 # if the matching row already exists.
574 with self.asReadOnly(db) as rodb:
575 with rodb.declareStaticTables(create=False) as context:
576 tables = context.addTableTuple(STATIC_TABLE_SPECS)
577 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
578 self.assertFalse(inserted)
579 self.assertEqual(
580 [{"id": values["id"], "name": "b1", "value": 10}],
581 [r._asdict() for r in self.query_list(rodb, tables.b.select())],
582 )
583 with self.assertRaises(ReadOnlyDatabaseError):
584 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
585 # Repeat the operation with a different value in 'compared' and ask to
586 # update.
587 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
588 self.assertEqual(updated, {"value": 10})
589 self.assertEqual(
590 [{"id": values["id"], "name": "b1", "value": 20}],
591 [r._asdict() for r in self.query_list(db, tables.b.select())],
592 )
594 def testReplace(self):
595 """Tests for `Database.replace`."""
596 db = self.makeEmptyDatabase(origin=1)
597 with db.declareStaticTables(create=True) as context:
598 tables = context.addTableTuple(STATIC_TABLE_SPECS)
599 # Use 'replace' to insert a single row that contains a region and
600 # query to get it back.
601 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
602 row1 = {"name": "a1", "region": region}
603 db.replace(tables.a, row1)
604 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
605 # Insert another row without a region.
606 row2 = {"name": "a2", "region": None}
607 db.replace(tables.a, row2)
608 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
609 # Use replace to re-insert both of those rows again, which should do
610 # nothing.
611 db.replace(tables.a, row1, row2)
612 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
613 # Replace row1 with a row with no region, while reinserting row2.
614 row1a = {"name": "a1", "region": None}
615 db.replace(tables.a, row1a, row2)
616 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2])
617 # Replace both rows, returning row1 to its original state, while adding
618 # a new one. Pass them in in a different order.
619 row2a = {"name": "a2", "region": region}
620 row3 = {"name": "a3", "region": None}
621 db.replace(tables.a, row3, row2a, row1)
622 self.assertCountEqual(
623 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3]
624 )
626 def testEnsure(self):
627 """Tests for `Database.ensure`."""
628 db = self.makeEmptyDatabase(origin=1)
629 with db.declareStaticTables(create=True) as context:
630 tables = context.addTableTuple(STATIC_TABLE_SPECS)
631 # Use 'ensure' to insert a single row that contains a region and
632 # query to get it back.
633 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
634 row1 = {"name": "a1", "region": region}
635 self.assertEqual(db.ensure(tables.a, row1), 1)
636 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
637 # Insert another row without a region.
638 row2 = {"name": "a2", "region": None}
639 self.assertEqual(db.ensure(tables.a, row2), 1)
640 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
641 # Use ensure to re-insert both of those rows again, which should do
642 # nothing.
643 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
644 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
645 # Attempt to insert row1's key with no region, while
646 # reinserting row2. This should also do nothing.
647 row1a = {"name": "a1", "region": None}
648 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
649 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
650 # Attempt to insert new rows for both existing keys, this time also
651 # adding a new row. Pass them in in a different order. Only the new
652 # row should be added.
653 row2a = {"name": "a2", "region": region}
654 row3 = {"name": "a3", "region": None}
655 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
656 self.assertCountEqual(
657 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3]
658 )
659 # Add some data to a table with both a primary key and a different
660 # unique constraint.
661 row_b = {"id": 5, "name": "five", "value": 50}
662 db.insert(tables.b, row_b)
663 # Attempt ensure with primary_key_only=False and a conflict for the
664 # non-PK constraint. This should do nothing.
665 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200})
666 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
667 # Now use primary_key_only=True with conflict in only the non-PK field.
668 # This should be an integrity error and nothing should change.
669 with self.assertRaises(sqlalchemy.exc.IntegrityError):
670 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True)
671 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
672 # With primary_key_only=True a conflict in the primary key is ignored
673 # regardless of whether there is a conflict elsewhere.
674 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True)
675 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
676 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True)
677 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
679 def testTransactionNesting(self):
680 """Test that transactions can be nested with the behavior in the
681 presence of exceptions working as documented.
682 """
683 db = self.makeEmptyDatabase(origin=1)
684 with db.declareStaticTables(create=True) as context:
685 tables = context.addTableTuple(STATIC_TABLE_SPECS)
686 # Insert one row so we can trigger integrity errors by trying to insert
687 # a duplicate of it below.
688 db.insert(tables.a, {"name": "a1"})
689 # First test: error recovery via explicit savepoint=True in the inner
690 # transaction.
691 with db.transaction():
692 # This insert should succeed, and should not be rolled back because
693 # the assertRaises context should catch any exception before it
694 # propagates up to the outer transaction.
695 db.insert(tables.a, {"name": "a2"})
696 with self.assertRaises(sqlalchemy.exc.IntegrityError):
697 with db.transaction(savepoint=True):
698 # This insert should succeed, but should be rolled back.
699 db.insert(tables.a, {"name": "a4"})
700 # This insert should fail (duplicate primary key), raising
701 # an exception.
702 db.insert(tables.a, {"name": "a1"})
703 self.assertCountEqual(
704 [r._asdict() for r in self.query_list(db, tables.a.select())],
705 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
706 )
707 # Second test: error recovery via implicit savepoint=True, when the
708 # innermost transaction is inside a savepoint=True transaction.
709 with db.transaction():
710 # This insert should succeed, and should not be rolled back
711 # because the assertRaises context should catch any
712 # exception before it propagates up to the outer
713 # transaction.
714 db.insert(tables.a, {"name": "a3"})
715 with self.assertRaises(sqlalchemy.exc.IntegrityError):
716 with db.transaction(savepoint=True):
717 # This insert should succeed, but should be rolled back.
718 db.insert(tables.a, {"name": "a4"})
719 with db.transaction():
720 # This insert should succeed, but should be rolled
721 # back.
722 db.insert(tables.a, {"name": "a5"})
723 # This insert should fail (duplicate primary key),
724 # raising an exception.
725 db.insert(tables.a, {"name": "a1"})
726 self.assertCountEqual(
727 [r._asdict() for r in self.query_list(db, tables.a.select())],
728 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
729 )
731 def testTransactionLocking(self):
732 """Test that `Database.transaction` can be used to acquire a lock
733 that prohibits concurrent writes.
734 """
735 db1 = self.makeEmptyDatabase(origin=1)
736 with db1.declareStaticTables(create=True) as context:
737 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
739 async def side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]:
740 """One side of the concurrent locking test.
742 This optionally locks the table (and maybe the whole database),
743 does a select for its contents, inserts a new row, and then selects
744 again, with some waiting in between to make sure the other side has
745 a chance to _attempt_ to insert in between. If the locking is
746 enabled and works, the difference between the selects should just
747 be the insert done on this thread.
748 """
749 # Give Side2 a chance to create a connection
750 await asyncio.sleep(1.0)
751 with db1.transaction(lock=lock):
752 names1 = {row.name for row in self.query_list(db1, tables1.a.select())}
753 # Give Side2 a chance to insert (which will be blocked if
754 # we've acquired a lock).
755 await asyncio.sleep(2.0)
756 db1.insert(tables1.a, {"name": "a1"})
757 names2 = {row.name for row in self.query_list(db1, tables1.a.select())}
758 return names1, names2
760 async def side2() -> None:
761 """Other side of the concurrent locking test.
763 This side just waits a bit and then tries to insert a row into the
764 table that the other side is trying to lock. Hopefully that
765 waiting is enough to give the other side a chance to acquire the
766 lock and thus make this side block until the lock is released. If
767 this side manages to do the insert before side1 acquires the lock,
768 we'll just warn about not succeeding at testing the locking,
769 because we can only make that unlikely, not impossible.
770 """
772 def toRunInThread():
773 """Create new SQLite connection for use in thread.
775 SQLite locking isn't asyncio-friendly unless we actually
776 run it in another thread. And SQLite gets very unhappy if
777 we try to use a connection from multiple threads, so we have
778 to create the new connection here instead of out in the main
779 body of the test function.
780 """
781 db2 = self.getNewConnection(db1, writeable=True)
782 with db2.declareStaticTables(create=False) as context:
783 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
784 with db2.transaction():
785 db2.insert(tables2.a, {"name": "a2"})
787 await asyncio.sleep(2.0)
788 loop = asyncio.get_running_loop()
789 with ThreadPoolExecutor() as pool:
790 await loop.run_in_executor(pool, toRunInThread)
792 async def testProblemsWithNoLocking() -> None:
793 """Run side1 and side2 with no locking, attempting to demonstrate
794 the problem that locking is supposed to solve. If we get unlucky
795 with scheduling, side2 will just happen to insert after side1 is
796 done, and we won't have anything definitive. We just warn in that
797 case because we really don't want spurious test failures.
798 """
799 task1 = asyncio.create_task(side1())
800 task2 = asyncio.create_task(side2())
802 names1, names2 = await task1
803 await task2
804 if "a2" in names1:
805 warnings.warn(
806 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.",
807 stacklevel=1,
808 )
809 self.assertEqual(names1, {"a2"})
810 self.assertEqual(names2, {"a1", "a2"})
811 elif "a2" not in names2:
812 warnings.warn(
813 "Unlucky scheduling in no-locking test: concurrent INSERT "
814 "happened after second SELECT even without locking.",
815 stacklevel=1,
816 )
817 self.assertEqual(names1, set())
818 self.assertEqual(names2, {"a1"})
819 else:
820 # This is the expected case: both INSERTS happen between the
821 # two SELECTS. If we don't get this almost all of the time we
822 # should adjust the sleep amounts.
823 self.assertEqual(names1, set())
824 self.assertEqual(names2, {"a1", "a2"})
826 asyncio.run(testProblemsWithNoLocking())
828 # Clean up after first test.
829 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
831 async def testSolutionWithLocking() -> None:
832 """Run side1 and side2 with locking, which should make side2 block
833 its insert until side2 releases its lock.
834 """
835 task1 = asyncio.create_task(side1(lock=[tables1.a]))
836 task2 = asyncio.create_task(side2())
838 names1, names2 = await task1
839 await task2
840 if "a2" in names1:
841 warnings.warn(
842 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.",
843 stacklevel=1,
844 )
845 self.assertEqual(names1, {"a2"})
846 self.assertEqual(names2, {"a1", "a2"})
847 else:
848 # This is the expected case: the side2 INSERT happens after the
849 # last SELECT on side1. This can also happen due to unlucky
850 # scheduling, and we have no way to detect that here, but the
851 # similar "no-locking" test has at least some chance of being
852 # affected by the same problem and warning about it.
853 self.assertEqual(names1, set())
854 self.assertEqual(names2, {"a1"})
856 asyncio.run(testSolutionWithLocking())
858 def testTimespanDatabaseRepresentation(self):
859 """Tests for `TimespanDatabaseRepresentation` and the `Database`
860 methods that interact with it.
861 """
862 # Make some test timespans to play with, with the full suite of
863 # topological relationships.
864 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
865 offset = astropy.time.TimeDelta(60, format="sec")
866 timestamps = [start + offset * n for n in range(3)]
867 aTimespans = [Timespan(begin=None, end=None)]
868 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
869 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
870 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
871 aTimespans.append(Timespan.makeEmpty())
872 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
873 # Make another list of timespans that span the full range but don't
874 # overlap. This is a subset of the previous list.
875 bTimespans = [Timespan(begin=None, end=timestamps[0])]
876 bTimespans.extend(
877 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True)
878 )
879 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
880 # Make a database and create a table with that database's timespan
881 # representation. This one will have no exclusion constraint and
882 # a nullable timespan.
883 db = self.makeEmptyDatabase(origin=1)
884 TimespanReprClass = db.getTimespanRepresentation()
885 aSpec = ddl.TableSpec(
886 fields=[
887 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
888 ],
889 )
890 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
891 aSpec.fields.add(fieldSpec)
892 with db.declareStaticTables(create=True) as context:
893 aTable = context.addTable("a", aSpec)
894 self.maxDiff = None
896 def convertRowForInsert(row: dict) -> dict:
897 """Convert a row containing a Timespan instance into one suitable
898 for insertion into the database.
899 """
900 result = row.copy()
901 ts = result.pop(TimespanReprClass.NAME)
902 return TimespanReprClass.update(ts, result=result)
904 def convertRowFromSelect(row: dict) -> dict:
905 """Convert a row from the database into one containing a Timespan.
907 Parameters
908 ----------
909 row : `dict`
910 Original row.
912 Returns
913 -------
914 row : `dict`
915 The updated row.
916 """
917 result = row.copy()
918 timespan = TimespanReprClass.extract(result)
919 for name in TimespanReprClass.getFieldNames():
920 del result[name]
921 result[TimespanReprClass.NAME] = timespan
922 return result
924 # Insert rows into table A, in chunks just to make things interesting.
925 # Include one with a NULL timespan.
926 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
927 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
928 db.insert(aTable, convertRowForInsert(aRows[0]))
929 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
930 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
931 # Add another one with a NULL timespan, but this time by invoking
932 # the server-side default.
933 aRows.append({"id": len(aRows)})
934 db.insert(aTable, aRows[-1])
935 aRows[-1][TimespanReprClass.NAME] = None
936 # Test basic round-trip through database.
937 self.assertEqual(
938 aRows,
939 [
940 convertRowFromSelect(row._asdict())
941 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id))
942 ],
943 )
944 # Create another table B with a not-null timespan and (if the database
945 # supports it), an exclusion constraint. Use ensureTableExists this
946 # time to check that mode of table creation vs. timespans.
947 bSpec = ddl.TableSpec(
948 fields=[
949 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
950 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
951 ],
952 )
953 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
954 bSpec.fields.add(fieldSpec)
955 if TimespanReprClass.hasExclusionConstraint():
956 bSpec.exclusion.add(("key", TimespanReprClass))
957 bTable = db.ensureTableExists("b", bSpec)
958 # Insert rows into table B, again in chunks. Each Timespan appears
959 # twice, but with different values for the 'key' field (which should
960 # still be okay for any exclusion constraint we may have defined).
961 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
962 offset = len(bRows)
963 bRows.extend(
964 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)
965 )
966 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
967 db.insert(bTable, convertRowForInsert(bRows[2]))
968 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
969 # Insert a row with no timespan into table B. This should invoke the
970 # server-side default, which is a timespan over (-∞, ∞). We set
971 # key=3 to avoid upsetting an exclusion constraint that might exist.
972 bRows.append({"id": len(bRows), "key": 3})
973 db.insert(bTable, bRows[-1])
974 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
975 # Test basic round-trip through database.
976 self.assertEqual(
977 bRows,
978 [
979 convertRowFromSelect(row._asdict())
980 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id))
981 ],
982 )
983 # Test that we can't insert timespan=None into this table.
984 with self.assertRaises(sqlalchemy.exc.IntegrityError):
985 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None}))
986 # IFF this database supports exclusion constraints, test that they
987 # also prevent inserts.
988 if TimespanReprClass.hasExclusionConstraint():
989 with self.assertRaises(sqlalchemy.exc.IntegrityError):
990 db.insert(
991 bTable,
992 convertRowForInsert(
993 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])}
994 ),
995 )
996 with self.assertRaises(sqlalchemy.exc.IntegrityError):
997 db.insert(
998 bTable,
999 convertRowForInsert(
1000 {
1001 "id": len(bRows),
1002 "key": 1,
1003 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]),
1004 }
1005 ),
1006 )
1007 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1008 db.insert(
1009 bTable,
1010 convertRowForInsert(
1011 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)}
1012 ),
1013 )
1014 # Test NULL checks in SELECT queries, on both tables.
1015 aRepr = TimespanReprClass.from_columns(aTable.columns)
1016 self.assertEqual(
1017 [row[TimespanReprClass.NAME] is None for row in aRows],
1018 [
1019 row.f
1020 for row in self.query_list(
1021 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id)
1022 )
1023 ],
1024 )
1025 bRepr = TimespanReprClass.from_columns(bTable.columns)
1026 self.assertEqual(
1027 [False for row in bRows],
1028 [
1029 row.f
1030 for row in self.query_list(
1031 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id)
1032 )
1033 ],
1034 )
1035 # Test relationships expressions that relate in-database timespans to
1036 # Python-literal timespans, all from the more complete 'a' set; check
1037 # that this is consistent with Python-only relationship tests.
1038 for rhsRow in aRows:
1039 if rhsRow[TimespanReprClass.NAME] is None:
1040 continue
1041 with self.subTest(rhsRow=rhsRow):
1042 expected = {}
1043 for lhsRow in aRows:
1044 if lhsRow[TimespanReprClass.NAME] is None:
1045 expected[lhsRow["id"]] = (None, None, None, None)
1046 else:
1047 expected[lhsRow["id"]] = (
1048 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
1049 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
1050 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
1051 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
1052 )
1053 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
1054 sql = sqlalchemy.sql.select(
1055 aTable.columns.id.label("lhs"),
1056 aRepr.overlaps(rhsRepr).label("overlaps"),
1057 aRepr.contains(rhsRepr).label("contains"),
1058 (aRepr < rhsRepr).label("less_than"),
1059 (aRepr > rhsRepr).label("greater_than"),
1060 ).select_from(aTable)
1061 queried = {
1062 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
1063 for row in self.query_list(db, sql)
1064 }
1065 self.assertEqual(expected, queried)
1066 # Test relationship expressions that relate in-database timespans to
1067 # each other, all from the more complete 'a' set; check that this is
1068 # consistent with Python-only relationship tests.
1069 expected = {}
1070 for lhs, rhs in itertools.product(aRows, aRows):
1071 lhsT = lhs[TimespanReprClass.NAME]
1072 rhsT = rhs[TimespanReprClass.NAME]
1073 if lhsT is not None and rhsT is not None:
1074 expected[lhs["id"], rhs["id"]] = (
1075 lhsT.overlaps(rhsT),
1076 lhsT.contains(rhsT),
1077 lhsT < rhsT,
1078 lhsT > rhsT,
1079 )
1080 else:
1081 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
1082 lhsSubquery = aTable.alias("lhs")
1083 rhsSubquery = aTable.alias("rhs")
1084 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns)
1085 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns)
1086 sql = sqlalchemy.sql.select(
1087 lhsSubquery.columns.id.label("lhs"),
1088 rhsSubquery.columns.id.label("rhs"),
1089 lhsRepr.overlaps(rhsRepr).label("overlaps"),
1090 lhsRepr.contains(rhsRepr).label("contains"),
1091 (lhsRepr < rhsRepr).label("less_than"),
1092 (lhsRepr > rhsRepr).label("greater_than"),
1093 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True)))
1094 queried = {
1095 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
1096 for row in self.query_list(db, sql)
1097 }
1098 self.assertEqual(expected, queried)
1099 # Test relationship expressions between in-database timespans and
1100 # Python-literal instantaneous times.
1101 for t in timestamps:
1102 with self.subTest(t=t):
1103 expected = {}
1104 for lhsRow in aRows:
1105 if lhsRow[TimespanReprClass.NAME] is None:
1106 expected[lhsRow["id"]] = (None, None, None, None)
1107 else:
1108 expected[lhsRow["id"]] = (
1109 lhsRow[TimespanReprClass.NAME].contains(t),
1110 lhsRow[TimespanReprClass.NAME].overlaps(t),
1111 lhsRow[TimespanReprClass.NAME] < t,
1112 lhsRow[TimespanReprClass.NAME] > t,
1113 )
1114 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1115 sql = sqlalchemy.sql.select(
1116 aTable.columns.id.label("lhs"),
1117 aRepr.contains(rhs).label("contains"),
1118 aRepr.overlaps(rhs).label("overlaps_point"),
1119 (aRepr < rhs).label("less_than"),
1120 (aRepr > rhs).label("greater_than"),
1121 ).select_from(aTable)
1122 queried = {
1123 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than)
1124 for row in self.query_list(db, sql)
1125 }
1126 self.assertEqual(expected, queried)
1128 def testConstantRows(self):
1129 """Test Database.constant_rows."""
1130 new_db = self.makeEmptyDatabase()
1131 with new_db.declareStaticTables(create=True) as context:
1132 static = context.addTableTuple(STATIC_TABLE_SPECS)
1133 b_ids = new_db.insert(
1134 static.b,
1135 {"name": "b1", "value": 11},
1136 {"name": "b2", "value": 12},
1137 {"name": "b3", "value": 13},
1138 returnIds=True,
1139 )
1140 values_spec = ddl.TableSpec(
1141 [
1142 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger),
1143 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)),
1144 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()),
1145 ],
1146 )
1147 values_data = [
1148 {"b": b_ids[0], "s": "b1", "r": None},
1149 {"b": b_ids[2], "s": "b3", "r": Circle.empty()},
1150 ]
1151 values = new_db.constant_rows(values_spec.fields, *values_data)
1152 select_values_alone = sqlalchemy.sql.select(
1153 values.columns["b"], values.columns["s"], values.columns["r"]
1154 )
1155 self.assertCountEqual(
1156 [row._mapping for row in self.query_list(new_db, select_values_alone)],
1157 values_data,
1158 )
1159 select_values_joined = sqlalchemy.sql.select(
1160 values.columns["s"].label("name"), static.b.columns["value"].label("value")
1161 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"]))
1162 self.assertCountEqual(
1163 [row._mapping for row in self.query_list(new_db, select_values_joined)],
1164 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}],
1165 )