Coverage for python / lsst / daf / butler / registry / tests / _database.py: 9%

522 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:49 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ["DatabaseTests"] 

32 

33import asyncio 

34import itertools 

35import warnings 

36from abc import ABC, abstractmethod 

37from collections import namedtuple 

38from collections.abc import Iterable 

39from concurrent.futures import ThreadPoolExecutor 

40from contextlib import AbstractContextManager, contextmanager 

41from typing import Any 

42 

43import astropy.time 

44import sqlalchemy 

45 

46from lsst.sphgeom import Circle, ConvexPolygon, Mq3cPixelization, UnionRegion, UnitVector3d 

47 

48from ..._timespan import Timespan 

49from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

50 

51StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

52 

53STATIC_TABLE_SPECS = StaticTablesTuple( 

54 a=ddl.TableSpec( 

55 fields=[ 

56 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

57 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

58 ] 

59 ), 

60 b=ddl.TableSpec( 

61 fields=[ 

62 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

63 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

64 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

65 ], 

66 unique=[("name",)], 

67 ), 

68 c=ddl.TableSpec( 

69 fields=[ 

70 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

71 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

72 ], 

73 foreignKeys=[ 

74 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

75 ], 

76 ), 

77) 

78 

79DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

80 fields=[ 

81 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

82 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

83 ], 

84 foreignKeys=[ 

85 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"), 

86 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

87 ], 

88) 

89 

90TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

91 fields=[ 

92 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

93 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

94 ], 

95) 

96 

97 

98@contextmanager 

99def _patch_getExistingTable(db: Database) -> Database: 

100 """Patch getExistingTable method in a database instance to test concurrent 

101 creation of tables. This patch obviously depends on knowning internals of 

102 ``ensureTableExists()`` implementation. 

103 """ 

104 original_method = db.getExistingTable 

105 

106 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None: 

107 # Return None on first call, but forward to original method after that 

108 db.getExistingTable = original_method 

109 return None 

110 

111 db.getExistingTable = _getExistingTable 

112 yield db 

113 db.getExistingTable = original_method 

114 

115 

116class DatabaseTests(ABC): 

117 """Generic tests for the `Database` interface that can be subclassed to 

118 generate tests for concrete implementations. 

119 """ 

120 

121 @abstractmethod 

122 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

123 """Return an empty `Database` with the given origin, or an 

124 automatically-generated one if ``origin`` is `None`. 

125 

126 Parameters 

127 ---------- 

128 origin : `int` or `None` 

129 Origin to use for the database. 

130 

131 Returns 

132 ------- 

133 db : `Database` 

134 Empty database with given origin or auto-generated origin. 

135 """ 

136 raise NotImplementedError() 

137 

138 @abstractmethod 

139 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]: 

140 """Return a context manager for a read-only connection into the given 

141 database. 

142 

143 Parameters 

144 ---------- 

145 database : `Database` 

146 The database to use. 

147 

148 Yields 

149 ------ 

150 `Database` 

151 The new database connection. 

152 

153 Notes 

154 ----- 

155 The original database should be considered unusable within the context 

156 but safe to use again afterwards (this allows the context manager to 

157 block write access by temporarily changing user permissions to really 

158 guarantee that write operations are not performed). 

159 """ 

160 raise NotImplementedError() 

161 

162 @abstractmethod 

163 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

164 """Return a new `Database` instance that points to the same underlying 

165 storage as the given one. 

166 

167 Parameters 

168 ---------- 

169 database : `Database` 

170 The current database. 

171 writeable : `bool` 

172 Whether the connection should be writeable or not. 

173 

174 Returns 

175 ------- 

176 db : `Database` 

177 The new database connection. 

178 """ 

179 raise NotImplementedError() 

180 

181 def query_list( 

182 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase 

183 ) -> list[sqlalchemy.engine.Row]: 

184 """Run a SELECT or other read-only query against the database and 

185 return the results as a list. 

186 

187 Parameters 

188 ---------- 

189 database : `Database` 

190 The database to use. 

191 executable : `sqlalchemy.sql.expression.SelectBase` 

192 Expression to execute. 

193 

194 Returns 

195 ------- 

196 results : `list` of `sqlalchemy.engine.Row` 

197 The results. 

198 

199 Notes 

200 ----- 

201 This is a thin wrapper around database.query() that just avoids 

202 context-manager boilerplate that is usefully verbose in production code 

203 but just noise in tests. 

204 """ 

205 with database.transaction(), database.query(executable) as result: 

206 return result.fetchall() 

207 

208 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any: 

209 """Run a SELECT query that yields a single column and row against the 

210 database and return its value. 

211 

212 Parameters 

213 ---------- 

214 database : `Database` 

215 The database to use. 

216 executable : `sqlalchemy.sql.expression.SelectBase` 

217 Expression to execute. 

218 

219 Returns 

220 ------- 

221 results : `~typing.Any` 

222 The results. 

223 

224 Notes 

225 ----- 

226 This is a thin wrapper around database.query() that just avoids 

227 context-manager boilerplate that is usefully verbose in production code 

228 but just noise in tests. 

229 """ 

230 with database.query(executable) as result: 

231 return result.scalar() 

232 

233 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

234 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

235 # Checking more than this currently seems fragile, as it might restrict 

236 # what Database implementations do; we don't care if the spec is 

237 # actually preserved in terms of types and constraints as long as we 

238 # can use the returned table as if it was. 

239 

240 def checkStaticSchema(self, tables: StaticTablesTuple): 

241 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

242 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

243 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

244 

245 def testDeclareStaticTables(self): 

246 """Tests for `Database.declareStaticSchema` and the methods it 

247 delegates to. 

248 """ 

249 # Create the static schema in a new, empty database. 

250 newDatabase = self.makeEmptyDatabase() 

251 with newDatabase.declareStaticTables(create=True) as context: 

252 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

253 self.checkStaticSchema(tables) 

254 # Check that we can load that schema even from a read-only connection. 

255 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

256 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

257 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

258 self.checkStaticSchema(tables) 

259 

260 def testDeclareStaticTablesTwice(self): 

