Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%
511 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 03:44 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-19 03:44 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ["DatabaseTests"]
33import asyncio
34import itertools
35import warnings
36from abc import ABC, abstractmethod
37from collections import namedtuple
38from collections.abc import Iterable
39from concurrent.futures import ThreadPoolExecutor
40from contextlib import AbstractContextManager, contextmanager
41from typing import Any
43import astropy.time
44import sqlalchemy
45from lsst.sphgeom import Circle, ConvexPolygon, Mq3cPixelization, UnionRegion, UnitVector3d
47from ..._timespan import Timespan
48from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError
50StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
52STATIC_TABLE_SPECS = StaticTablesTuple(
53 a=ddl.TableSpec(
54 fields=[
55 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
56 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
57 ]
58 ),
59 b=ddl.TableSpec(
60 fields=[
61 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
62 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
63 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
64 ],
65 unique=[("name",)],
66 ),
67 c=ddl.TableSpec(
68 fields=[
69 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
70 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
71 ],
72 foreignKeys=[
73 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
74 ],
75 ),
76)
78DYNAMIC_TABLE_SPEC = ddl.TableSpec(
79 fields=[
80 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
81 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
82 ],
83 foreignKeys=[
84 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"),
85 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
86 ],
87)
89TEMPORARY_TABLE_SPEC = ddl.TableSpec(
90 fields=[
91 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
92 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
93 ],
94)
97@contextmanager
98def _patch_getExistingTable(db: Database) -> Database:
99 """Patch getExistingTable method in a database instance to test concurrent
100 creation of tables. This patch obviously depends on knowning internals of
101 ``ensureTableExists()`` implementation.
102 """
103 original_method = db.getExistingTable
105 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None:
106 # Return None on first call, but forward to original method after that
107 db.getExistingTable = original_method
108 return None
110 db.getExistingTable = _getExistingTable
111 yield db
112 db.getExistingTable = original_method
115class DatabaseTests(ABC):
116 """Generic tests for the `Database` interface that can be subclassed to
117 generate tests for concrete implementations.
118 """
120 @abstractmethod
121 def makeEmptyDatabase(self, origin: int = 0) -> Database:
122 """Return an empty `Database` with the given origin, or an
123 automatically-generated one if ``origin`` is `None`.
125 Parameters
126 ----------
127 origin : `int` or `None`
128 Origin to use for the database.
130 Returns
131 -------
132 db : `Database`
133 Empty database with given origin or auto-generated origin.
134 """
135 raise NotImplementedError()
137 @abstractmethod
138 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]:
139 """Return a context manager for a read-only connection into the given
140 database.
142 Parameters
143 ----------
144 database : `Database`
145 The database to use.
147 Yields
148 ------
149 `Database`
150 The new database connection.
152 Notes
153 -----
154 The original database should be considered unusable within the context
155 but safe to use again afterwards (this allows the context manager to
156 block write access by temporarily changing user permissions to really
157 guarantee that write operations are not performed).
158 """
159 raise NotImplementedError()
161 @abstractmethod
162 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
163 """Return a new `Database` instance that points to the same underlying
164 storage as the given one.
166 Parameters
167 ----------
168 database : `Database`
169 The current database.
170 writeable : `bool`
171 Whether the connection should be writeable or not.
173 Returns
174 -------
175 db : `Database`
176 The new database connection.
177 """
178 raise NotImplementedError()
180 def query_list(
181 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase
182 ) -> list[sqlalchemy.engine.Row]:
183 """Run a SELECT or other read-only query against the database and
184 return the results as a list.
186 Parameters
187 ----------
188 database : `Database`
189 The database to use.
190 executable : `sqlalchemy.sql.expression.SelectBase`
191 Expression to execute.
193 Returns
194 -------
195 results : `list` of `sqlalchemy.engine.Row`
196 The results.
198 Notes
199 -----
200 This is a thin wrapper around database.query() that just avoids
201 context-manager boilerplate that is usefully verbose in production code
202 but just noise in tests.
203 """
204 with database.transaction(), database.query(executable) as result:
205 return result.fetchall()
207 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any:
208 """Run a SELECT query that yields a single column and row against the
209 database and return its value.
211 Parameters
212 ----------
213 database : `Database`
214 The database to use.
215 executable : `sqlalchemy.sql.expression.SelectBase`
216 Expression to execute.
218 Returns
219 -------
220 results : `~typing.Any`
221 The results.
223 Notes
224 -----
225 This is a thin wrapper around database.query() that just avoids
226 context-manager boilerplate that is usefully verbose in production code
227 but just noise in tests.
228 """
229 with database.query(executable) as result:
230 return result.scalar()
232 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
233 self.assertCountEqual(spec.fields.names, table.columns.keys())
234 # Checking more than this currently seems fragile, as it might restrict
235 # what Database implementations do; we don't care if the spec is
236 # actually preserved in terms of types and constraints as long as we
237 # can use the returned table as if it was.
239 def checkStaticSchema(self, tables: StaticTablesTuple):
240 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
241 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
242 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
244 def testDeclareStaticTables(self):
245 """Tests for `Database.declareStaticSchema` and the methods it
246 delegates to.
247 """
248 # Create the static schema in a new, empty database.
249 newDatabase = self.makeEmptyDatabase()
250 with newDatabase.declareStaticTables(create=True) as context:
251 tables = context.addTableTuple(STATIC_TABLE_SPECS)
252 self.checkStaticSchema(tables)
253 # Check that we can load that schema even from a read-only connection.
254 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
255 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
256 tables = context.addTableTuple(STATIC_TABLE_SPECS)
257 self.checkStaticSchema(tables)
259 def testDeclareStaticTablesTwice(self):
260 """Tests for `Database.declareStaticSchema` being called twice."""
261 # Create the static schema in a new, empty database.
262 newDatabase = self.makeEmptyDatabase()
263 with newDatabase.declareStaticTables(create=True) as context:
264 tables = context.addTableTuple(STATIC_TABLE_SPECS)
265 self.checkStaticSchema(tables)
266 # Second time it should raise
267 with self.assertRaises(SchemaAlreadyDefinedError):
268 with newDatabase.declareStaticTables(create=True) as context:
269 tables = context.addTableTuple(STATIC_TABLE_SPECS)
270 # Check schema, it should still contain all tables, and maybe some
271 # extra.
272 with newDatabase.declareStaticTables(create=False) as context:
273 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
275 def testRepr(self):
276 """Test that repr does not return a generic thing."""
277 newDatabase = self.makeEmptyDatabase()
278 rep = repr(newDatabase)
279 # Check that stringification works and gives us something different
280 self.assertNotEqual(rep, str(newDatabase))
281 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
282 self.assertIn("://", rep)
284 def testDynamicTables(self):
285 """Tests for `Database.ensureTableExists` and
286 `Database.getExistingTable`.
287 """
288 # Need to start with the static schema.
289 newDatabase = self.makeEmptyDatabase()
290 with newDatabase.declareStaticTables(create=True) as context:
291 context.addTableTuple(STATIC_TABLE_SPECS)
292 # Try to ensure the dynamic table exists in a read-only version of that
293 # database, which should fail because we can't create it.
294 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
295 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
296 context.addTableTuple(STATIC_TABLE_SPECS)
297 with self.assertRaises(ReadOnlyDatabaseError):
298 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
299 # Just getting the dynamic table before it exists should return None.
300 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
301 # Ensure the new table exists back in the original database, which
302 # should create it.
303 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
304 self.checkTable(DYNAMIC_TABLE_SPEC, table)
305 # Ensuring that it exists should just return the exact same table
306 # instance again.
307 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
308 # Try again from the read-only database.
309 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
310 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
311 context.addTableTuple(STATIC_TABLE_SPECS)
312 # Just getting the dynamic table should now work...
313 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
314 # ...as should ensuring that it exists, since it now does.
315 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
316 self.checkTable(DYNAMIC_TABLE_SPEC, table)
317 # Trying to get the table with a different specification (at least
318 # in terms of what columns are present) should raise.
319 with self.assertRaises(DatabaseConflictError):
320 newDatabase.ensureTableExists(
321 "d",
322 ddl.TableSpec(
323 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
324 ),
325 )
326 # Calling ensureTableExists inside a transaction block is an error,
327 # even if it would do nothing.
328 with newDatabase.transaction():
329 with self.assertRaises(AssertionError):
330 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
332 def testDynamicTablesConcurrency(self):
333 """Tests for `Database.ensureTableExists` concurrent use."""
334 # We cannot really run things concurrently in a deterministic way, here
335 # we just simulate a situation when the table is created by other
336 # process between the call to getExistingTable() and actual table
337 # creation.
338 db1 = self.makeEmptyDatabase()
339 with db1.declareStaticTables(create=True) as context:
340 context.addTableTuple(STATIC_TABLE_SPECS)
341 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC))
343 # Make a dynamic table using separate connection
344 db2 = self.getNewConnection(db1, writeable=True)
345 with db2.declareStaticTables(create=False) as context:
346 context.addTableTuple(STATIC_TABLE_SPECS)
347 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
348 self.checkTable(DYNAMIC_TABLE_SPEC, table)
350 # Call it again but trick it into thinking that table is not there.
351 # This test depends on knowing implementation of ensureTableExists()
352 # which initially calls getExistingTable() to check that table may
353 # exist, the patch intercepts that call and returns None.
354 with _patch_getExistingTable(db1):
355 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
357 def testTemporaryTables(self):
358 """Tests for `Database.temporary_table`, and `Database.insert` with the
359 ``select`` argument.
360 """
361 # Need to start with the static schema; also insert some test data.
362 newDatabase = self.makeEmptyDatabase()
363 with newDatabase.declareStaticTables(create=True) as context:
364 static = context.addTableTuple(STATIC_TABLE_SPECS)
365 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None})
366 bIds = newDatabase.insert(
367 static.b,
368 {"name": "b1", "value": 11},
369 {"name": "b2", "value": 12},
370 {"name": "b3", "value": 13},
371 returnIds=True,
372 )
373 # Create the table.
374 with newDatabase.session():
375 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1:
376 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
377 # Insert via a INSERT INTO ... SELECT query.
378 newDatabase.insert(
379 table1,
380 select=sqlalchemy.sql.select(
381 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
382 )
383 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
384 .where(
385 sqlalchemy.sql.and_(
386 static.a.columns.name == "a1",
387 static.b.columns.value <= 12,
388 )
389 ),
390 )
391 # Check that the inserted rows are present.
392 self.assertCountEqual(
393 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
394 [row._asdict() for row in self.query_list(newDatabase, table1.select())],
395 )
396 # Create another one via a read-only connection to the
397 # database. We _do_ allow temporary table modifications in
398 # read-only databases.
399 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
400 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
401 context.addTableTuple(STATIC_TABLE_SPECS)
402 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2:
403 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
404 # Those tables should not be the same, despite having
405 # the same ddl.
406 self.assertIsNot(table1, table2)
407 # Do a slightly different insert into this table, to
408 # check that it works in a read-only database. This
409 # time we pass column names as a kwarg to insert
410 # instead of by labeling the columns in the select.
411 existingReadOnlyDatabase.insert(
412 table2,
413 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id)
414 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
415 .where(
416 sqlalchemy.sql.and_(
417 static.a.columns.name == "a2",
418 static.b.columns.value >= 12,
419 )
420 ),
421 names=["a_name", "b_id"],
422 )
423 # Check that the inserted rows are present.
424 self.assertCountEqual(
425 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
426 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())],
427 )
428 # Exiting the context managers will drop the temporary tables from
429 # the read-only DB. It's unspecified whether attempting to use it
430 # after this point is an error or just never returns any results,
431 # so we can't test what it does, only that it's not an error.
433 def testSchemaSeparation(self):
434 """Test that creating two different `Database` instances allows us
435 to create different tables with the same name in each.
436 """
437 db1 = self.makeEmptyDatabase(origin=1)
438 with db1.declareStaticTables(create=True) as context:
439 tables = context.addTableTuple(STATIC_TABLE_SPECS)
440 self.checkStaticSchema(tables)
442 db2 = self.makeEmptyDatabase(origin=2)
443 # Make the DDL here intentionally different so we'll definitely
444 # notice if db1 and db2 are pointing at the same schema.
445 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
446 with db2.declareStaticTables(create=True) as context:
447 # Make the DDL here intentionally different so we'll definitely
448 # notice if db1 and db2 are pointing at the same schema.
449 table = context.addTable("a", spec)
450 self.checkTable(spec, table)
452 def testInsertQueryDelete(self):
453 """Test the `Database.insert`, `Database.query`, and `Database.delete`
454 methods, as well as the `Base64Region` type and the ``onDelete``
455 argument to `ddl.ForeignKeySpec`.
456 """
457 db = self.makeEmptyDatabase(origin=1)
458 with db.declareStaticTables(create=True) as context:
459 tables = context.addTableTuple(STATIC_TABLE_SPECS)
460 # Insert a single, non-autoincrement row that contains a region and
461 # query to get it back.
462 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
463 row = {"name": "a1", "region": region}
464 db.insert(tables.a, row)
465 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row])
466 # Insert multiple autoincrement rows but do not try to get the IDs
467 # back immediately.
468 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
469 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))]
470 self.assertEqual(len(results), 2)
471 for row in results:
472 self.assertIn(row["name"], ("b1", "b2"))
473 self.assertIsInstance(row["id"], int)
474 self.assertGreater(results[1]["id"], results[0]["id"])
475 # Insert multiple autoincrement rows and get the IDs back from insert.
476 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
477 ids = db.insert(tables.b, *rows, returnIds=True)
478 results = [
479 r._asdict()
480 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"]))
481 ]
482 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
483 self.assertCountEqual(results, expected)
484 self.assertTrue(all(result["id"] is not None for result in results))
485 # Insert multiple rows into a table with an autoincrement primary key,
486 # then use the returned IDs to insert into a dynamic table.
487 rows = [{"b_id": results[0]["id"]}, {"b_id": None}]
488 ids = db.insert(tables.c, *rows, returnIds=True)
489 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
490 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
491 self.assertCountEqual(results, expected)
492 self.assertTrue(all(result["id"] is not None for result in results))
493 # Add the dynamic table.
494 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
495 # Insert into it.
496 rows = [{"c_id": id, "a_name": "a1"} for id in ids]
497 db.insert(d, *rows)
498 results = [r._asdict() for r in self.query_list(db, d.select())]
499 self.assertCountEqual(rows, results)
500 # Insert multiple rows into a table with an autoincrement primary key,
501 # but pass in a value for the autoincrement key.
502 rows2 = [
503 {"id": 700, "b_id": None},
504 {"id": 701, "b_id": None},
505 ]
506 db.insert(tables.c, *rows2)
507 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
508 self.assertCountEqual(results, expected + rows2)
509 self.assertTrue(all(result["id"] is not None for result in results))
511 # Define 'SELECT COUNT(*)' query for later use.
512 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
513 # Get the values we inserted into table b.
514 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())]
515 # Remove two row from table b by ID.
516 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
517 self.assertEqual(n, 2)
518 # Remove the other two rows from table b by name.
519 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
520 self.assertEqual(n, 2)
521 # There should now be no rows in table b.
522 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
523 # All b_id values in table c should now be NULL, because there's an
524 # onDelete='SET NULL' foreign key.
525 self.assertEqual(
526 self.query_scalar(
527 db,
528 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711
529 ),
530 0,
531 )
532 # Remove all rows in table a (there's only one); this should remove all
533 # rows in d due to onDelete='CASCADE'.
534 n = db.delete(tables.a, [])
535 self.assertEqual(n, 1)
536 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0)
537 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0)
539 def testDeleteWhere(self):
540 """Tests for `Database.deleteWhere`."""
541 db = self.makeEmptyDatabase(origin=1)
542 with db.declareStaticTables(create=True) as context:
543 tables = context.addTableTuple(STATIC_TABLE_SPECS)
544 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
545 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
547 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
548 self.assertEqual(n, 3)
549 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7)
551 n = db.deleteWhere(
552 tables.b,
553 tables.b.columns.id.in_(
554 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
555 ),
556 )
557 self.assertEqual(n, 4)
558 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3)
560 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
561 self.assertEqual(n, 1)
562 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2)
564 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
565 self.assertEqual(n, 2)
566 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
568 def testUpdate(self):
569 """Tests for `Database.update`."""
570 db = self.makeEmptyDatabase(origin=1)
571 with db.declareStaticTables(create=True) as context:
572 tables = context.addTableTuple(STATIC_TABLE_SPECS)
573 # Insert two rows into table a, both without regions.
574 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
575 # Update one of the rows with a region.
576 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
577 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
578 self.assertEqual(n, 1)
579 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
580 self.assertCountEqual(
581 [r._asdict() for r in self.query_list(db, sql)],
582 [{"name": "a1", "region": None}, {"name": "a2", "region": region}],
583 )
585 def testSync(self):
586 """Tests for `Database.sync`."""
587 db = self.makeEmptyDatabase(origin=1)
588 with db.declareStaticTables(create=True) as context:
589 tables = context.addTableTuple(STATIC_TABLE_SPECS)
590 # Insert a row with sync, because it doesn't exist yet.
591 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
592 self.assertTrue(inserted)
593 self.assertEqual(
594 [{"id": values["id"], "name": "b1", "value": 10}],
595 [r._asdict() for r in self.query_list(db, tables.b.select())],
596 )
597 # Repeat that operation, which should do nothing but return the
598 # requested values.
599 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
600 self.assertFalse(inserted)
601 self.assertEqual(
602 [{"id": values["id"], "name": "b1", "value": 10}],
603 [r._asdict() for r in self.query_list(db, tables.b.select())],
604 )
605 # Repeat the operation without the 'extra' arg, which should also just
606 # return the existing row.
607 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
608 self.assertFalse(inserted)
609 self.assertEqual(
610 [{"id": values["id"], "name": "b1", "value": 10}],
611 [r._asdict() for r in self.query_list(db, tables.b.select())],
612 )
613 # Repeat the operation with a different value in 'extra'. That still
614 # shouldn't be an error, because 'extra' is only used if we really do
615 # insert. Also drop the 'returning' argument.
616 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
617 self.assertFalse(inserted)
618 self.assertEqual(
619 [{"id": values["id"], "name": "b1", "value": 10}],
620 [r._asdict() for r in self.query_list(db, tables.b.select())],
621 )
622 # Repeat the operation with the correct value in 'compared' instead of
623 # 'extra'.
624 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
625 self.assertFalse(inserted)
626 self.assertEqual(
627 [{"id": values["id"], "name": "b1", "value": 10}],
628 [r._asdict() for r in self.query_list(db, tables.b.select())],
629 )
630 # Repeat the operation with an incorrect value in 'compared'; this
631 # should raise.
632 with self.assertRaises(DatabaseConflictError):
633 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
634 # Try to sync in a read-only database. This should work if and only
635 # if the matching row already exists.
636 with self.asReadOnly(db) as rodb:
637 with rodb.declareStaticTables(create=False) as context:
638 tables = context.addTableTuple(STATIC_TABLE_SPECS)
639 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
640 self.assertFalse(inserted)
641 self.assertEqual(
642 [{"id": values["id"], "name": "b1", "value": 10}],
643 [r._asdict() for r in self.query_list(rodb, tables.b.select())],
644 )
645 with self.assertRaises(ReadOnlyDatabaseError):
646 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
647 # Repeat the operation with a different value in 'compared' and ask to
648 # update.
649 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
650 self.assertEqual(updated, {"value": 10})
651 self.assertEqual(
652 [{"id": values["id"], "name": "b1", "value": 20}],
653 [r._asdict() for r in self.query_list(db, tables.b.select())],
654 )
656 def testReplace(self):
657 """Tests for `Database.replace`."""
658 db = self.makeEmptyDatabase(origin=1)
659 with db.declareStaticTables(create=True) as context:
660 tables = context.addTableTuple(STATIC_TABLE_SPECS)
661 # Use 'replace' to insert a single row that contains a region and
662 # query to get it back.
663 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
664 row1 = {"name": "a1", "region": region}
665 db.replace(tables.a, row1)
666 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
667 # Insert another row without a region.
668 row2 = {"name": "a2", "region": None}
669 db.replace(tables.a, row2)
670 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
671 # Use replace to re-insert both of those rows again, which should do
672 # nothing.
673 db.replace(tables.a, row1, row2)
674 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
675 # Replace row1 with a row with no region, while reinserting row2.
676 row1a = {"name": "a1", "region": None}
677 db.replace(tables.a, row1a, row2)
678 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2])
679 # Replace both rows, returning row1 to its original state, while adding
680 # a new one. Pass them in in a different order.
681 row2a = {"name": "a2", "region": region}
682 row3 = {"name": "a3", "region": None}
683 db.replace(tables.a, row3, row2a, row1)
684 self.assertCountEqual(
685 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3]
686 )
688 def testEnsure(self):
689 """Tests for `Database.ensure`."""
690 db = self.makeEmptyDatabase(origin=1)
691 with db.declareStaticTables(create=True) as context:
692 tables = context.addTableTuple(STATIC_TABLE_SPECS)
693 # Use 'ensure' to insert a single row that contains a region and
694 # query to get it back.
695 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
696 row1 = {"name": "a1", "region": region}
697 self.assertEqual(db.ensure(tables.a, row1), 1)
698 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
699 # Insert another row without a region.
700 row2 = {"name": "a2", "region": None}
701 self.assertEqual(db.ensure(tables.a, row2), 1)
702 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
703 # Use ensure to re-insert both of those rows again, which should do
704 # nothing.
705 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
706 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
707 # Attempt to insert row1's key with no region, while
708 # reinserting row2. This should also do nothing.
709 row1a = {"name": "a1", "region": None}
710 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
711 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
712 # Attempt to insert new rows for both existing keys, this time also
713 # adding a new row. Pass them in in a different order. Only the new
714 # row should be added.
715 row2a = {"name": "a2", "region": region}
716 row3 = {"name": "a3", "region": None}
717 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
718 self.assertCountEqual(
719 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3]
720 )
721 # Add some data to a table with both a primary key and a different
722 # unique constraint.
723 row_b = {"id": 5, "name": "five", "value": 50}
724 db.insert(tables.b, row_b)
725 # Attempt ensure with primary_key_only=False and a conflict for the
726 # non-PK constraint. This should do nothing.
727 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200})
728 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
729 # Now use primary_key_only=True with conflict in only the non-PK field.
730 # This should be an integrity error and nothing should change.
731 with self.assertRaises(sqlalchemy.exc.IntegrityError):
732 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True)
733 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
734 # With primary_key_only=True a conflict in the primary key is ignored
735 # regardless of whether there is a conflict elsewhere.
736 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True)
737 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
738 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True)
739 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
741 def testTransactionNesting(self):
742 """Test that transactions can be nested with the behavior in the
743 presence of exceptions working as documented.
744 """
745 db = self.makeEmptyDatabase(origin=1)
746 with db.declareStaticTables(create=True) as context:
747 tables = context.addTableTuple(STATIC_TABLE_SPECS)
748 # Insert one row so we can trigger integrity errors by trying to insert
749 # a duplicate of it below.
750 db.insert(tables.a, {"name": "a1"})
751 # First test: error recovery via explicit savepoint=True in the inner
752 # transaction.
753 with db.transaction():
754 # This insert should succeed, and should not be rolled back because
755 # the assertRaises context should catch any exception before it
756 # propagates up to the outer transaction.
757 db.insert(tables.a, {"name": "a2"})
758 with self.assertRaises(sqlalchemy.exc.IntegrityError):
759 with db.transaction(savepoint=True):
760 # This insert should succeed, but should be rolled back.
761 db.insert(tables.a, {"name": "a4"})
762 # This insert should fail (duplicate primary key), raising
763 # an exception.
764 db.insert(tables.a, {"name": "a1"})
765 self.assertCountEqual(
766 [r._asdict() for r in self.query_list(db, tables.a.select())],
767 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
768 )
769 # Second test: error recovery via implicit savepoint=True, when the
770 # innermost transaction is inside a savepoint=True transaction.
771 with db.transaction():
772 # This insert should succeed, and should not be rolled back
773 # because the assertRaises context should catch any
774 # exception before it propagates up to the outer
775 # transaction.
776 db.insert(tables.a, {"name": "a3"})
777 with self.assertRaises(sqlalchemy.exc.IntegrityError):
778 with db.transaction(savepoint=True):
779 # This insert should succeed, but should be rolled back.
780 db.insert(tables.a, {"name": "a4"})
781 with db.transaction():
782 # This insert should succeed, but should be rolled
783 # back.
784 db.insert(tables.a, {"name": "a5"})
785 # This insert should fail (duplicate primary key),
786 # raising an exception.
787 db.insert(tables.a, {"name": "a1"})
788 self.assertCountEqual(
789 [r._asdict() for r in self.query_list(db, tables.a.select())],
790 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
791 )
793 def testTransactionLocking(self):
794 """Test that `Database.transaction` can be used to acquire a lock
795 that prohibits concurrent writes.
796 """
797 db1 = self.makeEmptyDatabase(origin=1)
798 with db1.declareStaticTables(create=True) as context:
799 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
801 async def _side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]:
802 """One side of the concurrent locking test.
804 Parameters
805 ----------
806 lock : `~collections.abc.Iterable` of `str`
807 Locks.
809 Notes
810 -----
811 This optionally locks the table (and maybe the whole database),
812 does a select for its contents, inserts a new row, and then selects
813 again, with some waiting in between to make sure the other side has
814 a chance to _attempt_ to insert in between. If the locking is
815 enabled and works, the difference between the selects should just
816 be the insert done on this thread.
817 """
818 # Give Side2 a chance to create a connection
819 await asyncio.sleep(1.0)
820 with db1.transaction(lock=lock):
821 names1 = {row.name for row in self.query_list(db1, tables1.a.select())}
822 # Give Side2 a chance to insert (which will be blocked if
823 # we've acquired a lock).
824 await asyncio.sleep(2.0)
825 db1.insert(tables1.a, {"name": "a1"})
826 names2 = {row.name for row in self.query_list(db1, tables1.a.select())}
827 return names1, names2
829 async def _side2() -> None:
830 """Other side of the concurrent locking test.
832 Notes
833 -----
834 This side just waits a bit and then tries to insert a row into the
835 table that the other side is trying to lock. Hopefully that
836 waiting is enough to give the other side a chance to acquire the
837 lock and thus make this side block until the lock is released. If
838 this side manages to do the insert before side1 acquires the lock,
839 we'll just warn about not succeeding at testing the locking,
840 because we can only make that unlikely, not impossible.
841 """
843 def _toRunInThread():
844 """Create new SQLite connection for use in thread.
846 SQLite locking isn't asyncio-friendly unless we actually
847 run it in another thread. And SQLite gets very unhappy if
848 we try to use a connection from multiple threads, so we have
849 to create the new connection here instead of out in the main
850 body of the test function.
851 """
852 db2 = self.getNewConnection(db1, writeable=True)
853 with db2.declareStaticTables(create=False) as context:
854 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
855 with db2.transaction():
856 db2.insert(tables2.a, {"name": "a2"})
858 await asyncio.sleep(2.0)
859 loop = asyncio.get_running_loop()
860 with ThreadPoolExecutor() as pool:
861 await loop.run_in_executor(pool, _toRunInThread)
863 async def _testProblemsWithNoLocking() -> None:
864 """Run side1 and side2 with no locking, attempting to demonstrate
865 the problem that locking is supposed to solve. If we get unlucky
866 with scheduling, side2 will just happen to insert after side1 is
867 done, and we won't have anything definitive. We just warn in that
868 case because we really don't want spurious test failures.
869 """
870 task1 = asyncio.create_task(_side1())
871 task2 = asyncio.create_task(_side2())
873 names1, names2 = await task1
874 await task2
875 if "a2" in names1:
876 warnings.warn(
877 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.",
878 stacklevel=1,
879 )
880 self.assertEqual(names1, {"a2"})
881 self.assertEqual(names2, {"a1", "a2"})
882 elif "a2" not in names2:
883 warnings.warn(
884 "Unlucky scheduling in no-locking test: concurrent INSERT "
885 "happened after second SELECT even without locking.",
886 stacklevel=1,
887 )
888 self.assertEqual(names1, set())
889 self.assertEqual(names2, {"a1"})
890 else:
891 # This is the expected case: both INSERTS happen between the
892 # two SELECTS. If we don't get this almost all of the time we
893 # should adjust the sleep amounts.
894 self.assertEqual(names1, set())
895 self.assertEqual(names2, {"a1", "a2"})
897 asyncio.run(_testProblemsWithNoLocking())
899 # Clean up after first test.
900 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
902 async def _testSolutionWithLocking() -> None:
903 """Run side1 and side2 with locking, which should make side2 block
904 its insert until side2 releases its lock.
905 """
906 task1 = asyncio.create_task(_side1(lock=[tables1.a]))
907 task2 = asyncio.create_task(_side2())
909 names1, names2 = await task1
910 await task2
911 if "a2" in names1:
912 warnings.warn(
913 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.",
914 stacklevel=1,
915 )
916 self.assertEqual(names1, {"a2"})
917 self.assertEqual(names2, {"a1", "a2"})
918 else:
919 # This is the expected case: the side2 INSERT happens after the
920 # last SELECT on side1. This can also happen due to unlucky
921 # scheduling, and we have no way to detect that here, but the
922 # similar "no-locking" test has at least some chance of being
923 # affected by the same problem and warning about it.
924 self.assertEqual(names1, set())
925 self.assertEqual(names2, {"a1"})
927 asyncio.run(_testSolutionWithLocking())
929 def testTimespanDatabaseRepresentation(self):
930 """Tests for `TimespanDatabaseRepresentation` and the `Database`
931 methods that interact with it.
932 """
933 # Make some test timespans to play with, with the full suite of
934 # topological relationships.
935 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
936 offset = astropy.time.TimeDelta(60, format="sec")
937 timestamps = [start + offset * n for n in range(3)]
938 aTimespans = [Timespan(begin=None, end=None)]
939 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
940 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
941 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
942 aTimespans.append(Timespan.makeEmpty())
943 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
944 # Make another list of timespans that span the full range but don't
945 # overlap. This is a subset of the previous list.
946 bTimespans = [Timespan(begin=None, end=timestamps[0])]
947 bTimespans.extend(
948 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True)
949 )
950 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
951 # Make a database and create a table with that database's timespan
952 # representation. This one will have no exclusion constraint and
953 # a nullable timespan.
954 db = self.makeEmptyDatabase(origin=1)
955 TimespanReprClass = db.getTimespanRepresentation()
956 aSpec = ddl.TableSpec(
957 fields=[
958 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
959 ],
960 )
961 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
962 aSpec.fields.add(fieldSpec)
963 with db.declareStaticTables(create=True) as context:
964 aTable = context.addTable("a", aSpec)
965 self.maxDiff = None
967 def _convertRowForInsert(row: dict) -> dict:
968 """Convert a row containing a Timespan instance into one suitable
969 for insertion into the database.
970 """
971 result = row.copy()
972 ts = result.pop(TimespanReprClass.NAME)
973 return TimespanReprClass.update(ts, result=result)
975 def _convertRowFromSelect(row: dict) -> dict:
976 """Convert a row from the database into one containing a Timespan.
978 Parameters
979 ----------
980 row : `dict`
981 Original row.
983 Returns
984 -------
985 row : `dict`
986 The updated row.
987 """
988 result = row.copy()
989 timespan = TimespanReprClass.extract(result)
990 for name in TimespanReprClass.getFieldNames():
991 del result[name]
992 result[TimespanReprClass.NAME] = timespan
993 return result
995 # Insert rows into table A, in chunks just to make things interesting.
996 # Include one with a NULL timespan.
997 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
998 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
999 db.insert(aTable, _convertRowForInsert(aRows[0]))
1000 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[1:3]])
1001 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[3:]])
1002 # Add another one with a NULL timespan, but this time by invoking
1003 # the server-side default.
1004 aRows.append({"id": len(aRows)})
1005 db.insert(aTable, aRows[-1])
1006 aRows[-1][TimespanReprClass.NAME] = None
1007 # Test basic round-trip through database.
1008 self.assertEqual(
1009 aRows,
1010 [
1011 _convertRowFromSelect(row._asdict())
1012 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id))
1013 ],
1014 )
1015 # Create another table B with a not-null timespan and (if the database
1016 # supports it), an exclusion constraint. Use ensureTableExists this
1017 # time to check that mode of table creation vs. timespans.
1018 bSpec = ddl.TableSpec(
1019 fields=[
1020 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
1021 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
1022 ],
1023 )
1024 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
1025 bSpec.fields.add(fieldSpec)
1026 if TimespanReprClass.hasExclusionConstraint():
1027 bSpec.exclusion.add(("key", TimespanReprClass))
1028 bTable = db.ensureTableExists("b", bSpec)
1029 # Insert rows into table B, again in chunks. Each Timespan appears
1030 # twice, but with different values for the 'key' field (which should
1031 # still be okay for any exclusion constraint we may have defined).
1032 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
1033 offset = len(bRows)
1034 bRows.extend(
1035 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)
1036 )
1037 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[:2]])
1038 db.insert(bTable, _convertRowForInsert(bRows[2]))
1039 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[3:]])
1040 # Insert a row with no timespan into table B. This should invoke the
1041 # server-side default, which is a timespan over (-∞, ∞). We set
1042 # key=3 to avoid upsetting an exclusion constraint that might exist.
1043 bRows.append({"id": len(bRows), "key": 3})
1044 db.insert(bTable, bRows[-1])
1045 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
1046 # Test basic round-trip through database.
1047 self.assertEqual(
1048 bRows,
1049 [
1050 _convertRowFromSelect(row._asdict())
1051 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id))
1052 ],
1053 )
1054 # Test that we can't insert timespan=None into this table.
1055 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1056 db.insert(
1057 bTable, _convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})
1058 )
1059 # IFF this database supports exclusion constraints, test that they
1060 # also prevent inserts.
1061 if TimespanReprClass.hasExclusionConstraint():
1062 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1063 db.insert(
1064 bTable,
1065 _convertRowForInsert(
1066 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])}
1067 ),
1068 )
1069 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1070 db.insert(
1071 bTable,
1072 _convertRowForInsert(
1073 {
1074 "id": len(bRows),
1075 "key": 1,
1076 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]),
1077 }
1078 ),
1079 )
1080 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1081 db.insert(
1082 bTable,
1083 _convertRowForInsert(
1084 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)}
1085 ),
1086 )
1087 # Test NULL checks in SELECT queries, on both tables.
1088 aRepr = TimespanReprClass.from_columns(aTable.columns)
1089 self.assertEqual(
1090 [row[TimespanReprClass.NAME] is None for row in aRows],
1091 [
1092 row.f
1093 for row in self.query_list(
1094 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id)
1095 )
1096 ],
1097 )
1098 bRepr = TimespanReprClass.from_columns(bTable.columns)
1099 self.assertEqual(
1100 [False for row in bRows],
1101 [
1102 row.f
1103 for row in self.query_list(
1104 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id)
1105 )
1106 ],
1107 )
1108 # Test relationships expressions that relate in-database timespans to
1109 # Python-literal timespans, all from the more complete 'a' set; check
1110 # that this is consistent with Python-only relationship tests.
1111 for rhsRow in aRows:
1112 if rhsRow[TimespanReprClass.NAME] is None:
1113 continue
1114 with self.subTest(rhsRow=rhsRow):
1115 expected = {}
1116 for lhsRow in aRows:
1117 if lhsRow[TimespanReprClass.NAME] is None:
1118 expected[lhsRow["id"]] = (None, None, None, None)
1119 else:
1120 expected[lhsRow["id"]] = (
1121 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
1122 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
1123 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
1124 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
1125 )
1126 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
1127 sql = sqlalchemy.sql.select(
1128 aTable.columns.id.label("lhs"),
1129 aRepr.overlaps(rhsRepr).label("overlaps"),
1130 aRepr.contains(rhsRepr).label("contains"),
1131 (aRepr < rhsRepr).label("less_than"),
1132 (aRepr > rhsRepr).label("greater_than"),
1133 ).select_from(aTable)
1134 queried = {
1135 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
1136 for row in self.query_list(db, sql)
1137 }
1138 self.assertEqual(expected, queried)
1139 # Test relationship expressions that relate in-database timespans to
1140 # each other, all from the more complete 'a' set; check that this is
1141 # consistent with Python-only relationship tests.
1142 expected = {}
1143 for lhs, rhs in itertools.product(aRows, aRows):
1144 lhsT = lhs[TimespanReprClass.NAME]
1145 rhsT = rhs[TimespanReprClass.NAME]
1146 if lhsT is not None and rhsT is not None:
1147 expected[lhs["id"], rhs["id"]] = (
1148 lhsT.overlaps(rhsT),
1149 lhsT.contains(rhsT),
1150 lhsT < rhsT,
1151 lhsT > rhsT,
1152 )
1153 else:
1154 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
1155 lhsSubquery = aTable.alias("lhs")
1156 rhsSubquery = aTable.alias("rhs")
1157 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns)
1158 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns)
1159 sql = sqlalchemy.sql.select(
1160 lhsSubquery.columns.id.label("lhs"),
1161 rhsSubquery.columns.id.label("rhs"),
1162 lhsRepr.overlaps(rhsRepr).label("overlaps"),
1163 lhsRepr.contains(rhsRepr).label("contains"),
1164 (lhsRepr < rhsRepr).label("less_than"),
1165 (lhsRepr > rhsRepr).label("greater_than"),
1166 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True)))
1167 queried = {
1168 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
1169 for row in self.query_list(db, sql)
1170 }
1171 self.assertEqual(expected, queried)
1172 # Test relationship expressions between in-database timespans and
1173 # Python-literal instantaneous times.
1174 for t in timestamps:
1175 with self.subTest(t=t):
1176 expected = {}
1177 for lhsRow in aRows:
1178 if lhsRow[TimespanReprClass.NAME] is None:
1179 expected[lhsRow["id"]] = (None, None, None, None)
1180 else:
1181 expected[lhsRow["id"]] = (
1182 lhsRow[TimespanReprClass.NAME].contains(t),
1183 lhsRow[TimespanReprClass.NAME].overlaps(t),
1184 lhsRow[TimespanReprClass.NAME] < t,
1185 lhsRow[TimespanReprClass.NAME] > t,
1186 )
1187 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1188 sql = sqlalchemy.sql.select(
1189 aTable.columns.id.label("lhs"),
1190 aRepr.contains(rhs).label("contains"),
1191 aRepr.overlaps(rhs).label("overlaps_point"),
1192 (aRepr < rhs).label("less_than"),
1193 (aRepr > rhs).label("greater_than"),
1194 ).select_from(aTable)
1195 queried = {
1196 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than)
1197 for row in self.query_list(db, sql)
1198 }
1199 self.assertEqual(expected, queried)
1201 def testConstantRows(self):
1202 """Test Database.constant_rows."""
1203 new_db = self.makeEmptyDatabase()
1204 with new_db.declareStaticTables(create=True) as context:
1205 static = context.addTableTuple(STATIC_TABLE_SPECS)
1206 b_ids = new_db.insert(
1207 static.b,
1208 {"name": "b1", "value": 11},
1209 {"name": "b2", "value": 12},
1210 {"name": "b3", "value": 13},
1211 returnIds=True,
1212 )
1213 values_spec = ddl.TableSpec(
1214 [
1215 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger),
1216 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)),
1217 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()),
1218 ],
1219 )
1220 values_data = [
1221 {"b": b_ids[0], "s": "b1", "r": None},
1222 {"b": b_ids[2], "s": "b3", "r": Circle.empty()},
1223 ]
1224 values = new_db.constant_rows(values_spec.fields, *values_data)
1225 select_values_alone = sqlalchemy.sql.select(
1226 values.columns["b"], values.columns["s"], values.columns["r"]
1227 )
1228 self.assertCountEqual(
1229 [row._mapping for row in self.query_list(new_db, select_values_alone)],
1230 values_data,
1231 )
1232 select_values_joined = sqlalchemy.sql.select(
1233 values.columns["s"].label("name"), static.b.columns["value"].label("value")
1234 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"]))
1235 self.assertCountEqual(
1236 [row._mapping for row in self.query_list(new_db, select_values_joined)],
1237 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}],
1238 )
1240 def test_aggregate(self) -> None:
1241 """Test Database.apply_any_aggregate, ddl.Base64Region.union_aggregate,
1242 and TimespanDatabaseRepresetnation.apply_any_aggregate.
1243 """
1244 db = self.makeEmptyDatabase()
1245 with db.declareStaticTables(create=True) as context:
1246 t = context.addTable(
1247 "t",
1248 ddl.TableSpec(
1249 [
1250 ddl.FieldSpec("id", sqlalchemy.BigInteger(), nullable=False),
1251 ddl.FieldSpec("name", sqlalchemy.String(16), nullable=False),
1252 ddl.FieldSpec.for_region(),
1253 ]
1254 + list(db.getTimespanRepresentation().makeFieldSpecs(nullable=True)),
1255 ),
1256 )
1257 pixelization = Mq3cPixelization(10)
1258 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
1259 offset = astropy.time.TimeDelta(60, format="sec")
1260 timespans = [Timespan(begin=start + offset * n, end=start + offset * (n + 1)) for n in range(3)]
1261 ts_cls = db.getTimespanRepresentation()
1262 ts_col = ts_cls.from_columns(t.columns)
1263 db.insert(
1264 t,
1265 ts_cls.update(timespans[0], result={"id": 1, "name": "a", "region": pixelization.quad(12058870)}),
1266 ts_cls.update(timespans[1], result={"id": 2, "name": "a", "region": pixelization.quad(12058871)}),
1267 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058872)}),
1268 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058873)}),
1269 )
1270 # This should use DISTINCT ON in PostgreSQL and GROUP BY in SQLite.
1271 if db.has_distinct_on:
1272 sql = (
1273 sqlalchemy.select(
1274 t.c.id.label("i"),
1275 t.c.name.label("n"),
1276 *ts_col.flatten("t"),
1277 )
1278 .select_from(t)
1279 .distinct(t.c.id)
1280 )
1281 elif db.has_any_aggregate:
1282 sql = (
1283 sqlalchemy.select(
1284 t.c.id.label("i"),
1285 db.apply_any_aggregate(t.c.name).label("n"),
1286 *ts_col.apply_any_aggregate(db.apply_any_aggregate).flatten("t"),
1287 )
1288 .select_from(t)
1289 .group_by(t.c.id)
1290 )
1291 else:
1292 raise AssertionError(
1293 "PostgreSQL should support DISTINCT ON and SQLite should support no-op any aggregates."
1294 )
1295 self.assertCountEqual(
1296 [(row.i, row.n, ts_cls.extract(row._mapping, "t")) for row in self.query_list(db, sql)],
1297 [(1, "a", timespans[0]), (2, "a", timespans[1]), (3, "b", timespans[2])],
1298 )
1299 # Test union_aggregate in all versions of both database, with a GROUP
1300 # BY that does not need apply_any_aggregate.
1301 self.assertCountEqual(
1302 [
1303 (row.i, row.r.encode())
1304 for row in self.query_list(
1305 db,
1306 sqlalchemy.select(
1307 t.c.id.label("i"), ddl.Base64Region.union_aggregate(t.c.region).label("r")
1308 )
1309 .select_from(t)
1310 .group_by(t.c.id),
1311 )
1312 ],
1313 [
1314 (1, pixelization.quad(12058870).encode()),
1315 (2, pixelization.quad(12058871).encode()),
1316 (3, UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()),
1317 ],
1318 )
1319 if db.has_any_aggregate:
1320 # This should use run in SQLite and PostgreSQL 16+.
1321 self.assertCountEqual(
1322 [
1323 (row.i, row.n, row.r.encode())
1324 for row in self.query_list(
1325 db,
1326 sqlalchemy.select(
1327 t.c.id.label("i"),
1328 db.apply_any_aggregate(t.c.name).label("n"),
1329 ddl.Base64Region.union_aggregate(t.c.region).label("r"),
1330 )
1331 .select_from(t)
1332 .group_by(t.c.id),
1333 )
1334 ],
1335 [
1336 (1, "a", pixelization.quad(12058870).encode()),
1337 (2, "a", pixelization.quad(12058871).encode()),
1338 (3, "b", UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()),
1339 ],
1340 )