Coverage for python/lsst/daf/butler/registry/tests/_database.py: 6%
493 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-05 10:36 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-05 10:36 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25import asyncio
26import itertools
27import warnings
28from abc import ABC, abstractmethod
29from collections import namedtuple
30from concurrent.futures import ThreadPoolExecutor
31from contextlib import contextmanager
32from typing import Any, ContextManager, Iterable, Optional, Set, Tuple
34import astropy.time
35import sqlalchemy
36from lsst.sphgeom import Circle, ConvexPolygon, UnitVector3d
38from ...core import Timespan, ddl
39from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError
41StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
43STATIC_TABLE_SPECS = StaticTablesTuple(
44 a=ddl.TableSpec(
45 fields=[
46 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
47 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
48 ]
49 ),
50 b=ddl.TableSpec(
51 fields=[
52 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
53 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
54 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
55 ],
56 unique=[("name",)],
57 ),
58 c=ddl.TableSpec(
59 fields=[
60 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
61 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
62 ],
63 foreignKeys=[
64 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
65 ],
66 ),
67)
69DYNAMIC_TABLE_SPEC = ddl.TableSpec(
70 fields=[
71 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
72 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
73 ],
74 foreignKeys=[
75 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"),
76 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
77 ],
78)
80TEMPORARY_TABLE_SPEC = ddl.TableSpec(
81 fields=[
82 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
83 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
84 ],
85)
88@contextmanager
89def _patch_getExistingTable(db: Database) -> Database:
90 """Patch getExistingTable method in a database instance to test concurrent
91 creation of tables. This patch obviously depends on knowning internals of
92 ``ensureTableExists()`` implementation.
93 """
94 original_method = db.getExistingTable
96 def _getExistingTable(name: str, spec: ddl.TableSpec) -> Optional[sqlalchemy.schema.Table]:
97 # Return None on first call, but forward to original method after that
98 db.getExistingTable = original_method
99 return None
101 db.getExistingTable = _getExistingTable
102 yield db
103 db.getExistingTable = original_method
106class DatabaseTests(ABC):
107 """Generic tests for the `Database` interface that can be subclassed to
108 generate tests for concrete implementations.
109 """
111 @abstractmethod
112 def makeEmptyDatabase(self, origin: int = 0) -> Database:
113 """Return an empty `Database` with the given origin, or an
114 automatically-generated one if ``origin`` is `None`.
115 """
116 raise NotImplementedError()
118 @abstractmethod
119 def asReadOnly(self, database: Database) -> ContextManager[Database]:
120 """Return a context manager for a read-only connection into the given
121 database.
123 The original database should be considered unusable within the context
124 but safe to use again afterwards (this allows the context manager to
125 block write access by temporarily changing user permissions to really
126 guarantee that write operations are not performed).
127 """
128 raise NotImplementedError()
130 @abstractmethod
131 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
132 """Return a new `Database` instance that points to the same underlying
133 storage as the given one.
134 """
135 raise NotImplementedError()
137 def query_list(
138 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase
139 ) -> list[sqlalchemy.engine.Row]:
140 """Run a SELECT or other read-only query against the database and
141 return the results as a list.
143 This is a thin wrapper around database.query() that just avoids
144 context-manager boilerplate that is usefully verbose in production code
145 but just noise in tests.
146 """
147 with database.query(executable) as result:
148 return result.fetchall()
150 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any:
151 """Run a SELECT query that yields a single column and row against the
152 database and return its value.
154 This is a thin wrapper around database.query() that just avoids
155 context-manager boilerplate that is usefully verbose in production code
156 but just noise in tests.
157 """
158 with database.query(executable) as result:
159 return result.scalar()
161 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
162 self.assertCountEqual(spec.fields.names, table.columns.keys())
163 # Checking more than this currently seems fragile, as it might restrict
164 # what Database implementations do; we don't care if the spec is
165 # actually preserved in terms of types and constraints as long as we
166 # can use the returned table as if it was.
168 def checkStaticSchema(self, tables: StaticTablesTuple):
169 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
170 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
171 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
173 def testDeclareStaticTables(self):
174 """Tests for `Database.declareStaticSchema` and the methods it
175 delegates to.
176 """
177 # Create the static schema in a new, empty database.
178 newDatabase = self.makeEmptyDatabase()
179 with newDatabase.declareStaticTables(create=True) as context:
180 tables = context.addTableTuple(STATIC_TABLE_SPECS)
181 self.checkStaticSchema(tables)
182 # Check that we can load that schema even from a read-only connection.
183 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
184 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
185 tables = context.addTableTuple(STATIC_TABLE_SPECS)
186 self.checkStaticSchema(tables)
188 def testDeclareStaticTablesTwice(self):
189 """Tests for `Database.declareStaticSchema` being called twice."""
190 # Create the static schema in a new, empty database.
191 newDatabase = self.makeEmptyDatabase()
192 with newDatabase.declareStaticTables(create=True) as context:
193 tables = context.addTableTuple(STATIC_TABLE_SPECS)
194 self.checkStaticSchema(tables)
195 # Second time it should raise
196 with self.assertRaises(SchemaAlreadyDefinedError):
197 with newDatabase.declareStaticTables(create=True) as context:
198 tables = context.addTableTuple(STATIC_TABLE_SPECS)
199 # Check schema, it should still contain all tables, and maybe some
200 # extra.
201 with newDatabase.declareStaticTables(create=False) as context:
202 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
204 def testRepr(self):
205 """Test that repr does not return a generic thing."""
206 newDatabase = self.makeEmptyDatabase()
207 rep = repr(newDatabase)
208 # Check that stringification works and gives us something different
209 self.assertNotEqual(rep, str(newDatabase))
210 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
211 self.assertIn("://", rep)
213 def testDynamicTables(self):
214 """Tests for `Database.ensureTableExists` and
215 `Database.getExistingTable`.
216 """
217 # Need to start with the static schema.
218 newDatabase = self.makeEmptyDatabase()
219 with newDatabase.declareStaticTables(create=True) as context:
220 context.addTableTuple(STATIC_TABLE_SPECS)
221 # Try to ensure the dynamic table exists in a read-only version of that
222 # database, which should fail because we can't create it.
223 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
224 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
225 context.addTableTuple(STATIC_TABLE_SPECS)
226 with self.assertRaises(ReadOnlyDatabaseError):
227 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
228 # Just getting the dynamic table before it exists should return None.
229 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
230 # Ensure the new table exists back in the original database, which
231 # should create it.
232 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
233 self.checkTable(DYNAMIC_TABLE_SPEC, table)
234 # Ensuring that it exists should just return the exact same table
235 # instance again.
236 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
237 # Try again from the read-only database.
238 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
239 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
240 context.addTableTuple(STATIC_TABLE_SPECS)
241 # Just getting the dynamic table should now work...
242 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
243 # ...as should ensuring that it exists, since it now does.
244 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
245 self.checkTable(DYNAMIC_TABLE_SPEC, table)
246 # Trying to get the table with a different specification (at least
247 # in terms of what columns are present) should raise.
248 with self.assertRaises(DatabaseConflictError):
249 newDatabase.ensureTableExists(
250 "d",
251 ddl.TableSpec(
252 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
253 ),
254 )
255 # Calling ensureTableExists inside a transaction block is an error,
256 # even if it would do nothing.
257 with newDatabase.transaction():
258 with self.assertRaises(AssertionError):
259 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
261 def testDynamicTablesConcurrency(self):
262 """Tests for `Database.ensureTableExists` concurrent use."""
263 # We cannot really run things concurrently in a deterministic way, here
264 # we just simulate a situation when the table is created by other
265 # process between the call to getExistingTable() and actual table
266 # creation.
267 db1 = self.makeEmptyDatabase()
268 with db1.declareStaticTables(create=True) as context:
269 context.addTableTuple(STATIC_TABLE_SPECS)
270 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC))
272 # Make a dynamic table using separate connection
273 db2 = self.getNewConnection(db1, writeable=True)
274 with db2.declareStaticTables(create=False) as context:
275 context.addTableTuple(STATIC_TABLE_SPECS)
276 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
277 self.checkTable(DYNAMIC_TABLE_SPEC, table)
279 # Call it again but trick it into thinking that table is not there.
280 # This test depends on knowing implementation of ensureTableExists()
281 # which initially calls getExistingTable() to check that table may
282 # exist, the patch intercepts that call and returns None.
283 with _patch_getExistingTable(db1):
284 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
286 def testTemporaryTables(self):
287 """Tests for `Database.temporary_table`, and `Database.insert` with the
288 ``select`` argument.
289 """
290 # Need to start with the static schema; also insert some test data.
291 newDatabase = self.makeEmptyDatabase()
292 with newDatabase.declareStaticTables(create=True) as context:
293 static = context.addTableTuple(STATIC_TABLE_SPECS)
294 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None})
295 bIds = newDatabase.insert(
296 static.b,
297 {"name": "b1", "value": 11},
298 {"name": "b2", "value": 12},
299 {"name": "b3", "value": 13},
300 returnIds=True,
301 )
302 # Create the table.
303 with newDatabase.session():
304 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1:
305 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
306 # Insert via a INSERT INTO ... SELECT query.
307 newDatabase.insert(
308 table1,
309 select=sqlalchemy.sql.select(
310 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
311 )
312 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
313 .where(
314 sqlalchemy.sql.and_(
315 static.a.columns.name == "a1",
316 static.b.columns.value <= 12,
317 )
318 ),
319 )
320 # Check that the inserted rows are present.
321 self.assertCountEqual(
322 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
323 [row._asdict() for row in self.query_list(newDatabase, table1.select())],
324 )
325 # Create another one via a read-only connection to the
326 # database. We _do_ allow temporary table modifications in
327 # read-only databases.
328 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
329 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
330 context.addTableTuple(STATIC_TABLE_SPECS)
331 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2:
332 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
333 # Those tables should not be the same, despite having
334 # the same ddl.
335 self.assertIsNot(table1, table2)
336 # Do a slightly different insert into this table, to
337 # check that it works in a read-only database. This
338 # time we pass column names as a kwarg to insert
339 # instead of by labeling the columns in the select.
340 existingReadOnlyDatabase.insert(
341 table2,
342 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id)
343 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
344 .where(
345 sqlalchemy.sql.and_(
346 static.a.columns.name == "a2",
347 static.b.columns.value >= 12,
348 )
349 ),
350 names=["a_name", "b_id"],
351 )
352 # Check that the inserted rows are present.
353 self.assertCountEqual(
354 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
355 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())],
356 )
357 # Exiting the context managers will drop the temporary tables from
358 # the read-only DB. It's unspecified whether attempting to use it
359 # after this point is an error or just never returns any results,
360 # so we can't test what it does, only that it's not an error.
362 def testSchemaSeparation(self):
363 """Test that creating two different `Database` instances allows us
364 to create different tables with the same name in each.
365 """
366 db1 = self.makeEmptyDatabase(origin=1)
367 with db1.declareStaticTables(create=True) as context:
368 tables = context.addTableTuple(STATIC_TABLE_SPECS)
369 self.checkStaticSchema(tables)
371 db2 = self.makeEmptyDatabase(origin=2)
372 # Make the DDL here intentionally different so we'll definitely
373 # notice if db1 and db2 are pointing at the same schema.
374 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
375 with db2.declareStaticTables(create=True) as context:
376 # Make the DDL here intentionally different so we'll definitely
377 # notice if db1 and db2 are pointing at the same schema.
378 table = context.addTable("a", spec)
379 self.checkTable(spec, table)
381 def testInsertQueryDelete(self):
382 """Test the `Database.insert`, `Database.query`, and `Database.delete`
383 methods, as well as the `Base64Region` type and the ``onDelete``
384 argument to `ddl.ForeignKeySpec`.
385 """
386 db = self.makeEmptyDatabase(origin=1)
387 with db.declareStaticTables(create=True) as context:
388 tables = context.addTableTuple(STATIC_TABLE_SPECS)
389 # Insert a single, non-autoincrement row that contains a region and
390 # query to get it back.
391 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
392 row = {"name": "a1", "region": region}
393 db.insert(tables.a, row)
394 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row])
395 # Insert multiple autoincrement rows but do not try to get the IDs
396 # back immediately.
397 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
398 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))]
399 self.assertEqual(len(results), 2)
400 for row in results:
401 self.assertIn(row["name"], ("b1", "b2"))
402 self.assertIsInstance(row["id"], int)
403 self.assertGreater(results[1]["id"], results[0]["id"])
404 # Insert multiple autoincrement rows and get the IDs back from insert.
405 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
406 ids = db.insert(tables.b, *rows, returnIds=True)
407 results = [
408 r._asdict()
409 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"]))
410 ]
411 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
412 self.assertCountEqual(results, expected)
413 self.assertTrue(all(result["id"] is not None for result in results))
414 # Insert multiple rows into a table with an autoincrement primary key,
415 # then use the returned IDs to insert into a dynamic table.
416 rows = [{"b_id": results[0]["id"]}, {"b_id": None}]
417 ids = db.insert(tables.c, *rows, returnIds=True)
418 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
419 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
420 self.assertCountEqual(results, expected)
421 self.assertTrue(all(result["id"] is not None for result in results))
422 # Add the dynamic table.
423 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
424 # Insert into it.
425 rows = [{"c_id": id, "a_name": "a1"} for id in ids]
426 db.insert(d, *rows)
427 results = [r._asdict() for r in self.query_list(db, d.select())]
428 self.assertCountEqual(rows, results)
429 # Insert multiple rows into a table with an autoincrement primary key,
430 # but pass in a value for the autoincrement key.
431 rows2 = [
432 {"id": 700, "b_id": None},
433 {"id": 701, "b_id": None},
434 ]
435 db.insert(tables.c, *rows2)
436 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
437 self.assertCountEqual(results, expected + rows2)
438 self.assertTrue(all(result["id"] is not None for result in results))
440 # Define 'SELECT COUNT(*)' query for later use.
441 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
442 # Get the values we inserted into table b.
443 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())]
444 # Remove two row from table b by ID.
445 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
446 self.assertEqual(n, 2)
447 # Remove the other two rows from table b by name.
448 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
449 self.assertEqual(n, 2)
450 # There should now be no rows in table b.
451 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
452 # All b_id values in table c should now be NULL, because there's an
453 # onDelete='SET NULL' foreign key.
454 self.assertEqual(
455 self.query_scalar(
456 db,
457 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711
458 ),
459 0,
460 )
461 # Remove all rows in table a (there's only one); this should remove all
462 # rows in d due to onDelete='CASCADE'.
463 n = db.delete(tables.a, [])
464 self.assertEqual(n, 1)
465 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0)
466 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0)
468 def testDeleteWhere(self):
469 """Tests for `Database.deleteWhere`."""
470 db = self.makeEmptyDatabase(origin=1)
471 with db.declareStaticTables(create=True) as context:
472 tables = context.addTableTuple(STATIC_TABLE_SPECS)
473 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
474 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
476 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
477 self.assertEqual(n, 3)
478 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7)
480 n = db.deleteWhere(
481 tables.b,
482 tables.b.columns.id.in_(
483 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
484 ),
485 )
486 self.assertEqual(n, 4)
487 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3)
489 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
490 self.assertEqual(n, 1)
491 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2)
493 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
494 self.assertEqual(n, 2)
495 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
497 def testUpdate(self):
498 """Tests for `Database.update`."""
499 db = self.makeEmptyDatabase(origin=1)
500 with db.declareStaticTables(create=True) as context:
501 tables = context.addTableTuple(STATIC_TABLE_SPECS)
502 # Insert two rows into table a, both without regions.
503 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
504 # Update one of the rows with a region.
505 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
506 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
507 self.assertEqual(n, 1)
508 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
509 self.assertCountEqual(
510 [r._asdict() for r in self.query_list(db, sql)],
511 [{"name": "a1", "region": None}, {"name": "a2", "region": region}],
512 )
514 def testSync(self):
515 """Tests for `Database.sync`."""
516 db = self.makeEmptyDatabase(origin=1)
517 with db.declareStaticTables(create=True) as context:
518 tables = context.addTableTuple(STATIC_TABLE_SPECS)
519 # Insert a row with sync, because it doesn't exist yet.
520 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
521 self.assertTrue(inserted)
522 self.assertEqual(
523 [{"id": values["id"], "name": "b1", "value": 10}],
524 [r._asdict() for r in self.query_list(db, tables.b.select())],
525 )
526 # Repeat that operation, which should do nothing but return the
527 # requested values.
528 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
529 self.assertFalse(inserted)
530 self.assertEqual(
531 [{"id": values["id"], "name": "b1", "value": 10}],
532 [r._asdict() for r in self.query_list(db, tables.b.select())],
533 )
534 # Repeat the operation without the 'extra' arg, which should also just
535 # return the existing row.
536 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
537 self.assertFalse(inserted)
538 self.assertEqual(
539 [{"id": values["id"], "name": "b1", "value": 10}],
540 [r._asdict() for r in self.query_list(db, tables.b.select())],
541 )
542 # Repeat the operation with a different value in 'extra'. That still
543 # shouldn't be an error, because 'extra' is only used if we really do
544 # insert. Also drop the 'returning' argument.
545 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
546 self.assertFalse(inserted)
547 self.assertEqual(
548 [{"id": values["id"], "name": "b1", "value": 10}],
549 [r._asdict() for r in self.query_list(db, tables.b.select())],
550 )
551 # Repeat the operation with the correct value in 'compared' instead of
552 # 'extra'.
553 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
554 self.assertFalse(inserted)
555 self.assertEqual(
556 [{"id": values["id"], "name": "b1", "value": 10}],
557 [r._asdict() for r in self.query_list(db, tables.b.select())],
558 )
559 # Repeat the operation with an incorrect value in 'compared'; this
560 # should raise.
561 with self.assertRaises(DatabaseConflictError):
562 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
563 # Try to sync in a read-only database. This should work if and only
564 # if the matching row already exists.
565 with self.asReadOnly(db) as rodb:
566 with rodb.declareStaticTables(create=False) as context:
567 tables = context.addTableTuple(STATIC_TABLE_SPECS)
568 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
569 self.assertFalse(inserted)
570 self.assertEqual(
571 [{"id": values["id"], "name": "b1", "value": 10}],
572 [r._asdict() for r in self.query_list(rodb, tables.b.select())],
573 )
574 with self.assertRaises(ReadOnlyDatabaseError):
575 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
576 # Repeat the operation with a different value in 'compared' and ask to
577 # update.
578 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
579 self.assertEqual(updated, {"value": 10})
580 self.assertEqual(
581 [{"id": values["id"], "name": "b1", "value": 20}],
582 [r._asdict() for r in self.query_list(db, tables.b.select())],
583 )
585 def testReplace(self):
586 """Tests for `Database.replace`."""
587 db = self.makeEmptyDatabase(origin=1)
588 with db.declareStaticTables(create=True) as context:
589 tables = context.addTableTuple(STATIC_TABLE_SPECS)
590 # Use 'replace' to insert a single row that contains a region and
591 # query to get it back.
592 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
593 row1 = {"name": "a1", "region": region}
594 db.replace(tables.a, row1)
595 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
596 # Insert another row without a region.
597 row2 = {"name": "a2", "region": None}
598 db.replace(tables.a, row2)
599 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
600 # Use replace to re-insert both of those rows again, which should do
601 # nothing.
602 db.replace(tables.a, row1, row2)
603 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
604 # Replace row1 with a row with no region, while reinserting row2.
605 row1a = {"name": "a1", "region": None}
606 db.replace(tables.a, row1a, row2)
607 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2])
608 # Replace both rows, returning row1 to its original state, while adding
609 # a new one. Pass them in in a different order.
610 row2a = {"name": "a2", "region": region}
611 row3 = {"name": "a3", "region": None}
612 db.replace(tables.a, row3, row2a, row1)
613 self.assertCountEqual(
614 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3]
615 )
617 def testEnsure(self):
618 """Tests for `Database.ensure`."""
619 db = self.makeEmptyDatabase(origin=1)
620 with db.declareStaticTables(create=True) as context:
621 tables = context.addTableTuple(STATIC_TABLE_SPECS)
622 # Use 'ensure' to insert a single row that contains a region and
623 # query to get it back.
624 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
625 row1 = {"name": "a1", "region": region}
626 self.assertEqual(db.ensure(tables.a, row1), 1)
627 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
628 # Insert another row without a region.
629 row2 = {"name": "a2", "region": None}
630 self.assertEqual(db.ensure(tables.a, row2), 1)
631 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
632 # Use ensure to re-insert both of those rows again, which should do
633 # nothing.
634 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
635 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
636 # Attempt to insert row1's key with no region, while
637 # reinserting row2. This should also do nothing.
638 row1a = {"name": "a1", "region": None}
639 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
640 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
641 # Attempt to insert new rows for both existing keys, this time also
642 # adding a new row. Pass them in in a different order. Only the new
643 # row should be added.
644 row2a = {"name": "a2", "region": region}
645 row3 = {"name": "a3", "region": None}
646 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
647 self.assertCountEqual(
648 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3]
649 )
650 # Add some data to a table with both a primary key and a different
651 # unique constraint.
652 row_b = {"id": 5, "name": "five", "value": 50}
653 db.insert(tables.b, row_b)
654 # Attempt ensure with primary_key_only=False and a conflict for the
655 # non-PK constraint. This should do nothing.
656 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200})
657 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
658 # Now use primary_key_only=True with conflict in only the non-PK field.
659 # This should be an integrity error and nothing should change.
660 with self.assertRaises(sqlalchemy.exc.IntegrityError):
661 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True)
662 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
663 # With primary_key_only=True a conflict in the primary key is ignored
664 # regardless of whether there is a conflict elsewhere.
665 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True)
666 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
667 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True)
668 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
670 def testTransactionNesting(self):
671 """Test that transactions can be nested with the behavior in the
672 presence of exceptions working as documented.
673 """
674 db = self.makeEmptyDatabase(origin=1)
675 with db.declareStaticTables(create=True) as context:
676 tables = context.addTableTuple(STATIC_TABLE_SPECS)
677 # Insert one row so we can trigger integrity errors by trying to insert
678 # a duplicate of it below.
679 db.insert(tables.a, {"name": "a1"})
680 # First test: error recovery via explicit savepoint=True in the inner
681 # transaction.
682 with db.transaction():
683 # This insert should succeed, and should not be rolled back because
684 # the assertRaises context should catch any exception before it
685 # propagates up to the outer transaction.
686 db.insert(tables.a, {"name": "a2"})
687 with self.assertRaises(sqlalchemy.exc.IntegrityError):
688 with db.transaction(savepoint=True):
689 # This insert should succeed, but should be rolled back.
690 db.insert(tables.a, {"name": "a4"})
691 # This insert should fail (duplicate primary key), raising
692 # an exception.
693 db.insert(tables.a, {"name": "a1"})
694 self.assertCountEqual(
695 [r._asdict() for r in self.query_list(db, tables.a.select())],
696 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
697 )
698 # Second test: error recovery via implicit savepoint=True, when the
699 # innermost transaction is inside a savepoint=True transaction.
700 with db.transaction():
701 # This insert should succeed, and should not be rolled back
702 # because the assertRaises context should catch any
703 # exception before it propagates up to the outer
704 # transaction.
705 db.insert(tables.a, {"name": "a3"})
706 with self.assertRaises(sqlalchemy.exc.IntegrityError):
707 with db.transaction(savepoint=True):
708 # This insert should succeed, but should be rolled back.
709 db.insert(tables.a, {"name": "a4"})
710 with db.transaction():
711 # This insert should succeed, but should be rolled
712 # back.
713 db.insert(tables.a, {"name": "a5"})
714 # This insert should fail (duplicate primary key),
715 # raising an exception.
716 db.insert(tables.a, {"name": "a1"})
717 self.assertCountEqual(
718 [r._asdict() for r in self.query_list(db, tables.a.select())],
719 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
720 )
722 def testTransactionLocking(self):
723 """Test that `Database.transaction` can be used to acquire a lock
724 that prohibits concurrent writes.
725 """
726 db1 = self.makeEmptyDatabase(origin=1)
727 with db1.declareStaticTables(create=True) as context:
728 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
730 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
731 """One side of the concurrent locking test.
733 This optionally locks the table (and maybe the whole database),
734 does a select for its contents, inserts a new row, and then selects
735 again, with some waiting in between to make sure the other side has
736 a chance to _attempt_ to insert in between. If the locking is
737 enabled and works, the difference between the selects should just
738 be the insert done on this thread.
739 """
740 # Give Side2 a chance to create a connection
741 await asyncio.sleep(1.0)
742 with db1.transaction(lock=lock):
743 names1 = {row.name for row in self.query_list(db1, tables1.a.select())}
744 # Give Side2 a chance to insert (which will be blocked if
745 # we've acquired a lock).
746 await asyncio.sleep(2.0)
747 db1.insert(tables1.a, {"name": "a1"})
748 names2 = {row.name for row in self.query_list(db1, tables1.a.select())}
749 return names1, names2
751 async def side2() -> None:
752 """The other side of the concurrent locking test.
754 This side just waits a bit and then tries to insert a row into the
755 table that the other side is trying to lock. Hopefully that
756 waiting is enough to give the other side a chance to acquire the
757 lock and thus make this side block until the lock is released. If
758 this side manages to do the insert before side1 acquires the lock,
759 we'll just warn about not succeeding at testing the locking,
760 because we can only make that unlikely, not impossible.
761 """
763 def toRunInThread():
764 """SQLite locking isn't asyncio-friendly unless we actually
765 run it in another thread. And SQLite gets very unhappy if
766 we try to use a connection from multiple threads, so we have
767 to create the new connection here instead of out in the main
768 body of the test function.
769 """
770 db2 = self.getNewConnection(db1, writeable=True)
771 with db2.declareStaticTables(create=False) as context:
772 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
773 with db2.transaction():
774 db2.insert(tables2.a, {"name": "a2"})
776 await asyncio.sleep(2.0)
777 loop = asyncio.get_running_loop()
778 with ThreadPoolExecutor() as pool:
779 await loop.run_in_executor(pool, toRunInThread)
781 async def testProblemsWithNoLocking() -> None:
782 """Run side1 and side2 with no locking, attempting to demonstrate
783 the problem that locking is supposed to solve. If we get unlucky
784 with scheduling, side2 will just happen to insert after side1 is
785 done, and we won't have anything definitive. We just warn in that
786 case because we really don't want spurious test failures.
787 """
788 task1 = asyncio.create_task(side1())
789 task2 = asyncio.create_task(side2())
791 names1, names2 = await task1
792 await task2
793 if "a2" in names1:
794 warnings.warn(
795 "Unlucky scheduling in no-locking test: concurrent INSERT "
796 "happened before first SELECT."
797 )
798 self.assertEqual(names1, {"a2"})
799 self.assertEqual(names2, {"a1", "a2"})
800 elif "a2" not in names2:
801 warnings.warn(
802 "Unlucky scheduling in no-locking test: concurrent INSERT "
803 "happened after second SELECT even without locking."
804 )
805 self.assertEqual(names1, set())
806 self.assertEqual(names2, {"a1"})
807 else:
808 # This is the expected case: both INSERTS happen between the
809 # two SELECTS. If we don't get this almost all of the time we
810 # should adjust the sleep amounts.
811 self.assertEqual(names1, set())
812 self.assertEqual(names2, {"a1", "a2"})
814 asyncio.run(testProblemsWithNoLocking())
816 # Clean up after first test.
817 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
819 async def testSolutionWithLocking() -> None:
820 """Run side1 and side2 with locking, which should make side2 block
821 its insert until side2 releases its lock.
822 """
823 task1 = asyncio.create_task(side1(lock=[tables1.a]))
824 task2 = asyncio.create_task(side2())
826 names1, names2 = await task1
827 await task2
828 if "a2" in names1:
829 warnings.warn(
830 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT."
831 )
832 self.assertEqual(names1, {"a2"})
833 self.assertEqual(names2, {"a1", "a2"})
834 else:
835 # This is the expected case: the side2 INSERT happens after the
836 # last SELECT on side1. This can also happen due to unlucky
837 # scheduling, and we have no way to detect that here, but the
838 # similar "no-locking" test has at least some chance of being
839 # affected by the same problem and warning about it.
840 self.assertEqual(names1, set())
841 self.assertEqual(names2, {"a1"})
843 asyncio.run(testSolutionWithLocking())
845 def testTimespanDatabaseRepresentation(self):
846 """Tests for `TimespanDatabaseRepresentation` and the `Database`
847 methods that interact with it.
848 """
849 # Make some test timespans to play with, with the full suite of
850 # topological relationships.
851 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
852 offset = astropy.time.TimeDelta(60, format="sec")
853 timestamps = [start + offset * n for n in range(3)]
854 aTimespans = [Timespan(begin=None, end=None)]
855 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
856 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
857 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
858 aTimespans.append(Timespan.makeEmpty())
859 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
860 # Make another list of timespans that span the full range but don't
861 # overlap. This is a subset of the previous list.
862 bTimespans = [Timespan(begin=None, end=timestamps[0])]
863 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:]))
864 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
865 # Make a database and create a table with that database's timespan
866 # representation. This one will have no exclusion constraint and
867 # a nullable timespan.
868 db = self.makeEmptyDatabase(origin=1)
869 TimespanReprClass = db.getTimespanRepresentation()
870 aSpec = ddl.TableSpec(
871 fields=[
872 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
873 ],
874 )
875 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
876 aSpec.fields.add(fieldSpec)
877 with db.declareStaticTables(create=True) as context:
878 aTable = context.addTable("a", aSpec)
879 self.maxDiff = None
881 def convertRowForInsert(row: dict) -> dict:
882 """Convert a row containing a Timespan instance into one suitable
883 for insertion into the database.
884 """
885 result = row.copy()
886 ts = result.pop(TimespanReprClass.NAME)
887 return TimespanReprClass.update(ts, result=result)
889 def convertRowFromSelect(row: dict) -> dict:
890 """Convert a row from the database into one containing a Timespan.
892 Parameters
893 ----------
894 row : `dict`
895 Original row.
897 Returns
898 -------
899 row : `dict`
900 The updated row.
901 """
902 result = row.copy()
903 timespan = TimespanReprClass.extract(result)
904 for name in TimespanReprClass.getFieldNames():
905 del result[name]
906 result[TimespanReprClass.NAME] = timespan
907 return result
909 # Insert rows into table A, in chunks just to make things interesting.
910 # Include one with a NULL timespan.
911 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
912 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
913 db.insert(aTable, convertRowForInsert(aRows[0]))
914 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]])
915 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]])
916 # Add another one with a NULL timespan, but this time by invoking
917 # the server-side default.
918 aRows.append({"id": len(aRows)})
919 db.insert(aTable, aRows[-1])
920 aRows[-1][TimespanReprClass.NAME] = None
921 # Test basic round-trip through database.
922 self.assertEqual(
923 aRows,
924 [
925 convertRowFromSelect(row._asdict())
926 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id))
927 ],
928 )
929 # Create another table B with a not-null timespan and (if the database
930 # supports it), an exclusion constraint. Use ensureTableExists this
931 # time to check that mode of table creation vs. timespans.
932 bSpec = ddl.TableSpec(
933 fields=[
934 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
935 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
936 ],
937 )
938 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
939 bSpec.fields.add(fieldSpec)
940 if TimespanReprClass.hasExclusionConstraint():
941 bSpec.exclusion.add(("key", TimespanReprClass))
942 bTable = db.ensureTableExists("b", bSpec)
943 # Insert rows into table B, again in chunks. Each Timespan appears
944 # twice, but with different values for the 'key' field (which should
945 # still be okay for any exclusion constraint we may have defined).
946 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
947 offset = len(bRows)
948 bRows.extend(
949 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)
950 )
951 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]])
952 db.insert(bTable, convertRowForInsert(bRows[2]))
953 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]])
954 # Insert a row with no timespan into table B. This should invoke the
955 # server-side default, which is a timespan over (-∞, ∞). We set
956 # key=3 to avoid upsetting an exclusion constraint that might exist.
957 bRows.append({"id": len(bRows), "key": 3})
958 db.insert(bTable, bRows[-1])
959 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
960 # Test basic round-trip through database.
961 self.assertEqual(
962 bRows,
963 [
964 convertRowFromSelect(row._asdict())
965 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id))
966 ],
967 )
968 # Test that we can't insert timespan=None into this table.
969 with self.assertRaises(sqlalchemy.exc.IntegrityError):
970 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None}))
971 # IFF this database supports exclusion constraints, test that they
972 # also prevent inserts.
973 if TimespanReprClass.hasExclusionConstraint():
974 with self.assertRaises(sqlalchemy.exc.IntegrityError):
975 db.insert(
976 bTable,
977 convertRowForInsert(
978 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])}
979 ),
980 )
981 with self.assertRaises(sqlalchemy.exc.IntegrityError):
982 db.insert(
983 bTable,
984 convertRowForInsert(
985 {
986 "id": len(bRows),
987 "key": 1,
988 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]),
989 }
990 ),
991 )
992 with self.assertRaises(sqlalchemy.exc.IntegrityError):
993 db.insert(
994 bTable,
995 convertRowForInsert(
996 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)}
997 ),
998 )
999 # Test NULL checks in SELECT queries, on both tables.
1000 aRepr = TimespanReprClass.from_columns(aTable.columns)
1001 self.assertEqual(
1002 [row[TimespanReprClass.NAME] is None for row in aRows],
1003 [
1004 row.f
1005 for row in self.query_list(
1006 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id)
1007 )
1008 ],
1009 )
1010 bRepr = TimespanReprClass.from_columns(bTable.columns)
1011 self.assertEqual(
1012 [False for row in bRows],
1013 [
1014 row.f
1015 for row in self.query_list(
1016 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id)
1017 )
1018 ],
1019 )
1020 # Test relationships expressions that relate in-database timespans to
1021 # Python-literal timespans, all from the more complete 'a' set; check
1022 # that this is consistent with Python-only relationship tests.
1023 for rhsRow in aRows:
1024 if rhsRow[TimespanReprClass.NAME] is None:
1025 continue
1026 with self.subTest(rhsRow=rhsRow):
1027 expected = {}
1028 for lhsRow in aRows:
1029 if lhsRow[TimespanReprClass.NAME] is None:
1030 expected[lhsRow["id"]] = (None, None, None, None)
1031 else:
1032 expected[lhsRow["id"]] = (
1033 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
1034 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
1035 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
1036 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
1037 )
1038 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
1039 sql = sqlalchemy.sql.select(
1040 aTable.columns.id.label("lhs"),
1041 aRepr.overlaps(rhsRepr).label("overlaps"),
1042 aRepr.contains(rhsRepr).label("contains"),
1043 (aRepr < rhsRepr).label("less_than"),
1044 (aRepr > rhsRepr).label("greater_than"),
1045 ).select_from(aTable)
1046 queried = {
1047 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
1048 for row in self.query_list(db, sql)
1049 }
1050 self.assertEqual(expected, queried)
1051 # Test relationship expressions that relate in-database timespans to
1052 # each other, all from the more complete 'a' set; check that this is
1053 # consistent with Python-only relationship tests.
1054 expected = {}
1055 for lhs, rhs in itertools.product(aRows, aRows):
1056 lhsT = lhs[TimespanReprClass.NAME]
1057 rhsT = rhs[TimespanReprClass.NAME]
1058 if lhsT is not None and rhsT is not None:
1059 expected[lhs["id"], rhs["id"]] = (
1060 lhsT.overlaps(rhsT),
1061 lhsT.contains(rhsT),
1062 lhsT < rhsT,
1063 lhsT > rhsT,
1064 )
1065 else:
1066 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
1067 lhsSubquery = aTable.alias("lhs")
1068 rhsSubquery = aTable.alias("rhs")
1069 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns)
1070 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns)
1071 sql = sqlalchemy.sql.select(
1072 lhsSubquery.columns.id.label("lhs"),
1073 rhsSubquery.columns.id.label("rhs"),
1074 lhsRepr.overlaps(rhsRepr).label("overlaps"),
1075 lhsRepr.contains(rhsRepr).label("contains"),
1076 (lhsRepr < rhsRepr).label("less_than"),
1077 (lhsRepr > rhsRepr).label("greater_than"),
1078 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True)))
1079 queried = {
1080 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
1081 for row in self.query_list(db, sql)
1082 }
1083 self.assertEqual(expected, queried)
1084 # Test relationship expressions between in-database timespans and
1085 # Python-literal instantaneous times.
1086 for t in timestamps:
1087 with self.subTest(t=t):
1088 expected = {}
1089 for lhsRow in aRows:
1090 if lhsRow[TimespanReprClass.NAME] is None:
1091 expected[lhsRow["id"]] = (None, None, None, None)
1092 else:
1093 expected[lhsRow["id"]] = (
1094 lhsRow[TimespanReprClass.NAME].contains(t),
1095 lhsRow[TimespanReprClass.NAME].overlaps(t),
1096 lhsRow[TimespanReprClass.NAME] < t,
1097 lhsRow[TimespanReprClass.NAME] > t,
1098 )
1099 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1100 sql = sqlalchemy.sql.select(
1101 aTable.columns.id.label("lhs"),
1102 aRepr.contains(rhs).label("contains"),
1103 aRepr.overlaps(rhs).label("overlaps_point"),
1104 (aRepr < rhs).label("less_than"),
1105 (aRepr > rhs).label("greater_than"),
1106 ).select_from(aTable)
1107 queried = {
1108 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than)
1109 for row in self.query_list(db, sql)
1110 }
1111 self.assertEqual(expected, queried)
1113 def testConstantRows(self):
1114 """Test Database.constant_rows."""
1115 new_db = self.makeEmptyDatabase()
1116 with new_db.declareStaticTables(create=True) as context:
1117 static = context.addTableTuple(STATIC_TABLE_SPECS)
1118 b_ids = new_db.insert(
1119 static.b,
1120 {"name": "b1", "value": 11},
1121 {"name": "b2", "value": 12},
1122 {"name": "b3", "value": 13},
1123 returnIds=True,
1124 )
1125 values_spec = ddl.TableSpec(
1126 [
1127 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger),
1128 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)),
1129 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()),
1130 ],
1131 )
1132 values_data = [
1133 {"b": b_ids[0], "s": "b1", "r": None},
1134 {"b": b_ids[2], "s": "b3", "r": Circle.empty()},
1135 ]
1136 values = new_db.constant_rows(values_spec.fields, *values_data)
1137 select_values_alone = sqlalchemy.sql.select(
1138 values.columns["b"], values.columns["s"], values.columns["r"]
1139 )
1140 self.assertCountEqual(
1141 [dict(row) for row in self.query_list(new_db, select_values_alone)],
1142 values_data,
1143 )
1144 select_values_joined = sqlalchemy.sql.select(
1145 values.columns["s"].label("name"), static.b.columns["value"].label("value")
1146 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"]))
1147 self.assertCountEqual(
1148 [dict(row) for row in self.query_list(new_db, select_values_joined)],
1149 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}],
1150 )