261 """Tests for `Database.declareStaticSchema` being called twice.""" 

262 # Create the static schema in a new, empty database. 

263 newDatabase = self.makeEmptyDatabase() 

264 with newDatabase.declareStaticTables(create=True) as context: 

265 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

266 self.checkStaticSchema(tables) 

267 # Second time it should raise 

268 with self.assertRaises(SchemaAlreadyDefinedError): 

269 with newDatabase.declareStaticTables(create=True) as context: 

270 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

271 # Check schema, it should still contain all tables, and maybe some 

272 # extra. 

273 with newDatabase.declareStaticTables(create=False) as context: 

274 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

275 

276 def testRepr(self): 

277 """Test that repr does not return a generic thing.""" 

278 newDatabase = self.makeEmptyDatabase() 

279 rep = repr(newDatabase) 

280 # Check that stringification works and gives us something different 

281 self.assertNotEqual(rep, str(newDatabase)) 

282 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

283 self.assertIn("://", rep) 

284 

285 def testDynamicTables(self): 

286 """Tests for `Database.ensureTableExists` and 

287 `Database.getExistingTable`. 

288 """ 

289 # Need to start with the static schema. 

290 newDatabase = self.makeEmptyDatabase() 

291 with newDatabase.declareStaticTables(create=True) as context: 

292 context.addTableTuple(STATIC_TABLE_SPECS) 

293 # Try to ensure the dynamic table exists in a read-only version of that 

294 # database, which should fail because we can't create it. 

295 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

296 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

297 context.addTableTuple(STATIC_TABLE_SPECS) 

298 with self.assertRaises(ReadOnlyDatabaseError): 

299 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

300 # Just getting the dynamic table before it exists should return None. 

301 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

302 # Ensure the new table exists back in the original database, which 

303 # should create it. 

304 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

305 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

306 # Ensuring that it exists should just return the exact same table 

307 # instance again. 

308 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

309 # Try again from the read-only database. 

310 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

311 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

312 context.addTableTuple(STATIC_TABLE_SPECS) 

313 # Just getting the dynamic table should now work... 

314 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

315 # ...as should ensuring that it exists, since it now does. 

316 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

317 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

318 # Trying to get the table with a different specification (at least 

319 # in terms of what columns are present) should raise. 

320 with self.assertRaises(DatabaseConflictError): 

321 newDatabase.ensureTableExists( 

322 "d", 

323 ddl.TableSpec( 

324 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

325 ), 

326 ) 

327 # Calling ensureTableExists inside a transaction block is an error, 

328 # even if it would do nothing. 

329 with newDatabase.transaction(): 

330 with self.assertRaises(AssertionError): 

331 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

332 

333 def testDynamicTablesConcurrency(self): 

334 """Tests for `Database.ensureTableExists` concurrent use.""" 

335 # We cannot really run things concurrently in a deterministic way, here 

336 # we just simulate a situation when the table is created by other 

337 # process between the call to getExistingTable() and actual table 

338 # creation. 

339 db1 = self.makeEmptyDatabase() 

340 with db1.declareStaticTables(create=True) as context: 

341 context.addTableTuple(STATIC_TABLE_SPECS) 

342 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

343 

344 # Make a dynamic table using separate connection 

345 db2 = self.getNewConnection(db1, writeable=True) 

346 with db2.declareStaticTables(create=False) as context: 

347 context.addTableTuple(STATIC_TABLE_SPECS) 

348 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

349 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

350 

351 # Call it again but trick it into thinking that table is not there. 

352 # This test depends on knowing implementation of ensureTableExists() 

353 # which initially calls getExistingTable() to check that table may 

354 # exist, the patch intercepts that call and returns None. 

355 with _patch_getExistingTable(db1): 

356 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

357 

358 def testTemporaryTables(self): 

359 """Tests for `Database.temporary_table`, and `Database.insert` with the 

360 ``select`` argument. 

361 """ 

362 # Need to start with the static schema; also insert some test data. 

363 newDatabase = self.makeEmptyDatabase() 

364 with newDatabase.declareStaticTables(create=True) as context: 

365 static = context.addTableTuple(STATIC_TABLE_SPECS) 

366 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

367 bIds = newDatabase.insert( 

368 static.b, 

369 {"name": "b1", "value": 11}, 

370 {"name": "b2", "value": 12}, 

371 {"name": "b3", "value": 13}, 

372 returnIds=True, 

373 ) 

374 # Create the table. 

375 with newDatabase.session(): 

376 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1: 

377 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

378 # Insert via a INSERT INTO ... SELECT query. 

379 newDatabase.insert( 

380 table1, 

381 select=sqlalchemy.sql.select( 

382 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

383 ) 

384 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

385 .where( 

386 sqlalchemy.sql.and_( 

387 static.a.columns.name == "a1", 

388 static.b.columns.value <= 12, 

389 ) 

390 ), 

391 ) 

392 # Check that the inserted rows are present. 

393 self.assertCountEqual( 

394 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

395 [row._asdict() for row in self.query_list(newDatabase, table1.select())], 

396 ) 

397 # Create another one via a read-only connection to the 

398 # database. We _do_ allow temporary table modifications in 

399 # read-only databases. 

400 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

401 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

402 context.addTableTuple(STATIC_TABLE_SPECS) 

403 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2: 

404 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

405 # Those tables should not be the same, despite having 

406 # the same ddl. 

407 self.assertIsNot(table1, table2) 

408 # Do a slightly different insert into this table, to 

409 # check that it works in a read-only database. This 

410 # time we pass column names as a kwarg to insert 

411 # instead of by labeling the columns in the select. 

412 existingReadOnlyDatabase.insert( 

413 table2, 

414 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

415 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

416 .where( 

417 sqlalchemy.sql.and_( 

418 static.a.columns.name == "a2", 

419 static.b.columns.value >= 12, 

420 ) 

421 ), 

422 names=["a_name", "b_id"], 

423 ) 

424 # Check that the inserted rows are present. 

425 self.assertCountEqual( 

426 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

427 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())], 

428 ) 

