Coverage for python / lsst / daf / butler / registry / tests / _database.py: 9%
522 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:37 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:37 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ["DatabaseTests"]
33import asyncio
34import itertools
35import warnings
36from abc import ABC, abstractmethod
37from collections import namedtuple
38from collections.abc import Iterable
39from concurrent.futures import ThreadPoolExecutor
40from contextlib import AbstractContextManager, contextmanager
41from typing import Any
43import astropy.time
44import sqlalchemy
46from lsst.sphgeom import Circle, ConvexPolygon, Mq3cPixelization, UnionRegion, UnitVector3d
48from ..._timespan import Timespan
49from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError
51StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
53STATIC_TABLE_SPECS = StaticTablesTuple(
54 a=ddl.TableSpec(
55 fields=[
56 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
57 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
58 ]
59 ),
60 b=ddl.TableSpec(
61 fields=[
62 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
63 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
64 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
65 ],
66 unique=[("name",)],
67 ),
68 c=ddl.TableSpec(
69 fields=[
70 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
71 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
72 ],
73 foreignKeys=[
74 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
75 ],
76 ),
77)
79DYNAMIC_TABLE_SPEC = ddl.TableSpec(
80 fields=[
81 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
82 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
83 ],
84 foreignKeys=[
85 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"),
86 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
87 ],
88)
90TEMPORARY_TABLE_SPEC = ddl.TableSpec(
91 fields=[
92 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
93 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
94 ],
95)
98@contextmanager
99def _patch_getExistingTable(db: Database) -> Database:
100 """Patch getExistingTable method in a database instance to test concurrent
101 creation of tables. This patch obviously depends on knowning internals of
102 ``ensureTableExists()`` implementation.
103 """
104 original_method = db.getExistingTable
106 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None:
107 # Return None on first call, but forward to original method after that
108 db.getExistingTable = original_method
109 return None
111 db.getExistingTable = _getExistingTable
112 yield db
113 db.getExistingTable = original_method
116class DatabaseTests(ABC):
117 """Generic tests for the `Database` interface that can be subclassed to
118 generate tests for concrete implementations.
119 """
121 @abstractmethod
122 def makeEmptyDatabase(self, origin: int = 0) -> Database:
123 """Return an empty `Database` with the given origin, or an
124 automatically-generated one if ``origin`` is `None`.
126 Parameters
127 ----------
128 origin : `int` or `None`
129 Origin to use for the database.
131 Returns
132 -------
133 db : `Database`
134 Empty database with given origin or auto-generated origin.
135 """
136 raise NotImplementedError()
138 @abstractmethod
139 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]:
140 """Return a context manager for a read-only connection into the given
141 database.
143 Parameters
144 ----------
145 database : `Database`
146 The database to use.
148 Yields
149 ------
150 `Database`
151 The new database connection.
153 Notes
154 -----
155 The original database should be considered unusable within the context
156 but safe to use again afterwards (this allows the context manager to
157 block write access by temporarily changing user permissions to really
158 guarantee that write operations are not performed).
159 """
160 raise NotImplementedError()
162 @abstractmethod
163 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
164 """Return a new `Database` instance that points to the same underlying
165 storage as the given one.
167 Parameters
168 ----------
169 database : `Database`
170 The current database.
171 writeable : `bool`
172 Whether the connection should be writeable or not.
174 Returns
175 -------
176 db : `Database`
177 The new database connection.
178 """
179 raise NotImplementedError()
181 def query_list(
182 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase
183 ) -> list[sqlalchemy.engine.Row]:
184 """Run a SELECT or other read-only query against the database and
185 return the results as a list.
187 Parameters
188 ----------
189 database : `Database`
190 The database to use.
191 executable : `sqlalchemy.sql.expression.SelectBase`
192 Expression to execute.
194 Returns
195 -------
196 results : `list` of `sqlalchemy.engine.Row`
197 The results.
199 Notes
200 -----
201 This is a thin wrapper around database.query() that just avoids
202 context-manager boilerplate that is usefully verbose in production code
203 but just noise in tests.
204 """
205 with database.transaction(), database.query(executable) as result:
206 return result.fetchall()
208 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any:
209 """Run a SELECT query that yields a single column and row against the
210 database and return its value.
212 Parameters
213 ----------
214 database : `Database`
215 The database to use.
216 executable : `sqlalchemy.sql.expression.SelectBase`
217 Expression to execute.
219 Returns
220 -------
221 results : `~typing.Any`
222 The results.
224 Notes
225 -----
226 This is a thin wrapper around database.query() that just avoids
227 context-manager boilerplate that is usefully verbose in production code
228 but just noise in tests.
229 """
230 with database.query(executable) as result:
231 return result.scalar()
233 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
234 self.assertCountEqual(spec.fields.names, table.columns.keys())
235 # Checking more than this currently seems fragile, as it might restrict
236 # what Database implementations do; we don't care if the spec is
237 # actually preserved in terms of types and constraints as long as we
238 # can use the returned table as if it was.
240 def checkStaticSchema(self, tables: StaticTablesTuple):
241 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
242 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
243 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
245 def testDeclareStaticTables(self):
246 """Tests for `Database.declareStaticSchema` and the methods it
247 delegates to.
248 """
249 # Create the static schema in a new, empty database.
250 newDatabase = self.makeEmptyDatabase()
251 with newDatabase.declareStaticTables(create=True) as context:
252 tables = context.addTableTuple(STATIC_TABLE_SPECS)
253 self.checkStaticSchema(tables)
254 # Check that we can load that schema even from a read-only connection.
255 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
256 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
257 tables = context.addTableTuple(STATIC_TABLE_SPECS)
258 self.checkStaticSchema(tables)
260 def testDeclareStaticTablesTwice(self):
261 """Tests for `Database.declareStaticSchema` being called twice."""
262 # Create the static schema in a new, empty database.
263 newDatabase = self.makeEmptyDatabase()
264 with newDatabase.declareStaticTables(create=True) as context:
265 tables = context.addTableTuple(STATIC_TABLE_SPECS)
266 self.checkStaticSchema(tables)
267 # Second time it should raise
268 with self.assertRaises(SchemaAlreadyDefinedError):
269 with newDatabase.declareStaticTables(create=True) as context:
270 tables = context.addTableTuple(STATIC_TABLE_SPECS)
271 # Check schema, it should still contain all tables, and maybe some
272 # extra.
273 with newDatabase.declareStaticTables(create=False) as context:
274 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
276 def testRepr(self):
277 """Test that repr does not return a generic thing."""
278 newDatabase = self.makeEmptyDatabase()
279 rep = repr(newDatabase)
280 # Check that stringification works and gives us something different
281 self.assertNotEqual(rep, str(newDatabase))
282 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
283 self.assertIn("://", rep)
285 def testDynamicTables(self):
286 """Tests for `Database.ensureTableExists` and
287 `Database.getExistingTable`.
288 """
289 # Need to start with the static schema.
290 newDatabase = self.makeEmptyDatabase()
291 with newDatabase.declareStaticTables(create=True) as context:
292 context.addTableTuple(STATIC_TABLE_SPECS)
293 # Try to ensure the dynamic table exists in a read-only version of that
294 # database, which should fail because we can't create it.
295 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
296 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
297 context.addTableTuple(STATIC_TABLE_SPECS)
298 with self.assertRaises(ReadOnlyDatabaseError):
299 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
300 # Just getting the dynamic table before it exists should return None.
301 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
302 # Ensure the new table exists back in the original database, which
303 # should create it.
304 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
305 self.checkTable(DYNAMIC_TABLE_SPEC, table)
306 # Ensuring that it exists should just return the exact same table
307 # instance again.
308 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
309 # Try again from the read-only database.
310 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
311 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
312 context.addTableTuple(STATIC_TABLE_SPECS)
313 # Just getting the dynamic table should now work...
314 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
315 # ...as should ensuring that it exists, since it now does.
316 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
317 self.checkTable(DYNAMIC_TABLE_SPEC, table)
318 # Trying to get the table with a different specification (at least
319 # in terms of what columns are present) should raise.
320 with self.assertRaises(DatabaseConflictError):
321 newDatabase.ensureTableExists(
322 "d",
323 ddl.TableSpec(
324 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
325 ),
326 )
327 # Calling ensureTableExists inside a transaction block is an error,
328 # even if it would do nothing.
329 with newDatabase.transaction():
330 with self.assertRaises(AssertionError):
331 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
333 def testDynamicTablesConcurrency(self):
334 """Tests for `Database.ensureTableExists` concurrent use."""
335 # We cannot really run things concurrently in a deterministic way, here
336 # we just simulate a situation when the table is created by other
337 # process between the call to getExistingTable() and actual table
338 # creation.
339 db1 = self.makeEmptyDatabase()
340 with db1.declareStaticTables(create=True) as context:
341 context.addTableTuple(STATIC_TABLE_SPECS)
342 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC))
344 # Make a dynamic table using separate connection
345 db2 = self.getNewConnection(db1, writeable=True)
346 with db2.declareStaticTables(create=False) as context:
347 context.addTableTuple(STATIC_TABLE_SPECS)
348 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
349 self.checkTable(DYNAMIC_TABLE_SPEC, table)
351 # Call it again but trick it into thinking that table is not there.
352 # This test depends on knowing implementation of ensureTableExists()
353 # which initially calls getExistingTable() to check that table may
354 # exist, the patch intercepts that call and returns None.
355 with _patch_getExistingTable(db1):
356 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
358 def testTemporaryTables(self):
359 """Tests for `Database.temporary_table`, and `Database.insert` with the
360 ``select`` argument.
361 """
362 # Need to start with the static schema; also insert some test data.
363 newDatabase = self.makeEmptyDatabase()
364 with newDatabase.declareStaticTables(create=True) as context:
365 static = context.addTableTuple(STATIC_TABLE_SPECS)
366 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None})
367 bIds = newDatabase.insert(
368 static.b,
369 {"name": "b1", "value": 11},
370 {"name": "b2", "value": 12},
371 {"name": "b3", "value": 13},
372 returnIds=True,
373 )
374 # Create the table.
375 with newDatabase.session():
376 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1:
377 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
378 # Insert via a INSERT INTO ... SELECT query.
379 newDatabase.insert(
380 table1,
381 select=sqlalchemy.sql.select(
382 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")
383 )
384 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
385 .where(
386 sqlalchemy.sql.and_(
387 static.a.columns.name == "a1",
388 static.b.columns.value <= 12,
389 )
390 ),
391 )
392 # Check that the inserted rows are present.
393 self.assertCountEqual(
394 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
395 [row._asdict() for row in self.query_list(newDatabase, table1.select())],
396 )
397 # Create another one via a read-only connection to the
398 # database. We _do_ allow temporary table modifications in
399 # read-only databases.
400 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
401 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
402 context.addTableTuple(STATIC_TABLE_SPECS)
403 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2:
404 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
405 # Those tables should not be the same, despite having
406 # the same ddl.
407 self.assertIsNot(table1, table2)
408 # Do a slightly different insert into this table, to
409 # check that it works in a read-only database. This
410 # time we pass column names as a kwarg to insert
411 # instead of by labeling the columns in the select.
412 existingReadOnlyDatabase.insert(
413 table2,
414 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id)
415 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)))
416 .where(
417 sqlalchemy.sql.and_(
418 static.a.columns.name == "a2",
419 static.b.columns.value >= 12,
420 )
421 ),
422 names=["a_name", "b_id"],
423 )
424 # Check that the inserted rows are present.
425 self.assertCountEqual(
426 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
427 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())],
428 )
429 # Exiting the context managers will drop the temporary tables from
430 # the read-only DB. It's unspecified whether attempting to use it
431 # after this point is an error or just never returns any results,
432 # so we can't test what it does, only that it's not an error.
434 def testSchemaSeparation(self):
435 """Test that creating two different `Database` instances allows us
436 to create different tables with the same name in each.
437 """
438 db1 = self.makeEmptyDatabase(origin=1)
439 with db1.declareStaticTables(create=True) as context:
440 tables = context.addTableTuple(STATIC_TABLE_SPECS)
441 self.checkStaticSchema(tables)
443 db2 = self.makeEmptyDatabase(origin=2)
444 # Make the DDL here intentionally different so we'll definitely
445 # notice if db1 and db2 are pointing at the same schema.
446 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
447 with db2.declareStaticTables(create=True) as context:
448 # Make the DDL here intentionally different so we'll definitely
449 # notice if db1 and db2 are pointing at the same schema.
450 table = context.addTable("a", spec)
451 self.checkTable(spec, table)
453 def testInsertQueryDelete(self):
454 """Test the `Database.insert`, `Database.query`, and `Database.delete`
455 methods, as well as the `Base64Region` type and the ``onDelete``
456 argument to `ddl.ForeignKeySpec`.
457 """
458 db = self.makeEmptyDatabase(origin=1)
459 with db.declareStaticTables(create=True) as context:
460 tables = context.addTableTuple(STATIC_TABLE_SPECS)
461 # Insert a single, non-autoincrement row that contains a region and
462 # query to get it back.
463 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
464 row = {"name": "a1", "region": region}
465 db.insert(tables.a, row)
466 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row])
467 # Insert multiple autoincrement rows but do not try to get the IDs
468 # back immediately.
469 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
470 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))]
471 self.assertEqual(len(results), 2)
472 for row in results:
473 self.assertIn(row["name"], ("b1", "b2"))
474 self.assertIsInstance(row["id"], int)
475 self.assertGreater(results[1]["id"], results[0]["id"])
476 # Insert multiple autoincrement rows and get the IDs back from insert.
477 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
478 ids = db.insert(tables.b, *rows, returnIds=True)
479 results = [
480 r._asdict()
481 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"]))
482 ]
483 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
484 self.assertCountEqual(results, expected)
485 self.assertTrue(all(result["id"] is not None for result in results))
486 # Insert multiple rows into a table with an autoincrement primary key,
487 # then use the returned IDs to insert into a dynamic table.
488 rows = [{"b_id": results[0]["id"]}, {"b_id": None}]
489 ids = db.insert(tables.c, *rows, returnIds=True)
490 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
491 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)]
492 self.assertCountEqual(results, expected)
493 self.assertTrue(all(result["id"] is not None for result in results))
494 # Add the dynamic table.
495 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
496 # Insert into it.
497 rows = [{"c_id": id, "a_name": "a1"} for id in ids]
498 db.insert(d, *rows)
499 results = [r._asdict() for r in self.query_list(db, d.select())]
500 self.assertCountEqual(rows, results)
501 # Insert multiple rows into a table with an autoincrement primary key,
502 # but pass in a value for the autoincrement key.
503 rows2 = [
504 {"id": 700, "b_id": None},
505 {"id": 701, "b_id": None},
506 ]
507 db.insert(tables.c, *rows2)
508 results = [r._asdict() for r in self.query_list(db, tables.c.select())]
509 self.assertCountEqual(results, expected + rows2)
510 self.assertTrue(all(result["id"] is not None for result in results))
512 # Define 'SELECT COUNT(*)' query for later use.
513 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
514 # Get the values we inserted into table b.
515 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())]
516 # Remove two row from table b by ID.
517 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
518 self.assertEqual(n, 2)
519 # Remove the other two rows from table b by name.
520 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
521 self.assertEqual(n, 2)
522 # There should now be no rows in table b.
523 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
524 # All b_id values in table c should now be NULL, because there's an
525 # onDelete='SET NULL' foreign key.
526 self.assertEqual(
527 self.query_scalar(
528 db,
529 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711
530 ),
531 0,
532 )
533 # Remove all rows in table a (there's only one); this should remove all
534 # rows in d due to onDelete='CASCADE'.
535 n = db.delete(tables.a, [])
536 self.assertEqual(n, 1)
537 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0)
538 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0)
540 def testDeleteWhere(self):
541 """Tests for `Database.deleteWhere`."""
542 db = self.makeEmptyDatabase(origin=1)
543 with db.declareStaticTables(create=True) as context:
544 tables = context.addTableTuple(STATIC_TABLE_SPECS)
545 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)])
546 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count())
548 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2]))
549 self.assertEqual(n, 3)
550 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7)
552 n = db.deleteWhere(
553 tables.b,
554 tables.b.columns.id.in_(
555 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5)
556 ),
557 )
558 self.assertEqual(n, 4)
559 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3)
561 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5")
562 self.assertEqual(n, 1)
563 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2)
565 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True))
566 self.assertEqual(n, 2)
567 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0)
569 def testUpdate(self):
570 """Tests for `Database.update`."""
571 db = self.makeEmptyDatabase(origin=1)
572 with db.declareStaticTables(create=True) as context:
573 tables = context.addTableTuple(STATIC_TABLE_SPECS)
574 # Insert two rows into table a, both without regions.
575 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
576 # Update one of the rows with a region.
577 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
578 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
579 self.assertEqual(n, 1)
580 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a)
581 self.assertCountEqual(
582 [r._asdict() for r in self.query_list(db, sql)],
583 [{"name": "a1", "region": None}, {"name": "a2", "region": region}],
584 )
586 def testSync(self):
587 """Tests for `Database.sync`."""
588 db = self.makeEmptyDatabase(origin=1)
589 with db.declareStaticTables(create=True) as context:
590 tables = context.addTableTuple(STATIC_TABLE_SPECS)
591 # Insert a row with sync, because it doesn't exist yet.
592 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
593 self.assertTrue(inserted)
594 self.assertEqual(
595 [{"id": values["id"], "name": "b1", "value": 10}],
596 [r._asdict() for r in self.query_list(db, tables.b.select())],
597 )
598 # Repeat that operation, which should do nothing but return the
599 # requested values.
600 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
601 self.assertFalse(inserted)
602 self.assertEqual(
603 [{"id": values["id"], "name": "b1", "value": 10}],
604 [r._asdict() for r in self.query_list(db, tables.b.select())],
605 )
606 # Repeat the operation without the 'extra' arg, which should also just
607 # return the existing row.
608 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
609 self.assertFalse(inserted)
610 self.assertEqual(
611 [{"id": values["id"], "name": "b1", "value": 10}],
612 [r._asdict() for r in self.query_list(db, tables.b.select())],
613 )
614 # Repeat the operation with a different value in 'extra'. That still
615 # shouldn't be an error, because 'extra' is only used if we really do
616 # insert. Also drop the 'returning' argument.
617 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
618 self.assertFalse(inserted)
619 self.assertEqual(
620 [{"id": values["id"], "name": "b1", "value": 10}],
621 [r._asdict() for r in self.query_list(db, tables.b.select())],
622 )
623 # Repeat the operation with the correct value in 'compared' instead of
624 # 'extra'.
625 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
626 self.assertFalse(inserted)
627 self.assertEqual(
628 [{"id": values["id"], "name": "b1", "value": 10}],
629 [r._asdict() for r in self.query_list(db, tables.b.select())],
630 )
631 # Repeat the operation with an incorrect value in 'compared'; this
632 # should raise.
633 with self.assertRaises(DatabaseConflictError):
634 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
635 # Try to sync in a read-only database. This should work if and only
636 # if the matching row already exists.
637 with self.asReadOnly(db) as rodb:
638 with rodb.declareStaticTables(create=False) as context:
639 tables = context.addTableTuple(STATIC_TABLE_SPECS)
640 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
641 self.assertFalse(inserted)
642 self.assertEqual(
643 [{"id": values["id"], "name": "b1", "value": 10}],
644 [r._asdict() for r in self.query_list(rodb, tables.b.select())],
645 )
646 with self.assertRaises(ReadOnlyDatabaseError):
647 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
648 # Repeat the operation with a different value in 'compared' and ask to
649 # update.
650 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True)
651 self.assertEqual(updated, {"value": 10})
652 self.assertEqual(
653 [{"id": values["id"], "name": "b1", "value": 20}],
654 [r._asdict() for r in self.query_list(db, tables.b.select())],
655 )
657 def testReplace(self):
658 """Tests for `Database.replace`."""
659 db = self.makeEmptyDatabase(origin=1)
660 with db.declareStaticTables(create=True) as context:
661 tables = context.addTableTuple(STATIC_TABLE_SPECS)
662 # Use 'replace' to insert a single row that contains a region and
663 # query to get it back.
664 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
665 row1 = {"name": "a1", "region": region}
666 db.replace(tables.a, row1)
667 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
668 # Insert another row without a region.
669 row2 = {"name": "a2", "region": None}
670 db.replace(tables.a, row2)
671 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
672 # Use replace to re-insert both of those rows again, which should do
673 # nothing.
674 db.replace(tables.a, row1, row2)
675 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
676 # Replace row1 with a row with no region, while reinserting row2.
677 row1a = {"name": "a1", "region": None}
678 db.replace(tables.a, row1a, row2)
679 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2])
680 # Replace both rows, returning row1 to its original state, while adding
681 # a new one. Pass them in in a different order.
682 row2a = {"name": "a2", "region": region}
683 row3 = {"name": "a3", "region": None}
684 db.replace(tables.a, row3, row2a, row1)
685 self.assertCountEqual(
686 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3]
687 )
689 def test_replace_pkey_only(self):
690 """Test `Database.replace` on a table that only has primary key."""
691 spec = ddl.TableSpec(
692 [
693 ddl.FieldSpec("a1", dtype=sqlalchemy.BigInteger, primaryKey=True),
694 ddl.FieldSpec("a2", dtype=sqlalchemy.BigInteger, primaryKey=True),
695 ]
696 )
697 db = self.makeEmptyDatabase(origin=1)
698 with db.declareStaticTables(create=True) as context:
699 table = context.addTable("a", spec)
700 row1 = {"a1": 1, "a2": 2}
701 row2 = {"a1": 1, "a2": 3}
702 db.replace(table, row1)
703 db.replace(table, row2)
704 db.replace(table, row1)
705 self.assertCountEqual([r._asdict() for r in self.query_list(db, table.select())], [row1, row2])
707 def testEnsure(self):
708 """Tests for `Database.ensure`."""
709 db = self.makeEmptyDatabase(origin=1)
710 with db.declareStaticTables(create=True) as context:
711 tables = context.addTableTuple(STATIC_TABLE_SPECS)
712 # Use 'ensure' to insert a single row that contains a region and
713 # query to get it back.
714 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
715 row1 = {"name": "a1", "region": region}
716 self.assertEqual(db.ensure(tables.a, row1), 1)
717 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1])
718 # Insert another row without a region.
719 row2 = {"name": "a2", "region": None}
720 self.assertEqual(db.ensure(tables.a, row2), 1)
721 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
722 # Use ensure to re-insert both of those rows again, which should do
723 # nothing.
724 self.assertEqual(db.ensure(tables.a, row1, row2), 0)
725 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
726 # Attempt to insert row1's key with no region, while
727 # reinserting row2. This should also do nothing.
728 row1a = {"name": "a1", "region": None}
729 self.assertEqual(db.ensure(tables.a, row1a, row2), 0)
730 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2])
731 # Attempt to insert new rows for both existing keys, this time also
732 # adding a new row. Pass them in in a different order. Only the new
733 # row should be added.
734 row2a = {"name": "a2", "region": region}
735 row3 = {"name": "a3", "region": None}
736 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1)
737 self.assertCountEqual(
738 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3]
739 )
740 # Add some data to a table with both a primary key and a different
741 # unique constraint.
742 row_b = {"id": 5, "name": "five", "value": 50}
743 db.insert(tables.b, row_b)
744 # Attempt ensure with primary_key_only=False and a conflict for the
745 # non-PK constraint. This should do nothing.
746 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200})
747 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
748 # Now use primary_key_only=True with conflict in only the non-PK field.
749 # This should be an integrity error and nothing should change.
750 with self.assertRaises(sqlalchemy.exc.IntegrityError):
751 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True)
752 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
753 # With primary_key_only=True a conflict in the primary key is ignored
754 # regardless of whether there is a conflict elsewhere.
755 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True)
756 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
757 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True)
758 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b])
760 def testTransactionNesting(self):
761 """Test that transactions can be nested with the behavior in the
762 presence of exceptions working as documented.
763 """
764 db = self.makeEmptyDatabase(origin=1)
765 with db.declareStaticTables(create=True) as context:
766 tables = context.addTableTuple(STATIC_TABLE_SPECS)
767 # Insert one row so we can trigger integrity errors by trying to insert
768 # a duplicate of it below.
769 db.insert(tables.a, {"name": "a1"})
770 # First test: error recovery via explicit savepoint=True in the inner
771 # transaction.
772 with db.transaction():
773 # This insert should succeed, and should not be rolled back because
774 # the assertRaises context should catch any exception before it
775 # propagates up to the outer transaction.
776 db.insert(tables.a, {"name": "a2"})
777 with self.assertRaises(sqlalchemy.exc.IntegrityError):
778 with db.transaction(savepoint=True):
779 # This insert should succeed, but should be rolled back.
780 db.insert(tables.a, {"name": "a4"})
781 # This insert should fail (duplicate primary key), raising
782 # an exception.
783 db.insert(tables.a, {"name": "a1"})
784 self.assertCountEqual(
785 [r._asdict() for r in self.query_list(db, tables.a.select())],
786 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
787 )
788 # Second test: error recovery via implicit savepoint=True, when the
789 # innermost transaction is inside a savepoint=True transaction.
790 with db.transaction():
791 # This insert should succeed, and should not be rolled back
792 # because the assertRaises context should catch any
793 # exception before it propagates up to the outer
794 # transaction.
795 db.insert(tables.a, {"name": "a3"})
796 with self.assertRaises(sqlalchemy.exc.IntegrityError):
797 with db.transaction(savepoint=True):
798 # This insert should succeed, but should be rolled back.
799 db.insert(tables.a, {"name": "a4"})
800 with db.transaction():
801 # This insert should succeed, but should be rolled
802 # back.
803 db.insert(tables.a, {"name": "a5"})
804 # This insert should fail (duplicate primary key),
805 # raising an exception.
806 db.insert(tables.a, {"name": "a1"})
807 self.assertCountEqual(
808 [r._asdict() for r in self.query_list(db, tables.a.select())],
809 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
810 )
812 def testTransactionLocking(self):
813 """Test that `Database.transaction` can be used to acquire a lock
814 that prohibits concurrent writes.
815 """
816 db1 = self.makeEmptyDatabase(origin=1)
817 with db1.declareStaticTables(create=True) as context:
818 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
820 async def _side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]:
821 """One side of the concurrent locking test.
823 Parameters
824 ----------
825 lock : `~collections.abc.Iterable` of `str`
826 Locks.
828 Notes
829 -----
830 This optionally locks the table (and maybe the whole database),
831 does a select for its contents, inserts a new row, and then selects
832 again, with some waiting in between to make sure the other side has
833 a chance to _attempt_ to insert in between. If the locking is
834 enabled and works, the difference between the selects should just
835 be the insert done on this thread.
836 """
837 # Give Side2 a chance to create a connection
838 await asyncio.sleep(1.0)
839 with db1.transaction(lock=lock):
840 names1 = {row.name for row in self.query_list(db1, tables1.a.select())}
841 # Give Side2 a chance to insert (which will be blocked if
842 # we've acquired a lock).
843 await asyncio.sleep(2.0)
844 db1.insert(tables1.a, {"name": "a1"})
845 names2 = {row.name for row in self.query_list(db1, tables1.a.select())}
846 return names1, names2
848 async def _side2() -> None:
849 """Other side of the concurrent locking test.
851 Notes
852 -----
853 This side just waits a bit and then tries to insert a row into the
854 table that the other side is trying to lock. Hopefully that
855 waiting is enough to give the other side a chance to acquire the
856 lock and thus make this side block until the lock is released. If
857 this side manages to do the insert before side1 acquires the lock,
858 we'll just warn about not succeeding at testing the locking,
859 because we can only make that unlikely, not impossible.
860 """
862 def _toRunInThread():
863 """Create new SQLite connection for use in thread.
865 SQLite locking isn't asyncio-friendly unless we actually
866 run it in another thread. And SQLite gets very unhappy if
867 we try to use a connection from multiple threads, so we have
868 to create the new connection here instead of out in the main
869 body of the test function.
870 """
871 db2 = self.getNewConnection(db1, writeable=True)
872 with db2.declareStaticTables(create=False) as context:
873 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
874 with db2.transaction():
875 db2.insert(tables2.a, {"name": "a2"})
877 await asyncio.sleep(2.0)
878 loop = asyncio.get_running_loop()
879 with ThreadPoolExecutor() as pool:
880 await loop.run_in_executor(pool, _toRunInThread)
882 async def _testProblemsWithNoLocking() -> None:
883 """Run side1 and side2 with no locking, attempting to demonstrate
884 the problem that locking is supposed to solve. If we get unlucky
885 with scheduling, side2 will just happen to insert after side1 is
886 done, and we won't have anything definitive. We just warn in that
887 case because we really don't want spurious test failures.
888 """
889 task1 = asyncio.create_task(_side1())
890 task2 = asyncio.create_task(_side2())
892 names1, names2 = await task1
893 await task2
894 if "a2" in names1:
895 warnings.warn(
896 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.",
897 stacklevel=1,
898 )
899 self.assertEqual(names1, {"a2"})
900 self.assertEqual(names2, {"a1", "a2"})
901 elif "a2" not in names2:
902 warnings.warn(
903 "Unlucky scheduling in no-locking test: concurrent INSERT "
904 "happened after second SELECT even without locking.",
905 stacklevel=1,
906 )
907 self.assertEqual(names1, set())
908 self.assertEqual(names2, {"a1"})
909 else:
910 # This is the expected case: both INSERTS happen between the
911 # two SELECTS. If we don't get this almost all of the time we
912 # should adjust the sleep amounts.
913 self.assertEqual(names1, set())
914 self.assertEqual(names2, {"a1", "a2"})
916 asyncio.run(_testProblemsWithNoLocking())
918 # Clean up after first test.
919 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
921 async def _testSolutionWithLocking() -> None:
922 """Run side1 and side2 with locking, which should make side2 block
923 its insert until side2 releases its lock.
924 """
925 task1 = asyncio.create_task(_side1(lock=[tables1.a]))
926 task2 = asyncio.create_task(_side2())
928 names1, names2 = await task1
929 await task2
930 if "a2" in names1:
931 warnings.warn(
932 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.",
933 stacklevel=1,
934 )
935 self.assertEqual(names1, {"a2"})
936 self.assertEqual(names2, {"a1", "a2"})
937 else:
938 # This is the expected case: the side2 INSERT happens after the
939 # last SELECT on side1. This can also happen due to unlucky
940 # scheduling, and we have no way to detect that here, but the
941 # similar "no-locking" test has at least some chance of being
942 # affected by the same problem and warning about it.
943 self.assertEqual(names1, set())
944 self.assertEqual(names2, {"a1"})
946 asyncio.run(_testSolutionWithLocking())
948 def testTimespanDatabaseRepresentation(self):
949 """Tests for `TimespanDatabaseRepresentation` and the `Database`
950 methods that interact with it.
951 """
952 # Make some test timespans to play with, with the full suite of
953 # topological relationships.
954 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
955 offset = astropy.time.TimeDelta(60, format="sec")
956 timestamps = [start + offset * n for n in range(3)]
957 aTimespans = [Timespan(begin=None, end=None)]
958 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps)
959 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps)
960 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps)
961 aTimespans.append(Timespan.makeEmpty())
962 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2))
963 # Make another list of timespans that span the full range but don't
964 # overlap. This is a subset of the previous list.
965 bTimespans = [Timespan(begin=None, end=timestamps[0])]
966 bTimespans.extend(
967 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True)
968 )
969 bTimespans.append(Timespan(begin=timestamps[-1], end=None))
970 # Make a database and create a table with that database's timespan
971 # representation. This one will have no exclusion constraint and
972 # a nullable timespan.
973 db = self.makeEmptyDatabase(origin=1)
974 TimespanReprClass = db.getTimespanRepresentation()
975 aSpec = ddl.TableSpec(
976 fields=[
977 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
978 ],
979 )
980 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
981 aSpec.fields.add(fieldSpec)
982 with db.declareStaticTables(create=True) as context:
983 aTable = context.addTable("a", aSpec)
984 self.maxDiff = None
986 def _convertRowForInsert(row: dict) -> dict:
987 """Convert a row containing a Timespan instance into one suitable
988 for insertion into the database.
989 """
990 result = row.copy()
991 ts = result.pop(TimespanReprClass.NAME)
992 return TimespanReprClass.update(ts, result=result)
994 def _convertRowFromSelect(row: dict) -> dict:
995 """Convert a row from the database into one containing a Timespan.
997 Parameters
998 ----------
999 row : `dict`
1000 Original row.
1002 Returns
1003 -------
1004 row : `dict`
1005 The updated row.
1006 """
1007 result = row.copy()
1008 timespan = TimespanReprClass.extract(result)
1009 for name in TimespanReprClass.getFieldNames():
1010 del result[name]
1011 result[TimespanReprClass.NAME] = timespan
1012 return result
1014 # Insert rows into table A, in chunks just to make things interesting.
1015 # Include one with a NULL timespan.
1016 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)]
1017 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None})
1018 db.insert(aTable, _convertRowForInsert(aRows[0]))
1019 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[1:3]])
1020 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[3:]])
1021 # Add another one with a NULL timespan, but this time by invoking
1022 # the server-side default.
1023 aRows.append({"id": len(aRows)})
1024 db.insert(aTable, aRows[-1])
1025 aRows[-1][TimespanReprClass.NAME] = None
1026 # Test basic round-trip through database.
1027 self.assertEqual(
1028 aRows,
1029 [
1030 _convertRowFromSelect(row._asdict())
1031 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id))
1032 ],
1033 )
1034 # Create another table B with a not-null timespan and (if the database
1035 # supports it), an exclusion constraint. Use ensureTableExists this
1036 # time to check that mode of table creation vs. timespans.
1037 bSpec = ddl.TableSpec(
1038 fields=[
1039 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True),
1040 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False),
1041 ],
1042 )
1043 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False):
1044 bSpec.fields.add(fieldSpec)
1045 if TimespanReprClass.hasExclusionConstraint():
1046 bSpec.exclusion.add(("key", TimespanReprClass))
1047 bTable = db.ensureTableExists("b", bSpec)
1048 # Insert rows into table B, again in chunks. Each Timespan appears
1049 # twice, but with different values for the 'key' field (which should
1050 # still be okay for any exclusion constraint we may have defined).
1051 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)]
1052 offset = len(bRows)
1053 bRows.extend(
1054 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)
1055 )
1056 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[:2]])
1057 db.insert(bTable, _convertRowForInsert(bRows[2]))
1058 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[3:]])
1059 # Insert a row with no timespan into table B. This should invoke the
1060 # server-side default, which is a timespan over (-∞, ∞). We set
1061 # key=3 to avoid upsetting an exclusion constraint that might exist.
1062 bRows.append({"id": len(bRows), "key": 3})
1063 db.insert(bTable, bRows[-1])
1064 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None)
1065 # Test basic round-trip through database.
1066 self.assertEqual(
1067 bRows,
1068 [
1069 _convertRowFromSelect(row._asdict())
1070 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id))
1071 ],
1072 )
1073 # Test that we can't insert timespan=None into this table.
1074 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1075 db.insert(
1076 bTable, _convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})
1077 )
1078 # IFF this database supports exclusion constraints, test that they
1079 # also prevent inserts.
1080 if TimespanReprClass.hasExclusionConstraint():
1081 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1082 db.insert(
1083 bTable,
1084 _convertRowForInsert(
1085 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])}
1086 ),
1087 )
1088 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1089 db.insert(
1090 bTable,
1091 _convertRowForInsert(
1092 {
1093 "id": len(bRows),
1094 "key": 1,
1095 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]),
1096 }
1097 ),
1098 )
1099 with self.assertRaises(sqlalchemy.exc.IntegrityError):
1100 db.insert(
1101 bTable,
1102 _convertRowForInsert(
1103 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)}
1104 ),
1105 )
1106 # Test NULL checks in SELECT queries, on both tables.
1107 aRepr = TimespanReprClass.from_columns(aTable.columns)
1108 self.assertEqual(
1109 [row[TimespanReprClass.NAME] is None for row in aRows],
1110 [
1111 row.f
1112 for row in self.query_list(
1113 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id)
1114 )
1115 ],
1116 )
1117 bRepr = TimespanReprClass.from_columns(bTable.columns)
1118 self.assertEqual(
1119 [False for row in bRows],
1120 [
1121 row.f
1122 for row in self.query_list(
1123 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id)
1124 )
1125 ],
1126 )
1127 # Test relationships expressions that relate in-database timespans to
1128 # Python-literal timespans, all from the more complete 'a' set; check
1129 # that this is consistent with Python-only relationship tests.
1130 for rhsRow in aRows:
1131 if rhsRow[TimespanReprClass.NAME] is None:
1132 continue
1133 with self.subTest(rhsRow=repr(rhsRow)):
1134 expected = {}
1135 for lhsRow in aRows:
1136 if lhsRow[TimespanReprClass.NAME] is None:
1137 expected[lhsRow["id"]] = (None, None, None, None)
1138 else:
1139 expected[lhsRow["id"]] = (
1140 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]),
1141 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]),
1142 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME],
1143 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME],
1144 )
1145 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME])
1146 sql = sqlalchemy.sql.select(
1147 aTable.columns.id.label("lhs"),
1148 aRepr.overlaps(rhsRepr).label("overlaps"),
1149 aRepr.contains(rhsRepr).label("contains"),
1150 (aRepr < rhsRepr).label("less_than"),
1151 (aRepr > rhsRepr).label("greater_than"),
1152 ).select_from(aTable)
1153 queried = {
1154 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than)
1155 for row in self.query_list(db, sql)
1156 }
1157 self.assertEqual(expected, queried)
1158 # Test relationship expressions that relate in-database timespans to
1159 # each other, all from the more complete 'a' set; check that this is
1160 # consistent with Python-only relationship tests.
1161 expected = {}
1162 for lhs, rhs in itertools.product(aRows, aRows):
1163 lhsT = lhs[TimespanReprClass.NAME]
1164 rhsT = rhs[TimespanReprClass.NAME]
1165 if lhsT is not None and rhsT is not None:
1166 expected[lhs["id"], rhs["id"]] = (
1167 lhsT.overlaps(rhsT),
1168 lhsT.contains(rhsT),
1169 lhsT < rhsT,
1170 lhsT > rhsT,
1171 )
1172 else:
1173 expected[lhs["id"], rhs["id"]] = (None, None, None, None)
1174 lhsSubquery = aTable.alias("lhs")
1175 rhsSubquery = aTable.alias("rhs")
1176 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns)
1177 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns)
1178 sql = sqlalchemy.sql.select(
1179 lhsSubquery.columns.id.label("lhs"),
1180 rhsSubquery.columns.id.label("rhs"),
1181 lhsRepr.overlaps(rhsRepr).label("overlaps"),
1182 lhsRepr.contains(rhsRepr).label("contains"),
1183 (lhsRepr < rhsRepr).label("less_than"),
1184 (lhsRepr > rhsRepr).label("greater_than"),
1185 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True)))
1186 queried = {
1187 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than)
1188 for row in self.query_list(db, sql)
1189 }
1190 self.assertEqual(expected, queried)
1191 # Test relationship expressions between in-database timespans and
1192 # Python-literal instantaneous times.
1193 for t in timestamps:
1194 with self.subTest(t=repr(t)):
1195 expected = {}
1196 for lhsRow in aRows:
1197 if lhsRow[TimespanReprClass.NAME] is None:
1198 expected[lhsRow["id"]] = (None, None, None, None)
1199 else:
1200 expected[lhsRow["id"]] = (
1201 lhsRow[TimespanReprClass.NAME].contains(t),
1202 lhsRow[TimespanReprClass.NAME].overlaps(t),
1203 lhsRow[TimespanReprClass.NAME] < t,
1204 lhsRow[TimespanReprClass.NAME] > t,
1205 )
1206 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai)
1207 sql = sqlalchemy.sql.select(
1208 aTable.columns.id.label("lhs"),
1209 aRepr.contains(rhs).label("contains"),
1210 aRepr.overlaps(rhs).label("overlaps_point"),
1211 (aRepr < rhs).label("less_than"),
1212 (aRepr > rhs).label("greater_than"),
1213 ).select_from(aTable)
1214 queried = {
1215 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than)
1216 for row in self.query_list(db, sql)
1217 }
1218 self.assertEqual(expected, queried)
1220 def testConstantRows(self):
1221 """Test Database.constant_rows."""
1222 new_db = self.makeEmptyDatabase()
1223 with new_db.declareStaticTables(create=True) as context:
1224 static = context.addTableTuple(STATIC_TABLE_SPECS)
1225 b_ids = new_db.insert(
1226 static.b,
1227 {"name": "b1", "value": 11},
1228 {"name": "b2", "value": 12},
1229 {"name": "b3", "value": 13},
1230 returnIds=True,
1231 )
1232 values_spec = ddl.TableSpec(
1233 [
1234 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger),
1235 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)),
1236 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()),
1237 ],
1238 )
1239 values_data = [
1240 {"b": b_ids[0], "s": "b1", "r": None},
1241 {"b": b_ids[2], "s": "b3", "r": Circle.empty()},
1242 ]
1243 values = new_db.constant_rows(values_spec.fields, *values_data)
1244 select_values_alone = sqlalchemy.sql.select(
1245 values.columns["b"], values.columns["s"], values.columns["r"]
1246 )
1247 self.assertCountEqual(
1248 [row._mapping for row in self.query_list(new_db, select_values_alone)],
1249 values_data,
1250 )
1251 select_values_joined = sqlalchemy.sql.select(
1252 values.columns["s"].label("name"), static.b.columns["value"].label("value")
1253 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"]))
1254 self.assertCountEqual(
1255 [row._mapping for row in self.query_list(new_db, select_values_joined)],
1256 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}],
1257 )
1259 def test_aggregate(self) -> None:
1260 """Test Database.apply_any_aggregate, ddl.Base64Region.union_aggregate,
1261 and TimespanDatabaseRepresetnation.apply_any_aggregate.
1262 """
1263 db = self.makeEmptyDatabase()
1264 with db.declareStaticTables(create=True) as context:
1265 t = context.addTable(
1266 "t",
1267 ddl.TableSpec(
1268 [
1269 ddl.FieldSpec("id", sqlalchemy.BigInteger(), nullable=False),
1270 ddl.FieldSpec("name", sqlalchemy.String(16), nullable=False),
1271 ddl.FieldSpec.for_region(),
1272 ]
1273 + list(db.getTimespanRepresentation().makeFieldSpecs(nullable=True)),
1274 ),
1275 )
1276 pixelization = Mq3cPixelization(10)
1277 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai")
1278 offset = astropy.time.TimeDelta(60, format="sec")
1279 timespans = [Timespan(begin=start + offset * n, end=start + offset * (n + 1)) for n in range(3)]
1280 ts_cls = db.getTimespanRepresentation()
1281 ts_col = ts_cls.from_columns(t.columns)
1282 db.insert(
1283 t,
1284 ts_cls.update(timespans[0], result={"id": 1, "name": "a", "region": pixelization.quad(12058870)}),
1285 ts_cls.update(timespans[1], result={"id": 2, "name": "a", "region": pixelization.quad(12058871)}),
1286 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058872)}),
1287 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058873)}),
1288 )
1289 # This should use DISTINCT ON in PostgreSQL and GROUP BY in SQLite.
1290 if db.has_distinct_on:
1291 sql = (
1292 sqlalchemy.select(
1293 t.c.id.label("i"),
1294 t.c.name.label("n"),
1295 *ts_col.flatten("t"),
1296 )
1297 .select_from(t)
1298 .distinct(t.c.id)
1299 )
1300 elif db.has_any_aggregate:
1301 sql = (
1302 sqlalchemy.select(
1303 t.c.id.label("i"),
1304 db.apply_any_aggregate(t.c.name).label("n"),
1305 *ts_col.apply_any_aggregate(db.apply_any_aggregate).flatten("t"),
1306 )
1307 .select_from(t)
1308 .group_by(t.c.id)
1309 )
1310 else:
1311 raise AssertionError(
1312 "PostgreSQL should support DISTINCT ON and SQLite should support no-op any aggregates."
1313 )
1314 self.assertCountEqual(
1315 [(row.i, row.n, ts_cls.extract(row._mapping, "t")) for row in self.query_list(db, sql)],
1316 [(1, "a", timespans[0]), (2, "a", timespans[1]), (3, "b", timespans[2])],
1317 )
1318 # Test union_aggregate in all versions of both database, with a GROUP
1319 # BY that does not need apply_any_aggregate.
1320 self.assertCountEqual(
1321 [
1322 (row.i, row.r.encode())
1323 for row in self.query_list(
1324 db,
1325 sqlalchemy.select(
1326 t.c.id.label("i"), ddl.Base64Region.union_aggregate(t.c.region).label("r")
1327 )
1328 .select_from(t)
1329 .group_by(t.c.id),
1330 )
1331 ],
1332 [
1333 (1, pixelization.quad(12058870).encode()),
1334 (2, pixelization.quad(12058871).encode()),
1335 (3, UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()),
1336 ],
1337 )
1338 if db.has_any_aggregate:
1339 # This should use run in SQLite and PostgreSQL 16+.
1340 self.assertCountEqual(
1341 [
1342 (row.i, row.n, row.r.encode())
1343 for row in self.query_list(
1344 db,
1345 sqlalchemy.select(
1346 t.c.id.label("i"),
1347 db.apply_any_aggregate(t.c.name).label("n"),
1348 ddl.Base64Region.union_aggregate(t.c.region).label("r"),
1349 )
1350 .select_from(t)
1351 .group_by(t.c.id),
1352 )
1353 ],
1354 [
1355 (1, "a", pixelization.quad(12058870).encode()),
1356 (2, "a", pixelization.quad(12058871).encode()),
1357 (3, "b", UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()),
1358 ],
1359 )