429 # Exiting the context managers will drop the temporary tables from 

430 # the read-only DB. It's unspecified whether attempting to use it 

431 # after this point is an error or just never returns any results, 

432 # so we can't test what it does, only that it's not an error. 

433 

434 def testSchemaSeparation(self): 

435 """Test that creating two different `Database` instances allows us 

436 to create different tables with the same name in each. 

437 """ 

438 db1 = self.makeEmptyDatabase(origin=1) 

439 with db1.declareStaticTables(create=True) as context: 

440 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

441 self.checkStaticSchema(tables) 

442 

443 db2 = self.makeEmptyDatabase(origin=2) 

444 # Make the DDL here intentionally different so we'll definitely 

445 # notice if db1 and db2 are pointing at the same schema. 

446 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

447 with db2.declareStaticTables(create=True) as context: 

448 # Make the DDL here intentionally different so we'll definitely 

449 # notice if db1 and db2 are pointing at the same schema. 

450 table = context.addTable("a", spec) 

451 self.checkTable(spec, table) 

452 

453 def testInsertQueryDelete(self): 

454 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

455 methods, as well as the `Base64Region` type and the ``onDelete`` 

456 argument to `ddl.ForeignKeySpec`. 

457 """ 

458 db = self.makeEmptyDatabase(origin=1) 

459 with db.declareStaticTables(create=True) as context: 

460 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

461 # Insert a single, non-autoincrement row that contains a region and 

462 # query to get it back. 

463 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

464 row = {"name": "a1", "region": region} 

465 db.insert(tables.a, row) 

466 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row]) 

467 # Insert multiple autoincrement rows but do not try to get the IDs 

468 # back immediately. 

469 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

470 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))] 

471 self.assertEqual(len(results), 2) 

472 for row in results: 

473 self.assertIn(row["name"], ("b1", "b2")) 

474 self.assertIsInstance(row["id"], int) 

475 self.assertGreater(results[1]["id"], results[0]["id"]) 

476 # Insert multiple autoincrement rows and get the IDs back from insert. 

477 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

478 ids = db.insert(tables.b, *rows, returnIds=True) 

479 results = [ 

480 r._asdict() 

481 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

482 ] 

483 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

484 self.assertCountEqual(results, expected) 

485 self.assertTrue(all(result["id"] is not None for result in results)) 

486 # Insert multiple rows into a table with an autoincrement primary key, 

487 # then use the returned IDs to insert into a dynamic table. 

488 rows = [{"b_id": results[0]["id"]}, {"b_id": None}] 

489 ids = db.insert(tables.c, *rows, returnIds=True) 

490 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

491 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

492 self.assertCountEqual(results, expected) 

493 self.assertTrue(all(result["id"] is not None for result in results)) 

494 # Add the dynamic table. 

495 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

496 # Insert into it. 

497 rows = [{"c_id": id, "a_name": "a1"} for id in ids] 

498 db.insert(d, *rows) 

499 results = [r._asdict() for r in self.query_list(db, d.select())] 

500 self.assertCountEqual(rows, results) 

501 # Insert multiple rows into a table with an autoincrement primary key, 

502 # but pass in a value for the autoincrement key. 

503 rows2 = [ 

504 {"id": 700, "b_id": None}, 

505 {"id": 701, "b_id": None}, 

506 ] 

507 db.insert(tables.c, *rows2) 

508 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

509 self.assertCountEqual(results, expected + rows2) 

510 self.assertTrue(all(result["id"] is not None for result in results)) 

511 

512 # Define 'SELECT COUNT(*)' query for later use. 

513 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

514 # Get the values we inserted into table b. 

515 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())] 

516 # Remove two row from table b by ID. 

517 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

518 self.assertEqual(n, 2) 

519 # Remove the other two rows from table b by name. 

520 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

521 self.assertEqual(n, 2) 

522 # There should now be no rows in table b. 

523 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

524 # All b_id values in table c should now be NULL, because there's an 

525 # onDelete='SET NULL' foreign key. 

526 self.assertEqual( 

527 self.query_scalar( 

528 db, 

529 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711 

530 ), 

531 0, 

532 ) 

533 # Remove all rows in table a (there's only one); this should remove all 

534 # rows in d due to onDelete='CASCADE'. 

535 n = db.delete(tables.a, []) 

536 self.assertEqual(n, 1) 

537 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0) 

538 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0) 

539 

540 def testDeleteWhere(self): 

541 """Tests for `Database.deleteWhere`.""" 

542 db = self.makeEmptyDatabase(origin=1) 

543 with db.declareStaticTables(create=True) as context: 

544 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

545 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

546 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

547 

548 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

549 self.assertEqual(n, 3) 

550 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7) 

551 

552 n = db.deleteWhere( 

553 tables.b, 

554 tables.b.columns.id.in_( 

555 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

556 ), 

557 ) 

558 self.assertEqual(n, 4) 

559 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3) 

560 

561 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

562 self.assertEqual(n, 1) 

563 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2) 

564 

565 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

566 self.assertEqual(n, 2) 

567 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

568 

569 def testUpdate(self): 

570 """Tests for `Database.update`.""" 

571 db = self.makeEmptyDatabase(origin=1) 

572 with db.declareStaticTables(create=True) as context: 

573 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

574 # Insert two rows into table a, both without regions. 

575 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

576 # Update one of the rows with a region. 

577 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

578 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

579 self.assertEqual(n, 1) 

580 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

581 self.assertCountEqual( 

582 [r._asdict() for r in self.query_list(db, sql)], 

583 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

584 ) 

585 

586 def testSync(self): 

587 """Tests for `Database.sync`.""" 

588 db = self.makeEmptyDatabase(origin=1) 

589 with db.declareStaticTables(create=True) as context: 

590 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

591 # Insert a row with sync, because it doesn't exist yet. 

592 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

593 self.assertTrue(inserted) 

594 self.assertEqual( 

595 [{"id": values["id"], "name": "b1", "value": 10}], 

596 [r._asdict() for r in self.query_list(db, tables.b.select())], 

597 ) 

598 # Repeat that operation, which should do nothing but return the 

599 # requested values. 

600 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

601 self.assertFalse(inserted) 

602 self.assertEqual( 

603 [{"id": values["id"], "name": "b1", "value": 10}], 

604 [r._asdict() for r in self.query_list(db, tables.b.select())], 

605 ) 

606 # Repeat the operation without the 'extra' arg, which should also just 

607 # return the existing row. 

608 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

609 self.assertFalse(inserted) 

610 self.assertEqual( 

611 [{"id": values["id"], "name": "b1", "value": 10}], 

612 [r._asdict() for r in self.query_list(db, tables.b.select())], 

613 ) 

614 # Repeat the operation with a different value in 'extra'. That still 

615 # shouldn't be an error, because 'extra' is only used if we really do 

616 # insert. Also drop the 'returning' argument. 

617 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

618 self.assertFalse(inserted) 

619 self.assertEqual( 

620 [{"id": values["id"], "name": "b1", "value": 10}], 

621 [r._asdict() for r in self.query_list(db, tables.b.select())], 

622 ) 

623 # Repeat the operation with the correct value in 'compared' instead of 

624 # 'extra'. 

625 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

626 self.assertFalse(inserted) 

627 self.assertEqual( 

628 [{"id": values["id"], "name": "b1", "value": 10}], 

629 [r._asdict() for r in self.query_list(db, tables.b.select())], 

630 ) 

631 # Repeat the operation with an incorrect value in 'compared'; this 

632 # should raise. 

633 with self.assertRaises(DatabaseConflictError): 

634 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

635 # Try to sync in a read-only database. This should work if and only 

636 # if the matching row already exists. 

637 with self.asReadOnly(db) as rodb: 

638 with rodb.declareStaticTables(create=False) as context: 

639 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

640 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

641 self.assertFalse(inserted) 

642 self.assertEqual( 

643 [{"id": values["id"], "name": "b1", "value": 10}], 

644 [r._asdict() for r in self.query_list(rodb, tables.b.select())], 

645 ) 

646 with self.assertRaises(ReadOnlyDatabaseError): 

647 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

648 # Repeat the operation with a different value in 'compared' and ask to 

649 # update. 

650 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

651 self.assertEqual(updated, {"value": 10}) 

652 self.assertEqual( 

653 [{"id": values["id"], "name": "b1", "value": 20}], 

654 [r._asdict() for r in self.query_list(db, tables.b.select())], 

655 ) 

656 

657 def testReplace(self): 

658 """Tests for `Database.replace`.""" 

659 db = self.makeEmptyDatabase(origin=1) 

660 with db.declareStaticTables(create=True) as context: 

661 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

662 # Use 'replace' to insert a single row that contains a region and 

663 # query to get it back. 

664 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

665 row1 = {"name": "a1", "region": region} 

666 db.replace(tables.a, row1) 

667 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

668 # Insert another row without a region. 

669 row2 = {"name": "a2", "region": None} 

670 db.replace(tables.a, row2) 

671 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

672 # Use replace to re-insert both of those rows again, which should do 

673 # nothing. 

674 db.replace(tables.a, row1, row2) 

675 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

676 # Replace row1 with a row with no region, while reinserting row2. 

677 row1a = {"name": "a1", "region": None} 

678 db.replace(tables.a, row1a, row2) 

679 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2]) 

680 # Replace both rows, returning row1 to its original state, while adding 

681 # a new one. Pass them in in a different order. 

682 row2a = {"name": "a2", "region": region} 

683 row3 = {"name": "a3", "region": None} 

684 db.replace(tables.a, row3, row2a, row1) 

685 self.assertCountEqual( 

686 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3] 

687 ) 

688 

689 def test_replace_pkey_only(self): 

690 """Test `Database.replace` on a table that only has primary key.""" 

691 spec = ddl.TableSpec( 

692 [ 

693 ddl.FieldSpec("a1", dtype=sqlalchemy.BigInteger, primaryKey=True), 

694 ddl.FieldSpec("a2", dtype=sqlalchemy.BigInteger, primaryKey=True), 

695 ] 

696 ) 

697 db = self.makeEmptyDatabase(origin=1) 

698 with db.declareStaticTables(create=True) as context: 

699 table = context.addTable("a", spec) 

700 row1 = {"a1": 1, "a2": 2} 

701 row2 = {"a1": 1, "a2": 3} 

702 db.replace(table, row1) 

703 db.replace(table, row2) 

704 db.replace(table, row1) 

705 self.assertCountEqual([r._asdict() for r in self.query_list(db, table.select())], [row1, row2]) 

706 

707 def testEnsure(self): 

708 """Tests for `Database.ensure`.""" 

709 db = self.makeEmptyDatabase(origin=1) 

710 with db.declareStaticTables(create=True) as context: 

711 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

712 # Use 'ensure' to insert a single row that contains a region and 

713 # query to get it back. 

714 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

715 row1 = {"name": "a1", "region": region} 

716 self.assertEqual(db.ensure(tables.a, row1), 1) 

717 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

718 # Insert another row without a region. 

719 row2 = {"name": "a2", "region": None} 

720 self.assertEqual(db.ensure(tables.a, row2), 1) 

721 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

722 # Use ensure to re-insert both of those rows again, which should do 

723 # nothing. 

724 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

725 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

726 # Attempt to insert row1's key with no region, while 

727 # reinserting row2. This should also do nothing. 

728 row1a = {"name": "a1", "region": None} 

729 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

730 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

731 # Attempt to insert new rows for both existing keys, this time also 

732 # adding a new row. Pass them in in a different order. Only the new 

733 # row should be added. 

734 row2a = {"name": "a2", "region": region} 

735 row3 = {"name": "a3", "region": None} 

736 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

737 self.assertCountEqual( 

738 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3] 

739 ) 

740 # Add some data to a table with both a primary key and a different 

741 # unique constraint. 

742 row_b = {"id": 5, "name": "five", "value": 50} 

743 db.insert(tables.b, row_b) 

744 # Attempt ensure with primary_key_only=False and a conflict for the 

745 # non-PK constraint. This should do nothing. 

746 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}) 

747 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

748 # Now use primary_key_only=True with conflict in only the non-PK field. 

749 # This should be an integrity error and nothing should change. 

750 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

751 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True) 

752 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

753 # With primary_key_only=True a conflict in the primary key is ignored 

754 # regardless of whether there is a conflict elsewhere. 

755 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True) 

756 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

757 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True) 

758 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

759 

760 def testTransactionNesting(self): 

761 """Test that transactions can be nested with the behavior in the 

762 presence of exceptions working as documented. 

763 """ 

764 db = self.makeEmptyDatabase(origin=1) 

765 with db.declareStaticTables(create=True) as context: 

766 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

767 # Insert one row so we can trigger integrity errors by trying to insert 

768 # a duplicate of it below. 

769 db.insert(tables.a, {"name": "a1"}) 

770 # First test: error recovery via explicit savepoint=True in the inner 

771 # transaction. 

772 with db.transaction(): 

773 # This insert should succeed, and should not be rolled back because 

774 # the assertRaises context should catch any exception before it 

775 # propagates up to the outer transaction. 

776 db.insert(tables.a, {"name": "a2"}) 

777 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

778 with db.transaction(savepoint=True): 

779 # This insert should succeed, but should be rolled back. 

780 db.insert(tables.a, {"name": "a4"}) 

781 # This insert should fail (duplicate primary key), raising 

782 # an exception. 

783 db.insert(tables.a, {"name": "a1"}) 

784 self.assertCountEqual( 

785 [r._asdict() for r in self.query_list(db, tables.a.select())], 

786 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

787 ) 

788 # Second test: error recovery via implicit savepoint=True, when the 

789 # innermost transaction is inside a savepoint=True transaction. 

790 with db.transaction(): 

791 # This insert should succeed, and should not be rolled back 

792 # because the assertRaises context should catch any 

793 # exception before it propagates up to the outer 

794 # transaction. 

795 db.insert(tables.a, {"name": "a3"}) 

796 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

797 with db.transaction(savepoint=True): 

798 # This insert should succeed, but should be rolled back. 

799 db.insert(tables.a, {"name": "a4"}) 

800 with db.transaction(): 

801 # This insert should succeed, but should be rolled 

802 # back. 

803 db.insert(tables.a, {"name": "a5"}) 

804 # This insert should fail (duplicate primary key), 

805 # raising an exception. 

806 db.insert(tables.a, {"name": "a1"}) 

807 self.assertCountEqual( 

808 [r._asdict() for r in self.query_list(db, tables.a.select())], 

809 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

810 ) 

811 

812 def testTransactionLocking(self): 

813 """Test that `Database.transaction` can be used to acquire a lock 

814 that prohibits concurrent writes. 

815 """ 

816 db1 = self.makeEmptyDatabase(origin=1) 

817 with db1.declareStaticTables(create=True) as context: 

818 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

819 

820 async def _side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]: 

821 """One side of the concurrent locking test. 

822 

823 Parameters 

824 ---------- 

825 lock : `~collections.abc.Iterable` of `str` 

826 Locks. 

827 

828 Notes 

829 ----- 

830 This optionally locks the table (and maybe the whole database), 

831 does a select for its contents, inserts a new row, and then selects 

832 again, with some waiting in between to make sure the other side has 

833 a chance to _attempt_ to insert in between. If the locking is 

834 enabled and works, the difference between the selects should just 

835 be the insert done on this thread. 

836 """ 

837 # Give Side2 a chance to create a connection 

838 await asyncio.sleep(1.0) 

839 with db1.transaction(lock=lock): 

840 names1 = {row.name for row in self.query_list(db1, tables1.a.select())} 

841 # Give Side2 a chance to insert (which will be blocked if 

842 # we've acquired a lock). 

843 await asyncio.sleep(2.0) 

844 db1.insert(tables1.a, {"name": "a1"}) 

845 names2 = {row.name for row in self.query_list(db1, tables1.a.select())} 

846 return names1, names2 

847 

848 async def _side2() -> None: 

849 """Other side of the concurrent locking test. 

850 

851 Notes 

852 ----- 

853 This side just waits a bit and then tries to insert a row into the 

854 table that the other side is trying to lock. Hopefully that 

855 waiting is enough to give the other side a chance to acquire the 

856 lock and thus make this side block until the lock is released. If 

857 this side manages to do the insert before side1 acquires the lock, 

858 we'll just warn about not succeeding at testing the locking, 

859 because we can only make that unlikely, not impossible. 

860 """ 

861 

862 def _toRunInThread(): 

863 """Create new SQLite connection for use in thread. 

864 

865 SQLite locking isn't asyncio-friendly unless we actually 

866 run it in another thread. And SQLite gets very unhappy if 

867 we try to use a connection from multiple threads, so we have 

868 to create the new connection here instead of out in the main 

869 body of the test function. 

870 """ 

871 db2 = self.getNewConnection(db1, writeable=True) 

872 with db2.declareStaticTables(create=False) as context: 

873 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

874 with db2.transaction(): 

875 db2.insert(tables2.a, {"name": "a2"}) 

876 

877 await asyncio.sleep(2.0) 

878 loop = asyncio.get_running_loop() 

879 with ThreadPoolExecutor() as pool: 

880 await loop.run_in_executor(pool, _toRunInThread) 

881 

882 async def _testProblemsWithNoLocking() -> None: 

883 """Run side1 and side2 with no locking, attempting to demonstrate 

884 the problem that locking is supposed to solve. If we get unlucky 

885 with scheduling, side2 will just happen to insert after side1 is 

886 done, and we won't have anything definitive. We just warn in that 

887 case because we really don't want spurious test failures. 

888 """ 

889 task1 = asyncio.create_task(_side1()) 

890 task2 = asyncio.create_task(_side2()) 

891 

892 names1, names2 = await task1 

893 await task2 

894 if "a2" in names1: 

895 warnings.warn( 

896 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.", 

897 stacklevel=1, 

898 ) 

899 self.assertEqual(names1, {"a2"}) 

900 self.assertEqual(names2, {"a1", "a2"}) 

901 elif "a2" not in names2: 

902 warnings.warn( 

903 "Unlucky scheduling in no-locking test: concurrent INSERT " 

904 "happened after second SELECT even without locking.", 

905 stacklevel=1, 

906 ) 

907 self.assertEqual(names1, set()) 

908 self.assertEqual(names2, {"a1"}) 

909 else: 

910 # This is the expected case: both INSERTS happen between the 

911 # two SELECTS. If we don't get this almost all of the time we 

912 # should adjust the sleep amounts. 

913 self.assertEqual(names1, set()) 

914 self.assertEqual(names2, {"a1", "a2"}) 

915 

916 asyncio.run(_testProblemsWithNoLocking()) 

917 

918 # Clean up after first test. 

919 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

920 

921 async def _testSolutionWithLocking() -> None: 

922 """Run side1 and side2 with locking, which should make side2 block 

923 its insert until side2 releases its lock. 

924 """ 

925 task1 = asyncio.create_task(_side1(lock=[tables1.a])) 

926 task2 = asyncio.create_task(_side2()) 

927 

928 names1, names2 = await task1 

929 await task2 

930 if "a2" in names1: 

931 warnings.warn( 

932 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.", 

933 stacklevel=1, 

934 ) 

935 self.assertEqual(names1, {"a2"}) 

936 self.assertEqual(names2, {"a1", "a2"}) 

937 else: 

938 # This is the expected case: the side2 INSERT happens after the 

939 # last SELECT on side1. This can also happen due to unlucky 

940 # scheduling, and we have no way to detect that here, but the 

941 # similar "no-locking" test has at least some chance of being 

942 # affected by the same problem and warning about it. 

943 self.assertEqual(names1, set()) 

944 self.assertEqual(names2, {"a1"}) 

945 

946 asyncio.run(_testSolutionWithLocking()) 

947 

948 def testTimespanDatabaseRepresentation(self): 

949 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

950 methods that interact with it. 

951 """ 

952 # Make some test timespans to play with, with the full suite of 

953 # topological relationships. 

954 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

955 offset = astropy.time.TimeDelta(60, format="sec") 

956 timestamps = [start + offset * n for n in range(3)] 

957 aTimespans = [Timespan(begin=None, end=None)] 

958 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

959 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

960 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

961 aTimespans.append(Timespan.makeEmpty()) 

962 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

963 # Make another list of timespans that span the full range but don't 

964 # overlap. This is a subset of the previous list. 

965 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

966 bTimespans.extend( 

967 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True) 

968 ) 

969 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

970 # Make a database and create a table with that database's timespan 

971 # representation. This one will have no exclusion constraint and 

972 # a nullable timespan. 

973 db = self.makeEmptyDatabase(origin=1) 

974 TimespanReprClass = db.getTimespanRepresentation() 

975 aSpec = ddl.TableSpec( 

976 fields=[ 

977 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

978 ], 

979 ) 

980 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

981 aSpec.fields.add(fieldSpec) 

982 with db.declareStaticTables(create=True) as context: 

983 aTable = context.addTable("a", aSpec) 

984 self.maxDiff = None 

985 

986 def _convertRowForInsert(row: dict) -> dict: 

987 """Convert a row containing a Timespan instance into one suitable 

988 for insertion into the database. 

989 """ 

990 result = row.copy() 

991 ts = result.pop(TimespanReprClass.NAME) 

992 return TimespanReprClass.update(ts, result=result) 

993 

994 def _convertRowFromSelect(row: dict) -> dict: 

995 """Convert a row from the database into one containing a Timespan. 

996 

997 Parameters 

998 ---------- 

999 row : `dict` 

1000 Original row. 

1001 

1002 Returns 

1003 ------- 

1004 row : `dict` 

1005 The updated row. 

1006 """ 

1007 result = row.copy() 

1008 timespan = TimespanReprClass.extract(result) 

1009 for name in TimespanReprClass.getFieldNames(): 

1010 del result[name] 

1011 result[TimespanReprClass.NAME] = timespan 

1012 return result 

1013 

1014 # Insert rows into table A, in chunks just to make things interesting. 

1015 # Include one with a NULL timespan. 

1016 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

1017 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

1018 db.insert(aTable, _convertRowForInsert(aRows[0])) 

1019 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[1:3]]) 

1020 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[3:]]) 

1021 # Add another one with a NULL timespan, but this time by invoking 

1022 # the server-side default. 

1023 aRows.append({"id": len(aRows)}) 

1024 db.insert(aTable, aRows[-1]) 

1025 aRows[-1][TimespanReprClass.NAME] = None 

1026 # Test basic round-trip through database. 

1027 self.assertEqual( 

1028 aRows, 

1029 [ 

1030 _convertRowFromSelect(row._asdict()) 

1031 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id)) 

1032 ], 

1033 ) 

1034 # Create another table B with a not-null timespan and (if the database 

1035 # supports it), an exclusion constraint. Use ensureTableExists this 

1036 # time to check that mode of table creation vs. timespans. 

1037 bSpec = ddl.TableSpec( 

1038 fields=[ 

1039 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

1040 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

1041 ], 

1042 ) 

1043 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

1044 bSpec.fields.add(fieldSpec) 

1045 if TimespanReprClass.hasExclusionConstraint(): 

1046 bSpec.exclusion.add(("key", TimespanReprClass)) 

1047 bTable = db.ensureTableExists("b", bSpec) 

1048 # Insert rows into table B, again in chunks. Each Timespan appears 

1049 # twice, but with different values for the 'key' field (which should 

1050 # still be okay for any exclusion constraint we may have defined). 

1051 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

1052 offset = len(bRows) 

1053 bRows.extend( 

1054 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

1055 ) 

1056 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[:2]]) 

1057 db.insert(bTable, _convertRowForInsert(bRows[2])) 

1058 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[3:]]) 

1059 # Insert a row with no timespan into table B. This should invoke the 

1060 # server-side default, which is a timespan over (-∞, ∞). We set 

1061 # key=3 to avoid upsetting an exclusion constraint that might exist. 

1062 bRows.append({"id": len(bRows), "key": 3}) 

1063 db.insert(bTable, bRows[-1]) 

1064 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

1065 # Test basic round-trip through database. 

1066 self.assertEqual( 

1067 bRows, 

1068 [ 

1069 _convertRowFromSelect(row._asdict()) 

1070 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id)) 

1071 ], 

1072 ) 

1073 # Test that we can't insert timespan=None into this table. 

1074 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1075 db.insert( 

1076 bTable, _convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None}) 

1077 ) 

1078 # IFF this database supports exclusion constraints, test that they 

1079 # also prevent inserts. 

1080 if TimespanReprClass.hasExclusionConstraint(): 

1081 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1082 db.insert( 

1083 bTable, 

1084 _convertRowForInsert( 

1085 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

1086 ), 

1087 ) 

1088 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1089 db.insert( 

1090 bTable, 

1091 _convertRowForInsert( 

1092 { 

1093 "id": len(bRows), 

1094 "key": 1, 

1095 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

1096 } 

1097 ), 

1098 ) 

1099 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1100 db.insert( 

1101 bTable, 

1102 _convertRowForInsert( 

1103 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

1104 ), 

1105 ) 

1106 # Test NULL checks in SELECT queries, on both tables. 

1107 aRepr = TimespanReprClass.from_columns(aTable.columns) 

1108 self.assertEqual( 

1109 [row[TimespanReprClass.NAME] is None for row in aRows], 

1110 [ 

1111 row.f 

1112 for row in self.query_list( 

1113 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

1114 ) 

1115 ], 

1116 ) 

1117 bRepr = TimespanReprClass.from_columns(bTable.columns) 

1118 self.assertEqual( 

1119 [False for row in bRows], 

1120 [ 

1121 row.f 

1122 for row in self.query_list( 

1123 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

1124 ) 

1125 ], 

1126 ) 

1127 # Test relationships expressions that relate in-database timespans to 

1128 # Python-literal timespans, all from the more complete 'a' set; check 

1129 # that this is consistent with Python-only relationship tests. 

1130 for rhsRow in aRows: 

1131 if rhsRow[TimespanReprClass.NAME] is None: 

1132 continue 

1133 with self.subTest(rhsRow=repr(rhsRow)): 

1134 expected = {} 

1135 for lhsRow in aRows: 

1136 if lhsRow[TimespanReprClass.NAME] is None: 

1137 expected[lhsRow["id"]] = (None, None, None, None) 

1138 else: 

1139 expected[lhsRow["id"]] = ( 

1140 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

1141 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

1142 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

1143 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

1144 ) 

1145 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1146 sql = sqlalchemy.sql.select( 

1147 aTable.columns.id.label("lhs"), 

1148 aRepr.overlaps(rhsRepr).label("overlaps"), 

1149 aRepr.contains(rhsRepr).label("contains"), 

1150 (aRepr < rhsRepr).label("less_than"), 

1151 (aRepr > rhsRepr).label("greater_than"), 

1152 ).select_from(aTable) 

1153 queried = { 

1154 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1155 for row in self.query_list(db, sql) 

1156 } 

1157 self.assertEqual(expected, queried) 

1158 # Test relationship expressions that relate in-database timespans to 

1159 # each other, all from the more complete 'a' set; check that this is 

1160 # consistent with Python-only relationship tests. 

1161 expected = {} 

1162 for lhs, rhs in itertools.product(aRows, aRows): 

1163 lhsT = lhs[TimespanReprClass.NAME] 

1164 rhsT = rhs[TimespanReprClass.NAME] 

1165 if lhsT is not None and rhsT is not None: 

1166 expected[lhs["id"], rhs["id"]] = ( 

1167 lhsT.overlaps(rhsT), 

1168 lhsT.contains(rhsT), 

1169 lhsT < rhsT, 

1170 lhsT > rhsT, 

1171 ) 

1172 else: 

1173 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1174 lhsSubquery = aTable.alias("lhs") 

1175 rhsSubquery = aTable.alias("rhs") 

1176 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns) 

1177 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns) 

1178 sql = sqlalchemy.sql.select( 

1179 lhsSubquery.columns.id.label("lhs"), 

1180 rhsSubquery.columns.id.label("rhs"), 

1181 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1182 lhsRepr.contains(rhsRepr).label("contains"), 

1183 (lhsRepr < rhsRepr).label("less_than"), 

1184 (lhsRepr > rhsRepr).label("greater_than"), 

1185 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1186 queried = { 

1187 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1188 for row in self.query_list(db, sql) 

1189 } 

1190 self.assertEqual(expected, queried) 

1191 # Test relationship expressions between in-database timespans and 

1192 # Python-literal instantaneous times. 

1193 for t in timestamps: 

1194 with self.subTest(t=repr(t)): 

1195 expected = {} 

1196 for lhsRow in aRows: 

1197 if lhsRow[TimespanReprClass.NAME] is None: 

1198 expected[lhsRow["id"]] = (None, None, None, None) 

1199 else: 

1200 expected[lhsRow["id"]] = ( 

1201 lhsRow[TimespanReprClass.NAME].contains(t), 

1202 lhsRow[TimespanReprClass.NAME].overlaps(t), 

1203 lhsRow[TimespanReprClass.NAME] < t, 

1204 lhsRow[TimespanReprClass.NAME] > t, 

1205 ) 

1206 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1207 sql = sqlalchemy.sql.select( 

1208 aTable.columns.id.label("lhs"), 

1209 aRepr.contains(rhs).label("contains"), 

1210 aRepr.overlaps(rhs).label("overlaps_point"), 

1211 (aRepr < rhs).label("less_than"), 

1212 (aRepr > rhs).label("greater_than"), 

1213 ).select_from(aTable) 

1214 queried = { 

1215 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than) 

1216 for row in self.query_list(db, sql) 

1217 } 

1218 self.assertEqual(expected, queried) 

1219 

1220 def testConstantRows(self): 

1221 """Test Database.constant_rows.""" 

1222 new_db = self.makeEmptyDatabase() 

1223 with new_db.declareStaticTables(create=True) as context: 

1224 static = context.addTableTuple(STATIC_TABLE_SPECS) 

1225 b_ids = new_db.insert( 

1226 static.b, 

1227 {"name": "b1", "value": 11}, 

1228 {"name": "b2", "value": 12}, 

1229 {"name": "b3", "value": 13}, 

1230 returnIds=True, 

1231 ) 

1232 values_spec = ddl.TableSpec( 

1233 [ 

1234 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger), 

1235 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)), 

1236 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()), 

1237 ], 

1238 ) 

1239 values_data = [ 

1240 {"b": b_ids[0], "s": "b1", "r": None}, 

1241 {"b": b_ids[2], "s": "b3", "r": Circle.empty()}, 

1242 ] 

1243 values = new_db.constant_rows(values_spec.fields, *values_data) 

1244 select_values_alone = sqlalchemy.sql.select( 

1245 values.columns["b"], values.columns["s"], values.columns["r"] 

1246 ) 

1247 self.assertCountEqual( 

1248 [row._mapping for row in self.query_list(new_db, select_values_alone)], 

1249 values_data, 

1250 ) 

1251 select_values_joined = sqlalchemy.sql.select( 

1252 values.columns["s"].label("name"), static.b.columns["value"].label("value") 

1253 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"])) 

1254 self.assertCountEqual( 

1255 [row._mapping for row in self.query_list(new_db, select_values_joined)], 

1256 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}], 

1257 ) 

1258 

1259 def test_aggregate(self) -> None: 

1260 """Test Database.apply_any_aggregate, ddl.Base64Region.union_aggregate, 

1261 and TimespanDatabaseRepresetnation.apply_any_aggregate. 

1262 """ 

1263 db = self.makeEmptyDatabase() 

1264 with db.declareStaticTables(create=True) as context: 

1265 t = context.addTable( 

1266 "t", 

1267 ddl.TableSpec( 

1268 [ 

1269 ddl.FieldSpec("id", sqlalchemy.BigInteger(), nullable=False), 

1270 ddl.FieldSpec("name", sqlalchemy.String(16), nullable=False), 

1271 ddl.FieldSpec.for_region(), 

1272 ] 

1273 + list(db.getTimespanRepresentation().makeFieldSpecs(nullable=True)), 

1274 ), 

1275 ) 

1276 pixelization = Mq3cPixelization(10) 

1277 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

1278 offset = astropy.time.TimeDelta(60, format="sec") 

1279 timespans = [Timespan(begin=start + offset * n, end=start + offset * (n + 1)) for n in range(3)] 

1280 ts_cls = db.getTimespanRepresentation() 

1281 ts_col = ts_cls.from_columns(t.columns) 

1282 db.insert( 

1283 t, 

1284 ts_cls.update(timespans[0], result={"id": 1, "name": "a", "region": pixelization.quad(12058870)}), 

1285 ts_cls.update(timespans[1], result={"id": 2, "name": "a", "region": pixelization.quad(12058871)}), 

1286 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058872)}), 

1287 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058873)}), 

1288 ) 

1289 # This should use DISTINCT ON in PostgreSQL and GROUP BY in SQLite. 

1290 if db.has_distinct_on: 

1291 sql = ( 

1292 sqlalchemy.select( 

1293 t.c.id.label("i"), 

1294 t.c.name.label("n"), 

1295 *ts_col.flatten("t"), 

1296 ) 

1297 .select_from(t) 

1298 .distinct(t.c.id) 

1299 ) 

1300 elif db.has_any_aggregate: 

1301 sql = ( 

1302 sqlalchemy.select( 

1303 t.c.id.label("i"), 

1304 db.apply_any_aggregate(t.c.name).label("n"), 

1305 *ts_col.apply_any_aggregate(db.apply_any_aggregate).flatten("t"), 

1306 ) 

1307 .select_from(t) 

1308 .group_by(t.c.id) 

1309 ) 

1310 else: 

1311 raise AssertionError( 

1312 "PostgreSQL should support DISTINCT ON and SQLite should support no-op any aggregates." 

1313 ) 

1314 self.assertCountEqual( 

1315 [(row.i, row.n, ts_cls.extract(row._mapping, "t")) for row in self.query_list(db, sql)], 

1316 [(1, "a", timespans[0]), (2, "a", timespans[1]), (3, "b", timespans[2])], 

1317 ) 

1318 # Test union_aggregate in all versions of both database, with a GROUP 

1319 # BY that does not need apply_any_aggregate. 

1320 self.assertCountEqual( 

1321 [ 

1322 (row.i, row.r.encode()) 

1323 for row in self.query_list( 

1324 db, 

1325 sqlalchemy.select( 

1326 t.c.id.label("i"), ddl.Base64Region.union_aggregate(t.c.region).label("r") 

1327 ) 

1328 .select_from(t) 

1329 .group_by(t.c.id), 

1330 ) 

1331 ], 

1332 [ 

1333 (1, pixelization.quad(12058870).encode()), 

1334 (2, pixelization.quad(12058871).encode()), 

1335 (3, UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()), 

1336 ], 

1337 ) 

1338 if db.has_any_aggregate: 

1339 # This should use run in SQLite and PostgreSQL 16+. 

1340 self.assertCountEqual( 

1341 [ 

1342 (row.i, row.n, row.r.encode()) 

1343 for row in self.query_list( 

1344 db, 

1345 sqlalchemy.select( 

1346 t.c.id.label("i"), 

1347 db.apply_any_aggregate(t.c.name).label("n"), 

1348 ddl.Base64Region.union_aggregate(t.c.region).label("r"), 

1349 ) 

1350 .select_from(t) 

1351 .group_by(t.c.id), 

1352 ) 

1353 ], 

1354 [ 

1355 (1, "a", pixelization.quad(12058870).encode()), 

1356 (2, "a", pixelization.quad(12058871).encode()), 

1357 (3, "b", UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()), 

1358 ], 

1359 )