Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%

511 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-02 10:24 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl 

30 

31__all__ = ["DatabaseTests"] 

32 

33import asyncio 

34import itertools 

35import warnings 

36from abc import ABC, abstractmethod 

37from collections import namedtuple 

38from collections.abc import Iterable 

39from concurrent.futures import ThreadPoolExecutor 

40from contextlib import AbstractContextManager, contextmanager 

41from typing import Any 

42 

43import astropy.time 

44import sqlalchemy 

45from lsst.sphgeom import Circle, ConvexPolygon, Mq3cPixelization, UnionRegion, UnitVector3d 

46 

47from ..._timespan import Timespan 

48from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

49 

50StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

51 

52STATIC_TABLE_SPECS = StaticTablesTuple( 

53 a=ddl.TableSpec( 

54 fields=[ 

55 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

56 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

57 ] 

58 ), 

59 b=ddl.TableSpec( 

60 fields=[ 

61 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

62 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

63 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

64 ], 

65 unique=[("name",)], 

66 ), 

67 c=ddl.TableSpec( 

68 fields=[ 

69 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

70 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

71 ], 

72 foreignKeys=[ 

73 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

74 ], 

75 ), 

76) 

77 

78DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

79 fields=[ 

80 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

81 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

82 ], 

83 foreignKeys=[ 

84 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"), 

85 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

86 ], 

87) 

88 

89TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

90 fields=[ 

91 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

92 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

93 ], 

94) 

95 

96 

97@contextmanager 

98def _patch_getExistingTable(db: Database) -> Database: 

99 """Patch getExistingTable method in a database instance to test concurrent 

100 creation of tables. This patch obviously depends on knowning internals of 

101 ``ensureTableExists()`` implementation. 

102 """ 

103 original_method = db.getExistingTable 

104 

105 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None: 

106 # Return None on first call, but forward to original method after that 

107 db.getExistingTable = original_method 

108 return None 

109 

110 db.getExistingTable = _getExistingTable 

111 yield db 

112 db.getExistingTable = original_method 

113 

114 

115class DatabaseTests(ABC): 

116 """Generic tests for the `Database` interface that can be subclassed to 

117 generate tests for concrete implementations. 

118 """ 

119 

120 @abstractmethod 

121 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

122 """Return an empty `Database` with the given origin, or an 

123 automatically-generated one if ``origin`` is `None`. 

124 

125 Parameters 

126 ---------- 

127 origin : `int` or `None` 

128 Origin to use for the database. 

129 

130 Returns 

131 ------- 

132 db : `Database` 

133 Empty database with given origin or auto-generated origin. 

134 """ 

135 raise NotImplementedError() 

136 

137 @abstractmethod 

138 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]: 

139 """Return a context manager for a read-only connection into the given 

140 database. 

141 

142 Parameters 

143 ---------- 

144 database : `Database` 

145 The database to use. 

146 

147 Yields 

148 ------ 

149 `Database` 

150 The new database connection. 

151 

152 Notes 

153 ----- 

154 The original database should be considered unusable within the context 

155 but safe to use again afterwards (this allows the context manager to 

156 block write access by temporarily changing user permissions to really 

157 guarantee that write operations are not performed). 

158 """ 

159 raise NotImplementedError() 

160 

161 @abstractmethod 

162 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

163 """Return a new `Database` instance that points to the same underlying 

164 storage as the given one. 

165 

166 Parameters 

167 ---------- 

168 database : `Database` 

169 The current database. 

170 writeable : `bool` 

171 Whether the connection should be writeable or not. 

172 

173 Returns 

174 ------- 

175 db : `Database` 

176 The new database connection. 

177 """ 

178 raise NotImplementedError() 

179 

180 def query_list( 

181 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase 

182 ) -> list[sqlalchemy.engine.Row]: 

183 """Run a SELECT or other read-only query against the database and 

184 return the results as a list. 

185 

186 Parameters 

187 ---------- 

188 database : `Database` 

189 The database to use. 

190 executable : `sqlalchemy.sql.expression.SelectBase` 

191 Expression to execute. 

192 

193 Returns 

194 ------- 

195 results : `list` of `sqlalchemy.engine.Row` 

196 The results. 

197 

198 Notes 

199 ----- 

200 This is a thin wrapper around database.query() that just avoids 

201 context-manager boilerplate that is usefully verbose in production code 

202 but just noise in tests. 

203 """ 

204 with database.transaction(), database.query(executable) as result: 

205 return result.fetchall() 

206 

207 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any: 

208 """Run a SELECT query that yields a single column and row against the 

209 database and return its value. 

210 

211 Parameters 

212 ---------- 

213 database : `Database` 

214 The database to use. 

215 executable : `sqlalchemy.sql.expression.SelectBase` 

216 Expression to execute. 

217 

218 Returns 

219 ------- 

220 results : `~typing.Any` 

221 The results. 

222 

223 Notes 

224 ----- 

225 This is a thin wrapper around database.query() that just avoids 

226 context-manager boilerplate that is usefully verbose in production code 

227 but just noise in tests. 

228 """ 

229 with database.query(executable) as result: 

230 return result.scalar() 

231 

232 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

233 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

234 # Checking more than this currently seems fragile, as it might restrict 

235 # what Database implementations do; we don't care if the spec is 

236 # actually preserved in terms of types and constraints as long as we 

237 # can use the returned table as if it was. 

238 

239 def checkStaticSchema(self, tables: StaticTablesTuple): 

240 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

241 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

242 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

243 

244 def testDeclareStaticTables(self): 

245 """Tests for `Database.declareStaticSchema` and the methods it 

246 delegates to. 

247 """ 

248 # Create the static schema in a new, empty database. 

249 newDatabase = self.makeEmptyDatabase() 

250 with newDatabase.declareStaticTables(create=True) as context: 

251 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

252 self.checkStaticSchema(tables) 

253 # Check that we can load that schema even from a read-only connection. 

254 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

255 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

256 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

257 self.checkStaticSchema(tables) 

258 

259 def testDeclareStaticTablesTwice(self): 

260 """Tests for `Database.declareStaticSchema` being called twice.""" 

261 # Create the static schema in a new, empty database. 

262 newDatabase = self.makeEmptyDatabase() 

263 with newDatabase.declareStaticTables(create=True) as context: 

264 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

265 self.checkStaticSchema(tables) 

266 # Second time it should raise 

267 with self.assertRaises(SchemaAlreadyDefinedError): 

268 with newDatabase.declareStaticTables(create=True) as context: 

269 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

270 # Check schema, it should still contain all tables, and maybe some 

271 # extra. 

272 with newDatabase.declareStaticTables(create=False) as context: 

273 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

274 

275 def testRepr(self): 

276 """Test that repr does not return a generic thing.""" 

277 newDatabase = self.makeEmptyDatabase() 

278 rep = repr(newDatabase) 

279 # Check that stringification works and gives us something different 

280 self.assertNotEqual(rep, str(newDatabase)) 

281 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

282 self.assertIn("://", rep) 

283 

284 def testDynamicTables(self): 

285 """Tests for `Database.ensureTableExists` and 

286 `Database.getExistingTable`. 

287 """ 

288 # Need to start with the static schema. 

289 newDatabase = self.makeEmptyDatabase() 

290 with newDatabase.declareStaticTables(create=True) as context: 

291 context.addTableTuple(STATIC_TABLE_SPECS) 

292 # Try to ensure the dynamic table exists in a read-only version of that 

293 # database, which should fail because we can't create it. 

294 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

295 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

296 context.addTableTuple(STATIC_TABLE_SPECS) 

297 with self.assertRaises(ReadOnlyDatabaseError): 

298 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

299 # Just getting the dynamic table before it exists should return None. 

300 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

301 # Ensure the new table exists back in the original database, which 

302 # should create it. 

303 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

304 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

305 # Ensuring that it exists should just return the exact same table 

306 # instance again. 

307 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

308 # Try again from the read-only database. 

309 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

310 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

311 context.addTableTuple(STATIC_TABLE_SPECS) 

312 # Just getting the dynamic table should now work... 

313 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

314 # ...as should ensuring that it exists, since it now does. 

315 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

316 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

317 # Trying to get the table with a different specification (at least 

318 # in terms of what columns are present) should raise. 

319 with self.assertRaises(DatabaseConflictError): 

320 newDatabase.ensureTableExists( 

321 "d", 

322 ddl.TableSpec( 

323 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

324 ), 

325 ) 

326 # Calling ensureTableExists inside a transaction block is an error, 

327 # even if it would do nothing. 

328 with newDatabase.transaction(): 

329 with self.assertRaises(AssertionError): 

330 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

331 

332 def testDynamicTablesConcurrency(self): 

333 """Tests for `Database.ensureTableExists` concurrent use.""" 

334 # We cannot really run things concurrently in a deterministic way, here 

335 # we just simulate a situation when the table is created by other 

336 # process between the call to getExistingTable() and actual table 

337 # creation. 

338 db1 = self.makeEmptyDatabase() 

339 with db1.declareStaticTables(create=True) as context: 

340 context.addTableTuple(STATIC_TABLE_SPECS) 

341 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

342 

343 # Make a dynamic table using separate connection 

344 db2 = self.getNewConnection(db1, writeable=True) 

345 with db2.declareStaticTables(create=False) as context: 

346 context.addTableTuple(STATIC_TABLE_SPECS) 

347 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

348 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

349 

350 # Call it again but trick it into thinking that table is not there. 

351 # This test depends on knowing implementation of ensureTableExists() 

352 # which initially calls getExistingTable() to check that table may 

353 # exist, the patch intercepts that call and returns None. 

354 with _patch_getExistingTable(db1): 

355 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

356 

357 def testTemporaryTables(self): 

358 """Tests for `Database.temporary_table`, and `Database.insert` with the 

359 ``select`` argument. 

360 """ 

361 # Need to start with the static schema; also insert some test data. 

362 newDatabase = self.makeEmptyDatabase() 

363 with newDatabase.declareStaticTables(create=True) as context: 

364 static = context.addTableTuple(STATIC_TABLE_SPECS) 

365 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

366 bIds = newDatabase.insert( 

367 static.b, 

368 {"name": "b1", "value": 11}, 

369 {"name": "b2", "value": 12}, 

370 {"name": "b3", "value": 13}, 

371 returnIds=True, 

372 ) 

373 # Create the table. 

374 with newDatabase.session(): 

375 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1: 

376 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

377 # Insert via a INSERT INTO ... SELECT query. 

378 newDatabase.insert( 

379 table1, 

380 select=sqlalchemy.sql.select( 

381 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

382 ) 

383 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

384 .where( 

385 sqlalchemy.sql.and_( 

386 static.a.columns.name == "a1", 

387 static.b.columns.value <= 12, 

388 ) 

389 ), 

390 ) 

391 # Check that the inserted rows are present. 

392 self.assertCountEqual( 

393 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

394 [row._asdict() for row in self.query_list(newDatabase, table1.select())], 

395 ) 

396 # Create another one via a read-only connection to the 

397 # database. We _do_ allow temporary table modifications in 

398 # read-only databases. 

399 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

400 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

401 context.addTableTuple(STATIC_TABLE_SPECS) 

402 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2: 

403 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

404 # Those tables should not be the same, despite having 

405 # the same ddl. 

406 self.assertIsNot(table1, table2) 

407 # Do a slightly different insert into this table, to 

408 # check that it works in a read-only database. This 

409 # time we pass column names as a kwarg to insert 

410 # instead of by labeling the columns in the select. 

411 existingReadOnlyDatabase.insert( 

412 table2, 

413 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

414 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

415 .where( 

416 sqlalchemy.sql.and_( 

417 static.a.columns.name == "a2", 

418 static.b.columns.value >= 12, 

419 ) 

420 ), 

421 names=["a_name", "b_id"], 

422 ) 

423 # Check that the inserted rows are present. 

424 self.assertCountEqual( 

425 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

426 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())], 

427 ) 

428 # Exiting the context managers will drop the temporary tables from 

429 # the read-only DB. It's unspecified whether attempting to use it 

430 # after this point is an error or just never returns any results, 

431 # so we can't test what it does, only that it's not an error. 

432 

433 def testSchemaSeparation(self): 

434 """Test that creating two different `Database` instances allows us 

435 to create different tables with the same name in each. 

436 """ 

437 db1 = self.makeEmptyDatabase(origin=1) 

438 with db1.declareStaticTables(create=True) as context: 

439 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

440 self.checkStaticSchema(tables) 

441 

442 db2 = self.makeEmptyDatabase(origin=2) 

443 # Make the DDL here intentionally different so we'll definitely 

444 # notice if db1 and db2 are pointing at the same schema. 

445 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

446 with db2.declareStaticTables(create=True) as context: 

447 # Make the DDL here intentionally different so we'll definitely 

448 # notice if db1 and db2 are pointing at the same schema. 

449 table = context.addTable("a", spec) 

450 self.checkTable(spec, table) 

451 

452 def testInsertQueryDelete(self): 

453 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

454 methods, as well as the `Base64Region` type and the ``onDelete`` 

455 argument to `ddl.ForeignKeySpec`. 

456 """ 

457 db = self.makeEmptyDatabase(origin=1) 

458 with db.declareStaticTables(create=True) as context: 

459 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

460 # Insert a single, non-autoincrement row that contains a region and 

461 # query to get it back. 

462 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

463 row = {"name": "a1", "region": region} 

464 db.insert(tables.a, row) 

465 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row]) 

466 # Insert multiple autoincrement rows but do not try to get the IDs 

467 # back immediately. 

468 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

469 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))] 

470 self.assertEqual(len(results), 2) 

471 for row in results: 

472 self.assertIn(row["name"], ("b1", "b2")) 

473 self.assertIsInstance(row["id"], int) 

474 self.assertGreater(results[1]["id"], results[0]["id"]) 

475 # Insert multiple autoincrement rows and get the IDs back from insert. 

476 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

477 ids = db.insert(tables.b, *rows, returnIds=True) 

478 results = [ 

479 r._asdict() 

480 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

481 ] 

482 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

483 self.assertCountEqual(results, expected) 

484 self.assertTrue(all(result["id"] is not None for result in results)) 

485 # Insert multiple rows into a table with an autoincrement primary key, 

486 # then use the returned IDs to insert into a dynamic table. 

487 rows = [{"b_id": results[0]["id"]}, {"b_id": None}] 

488 ids = db.insert(tables.c, *rows, returnIds=True) 

489 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

490 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

491 self.assertCountEqual(results, expected) 

492 self.assertTrue(all(result["id"] is not None for result in results)) 

493 # Add the dynamic table. 

494 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

495 # Insert into it. 

496 rows = [{"c_id": id, "a_name": "a1"} for id in ids] 

497 db.insert(d, *rows) 

498 results = [r._asdict() for r in self.query_list(db, d.select())] 

499 self.assertCountEqual(rows, results) 

500 # Insert multiple rows into a table with an autoincrement primary key, 

501 # but pass in a value for the autoincrement key. 

502 rows2 = [ 

503 {"id": 700, "b_id": None}, 

504 {"id": 701, "b_id": None}, 

505 ] 

506 db.insert(tables.c, *rows2) 

507 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

508 self.assertCountEqual(results, expected + rows2) 

509 self.assertTrue(all(result["id"] is not None for result in results)) 

510 

511 # Define 'SELECT COUNT(*)' query for later use. 

512 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

513 # Get the values we inserted into table b. 

514 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())] 

515 # Remove two row from table b by ID. 

516 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

517 self.assertEqual(n, 2) 

518 # Remove the other two rows from table b by name. 

519 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

520 self.assertEqual(n, 2) 

521 # There should now be no rows in table b. 

522 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

523 # All b_id values in table c should now be NULL, because there's an 

524 # onDelete='SET NULL' foreign key. 

525 self.assertEqual( 

526 self.query_scalar( 

527 db, 

528 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711 

529 ), 

530 0, 

531 ) 

532 # Remove all rows in table a (there's only one); this should remove all 

533 # rows in d due to onDelete='CASCADE'. 

534 n = db.delete(tables.a, []) 

535 self.assertEqual(n, 1) 

536 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0) 

537 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0) 

538 

539 def testDeleteWhere(self): 

540 """Tests for `Database.deleteWhere`.""" 

541 db = self.makeEmptyDatabase(origin=1) 

542 with db.declareStaticTables(create=True) as context: 

543 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

544 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

545 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

546 

547 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

548 self.assertEqual(n, 3) 

549 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7) 

550 

551 n = db.deleteWhere( 

552 tables.b, 

553 tables.b.columns.id.in_( 

554 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

555 ), 

556 ) 

557 self.assertEqual(n, 4) 

558 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3) 

559 

560 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

561 self.assertEqual(n, 1) 

562 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2) 

563 

564 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

565 self.assertEqual(n, 2) 

566 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

567 

568 def testUpdate(self): 

569 """Tests for `Database.update`.""" 

570 db = self.makeEmptyDatabase(origin=1) 

571 with db.declareStaticTables(create=True) as context: 

572 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

573 # Insert two rows into table a, both without regions. 

574 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

575 # Update one of the rows with a region. 

576 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

577 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

578 self.assertEqual(n, 1) 

579 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

580 self.assertCountEqual( 

581 [r._asdict() for r in self.query_list(db, sql)], 

582 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

583 ) 

584 

585 def testSync(self): 

586 """Tests for `Database.sync`.""" 

587 db = self.makeEmptyDatabase(origin=1) 

588 with db.declareStaticTables(create=True) as context: 

589 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

590 # Insert a row with sync, because it doesn't exist yet. 

591 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

592 self.assertTrue(inserted) 

593 self.assertEqual( 

594 [{"id": values["id"], "name": "b1", "value": 10}], 

595 [r._asdict() for r in self.query_list(db, tables.b.select())], 

596 ) 

597 # Repeat that operation, which should do nothing but return the 

598 # requested values. 

599 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

600 self.assertFalse(inserted) 

601 self.assertEqual( 

602 [{"id": values["id"], "name": "b1", "value": 10}], 

603 [r._asdict() for r in self.query_list(db, tables.b.select())], 

604 ) 

605 # Repeat the operation without the 'extra' arg, which should also just 

606 # return the existing row. 

607 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

608 self.assertFalse(inserted) 

609 self.assertEqual( 

610 [{"id": values["id"], "name": "b1", "value": 10}], 

611 [r._asdict() for r in self.query_list(db, tables.b.select())], 

612 ) 

613 # Repeat the operation with a different value in 'extra'. That still 

614 # shouldn't be an error, because 'extra' is only used if we really do 

615 # insert. Also drop the 'returning' argument. 

616 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

617 self.assertFalse(inserted) 

618 self.assertEqual( 

619 [{"id": values["id"], "name": "b1", "value": 10}], 

620 [r._asdict() for r in self.query_list(db, tables.b.select())], 

621 ) 

622 # Repeat the operation with the correct value in 'compared' instead of 

623 # 'extra'. 

624 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

625 self.assertFalse(inserted) 

626 self.assertEqual( 

627 [{"id": values["id"], "name": "b1", "value": 10}], 

628 [r._asdict() for r in self.query_list(db, tables.b.select())], 

629 ) 

630 # Repeat the operation with an incorrect value in 'compared'; this 

631 # should raise. 

632 with self.assertRaises(DatabaseConflictError): 

633 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

634 # Try to sync in a read-only database. This should work if and only 

635 # if the matching row already exists. 

636 with self.asReadOnly(db) as rodb: 

637 with rodb.declareStaticTables(create=False) as context: 

638 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

639 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

640 self.assertFalse(inserted) 

641 self.assertEqual( 

642 [{"id": values["id"], "name": "b1", "value": 10}], 

643 [r._asdict() for r in self.query_list(rodb, tables.b.select())], 

644 ) 

645 with self.assertRaises(ReadOnlyDatabaseError): 

646 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

647 # Repeat the operation with a different value in 'compared' and ask to 

648 # update. 

649 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

650 self.assertEqual(updated, {"value": 10}) 

651 self.assertEqual( 

652 [{"id": values["id"], "name": "b1", "value": 20}], 

653 [r._asdict() for r in self.query_list(db, tables.b.select())], 

654 ) 

655 

656 def testReplace(self): 

657 """Tests for `Database.replace`.""" 

658 db = self.makeEmptyDatabase(origin=1) 

659 with db.declareStaticTables(create=True) as context: 

660 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

661 # Use 'replace' to insert a single row that contains a region and 

662 # query to get it back. 

663 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

664 row1 = {"name": "a1", "region": region} 

665 db.replace(tables.a, row1) 

666 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

667 # Insert another row without a region. 

668 row2 = {"name": "a2", "region": None} 

669 db.replace(tables.a, row2) 

670 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

671 # Use replace to re-insert both of those rows again, which should do 

672 # nothing. 

673 db.replace(tables.a, row1, row2) 

674 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

675 # Replace row1 with a row with no region, while reinserting row2. 

676 row1a = {"name": "a1", "region": None} 

677 db.replace(tables.a, row1a, row2) 

678 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2]) 

679 # Replace both rows, returning row1 to its original state, while adding 

680 # a new one. Pass them in in a different order. 

681 row2a = {"name": "a2", "region": region} 

682 row3 = {"name": "a3", "region": None} 

683 db.replace(tables.a, row3, row2a, row1) 

684 self.assertCountEqual( 

685 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3] 

686 ) 

687 

688 def testEnsure(self): 

689 """Tests for `Database.ensure`.""" 

690 db = self.makeEmptyDatabase(origin=1) 

691 with db.declareStaticTables(create=True) as context: 

692 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

693 # Use 'ensure' to insert a single row that contains a region and 

694 # query to get it back. 

695 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

696 row1 = {"name": "a1", "region": region} 

697 self.assertEqual(db.ensure(tables.a, row1), 1) 

698 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

699 # Insert another row without a region. 

700 row2 = {"name": "a2", "region": None} 

701 self.assertEqual(db.ensure(tables.a, row2), 1) 

702 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

703 # Use ensure to re-insert both of those rows again, which should do 

704 # nothing. 

705 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

706 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

707 # Attempt to insert row1's key with no region, while 

708 # reinserting row2. This should also do nothing. 

709 row1a = {"name": "a1", "region": None} 

710 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

711 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

712 # Attempt to insert new rows for both existing keys, this time also 

713 # adding a new row. Pass them in in a different order. Only the new 

714 # row should be added. 

715 row2a = {"name": "a2", "region": region} 

716 row3 = {"name": "a3", "region": None} 

717 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

718 self.assertCountEqual( 

719 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3] 

720 ) 

721 # Add some data to a table with both a primary key and a different 

722 # unique constraint. 

723 row_b = {"id": 5, "name": "five", "value": 50} 

724 db.insert(tables.b, row_b) 

725 # Attempt ensure with primary_key_only=False and a conflict for the 

726 # non-PK constraint. This should do nothing. 

727 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}) 

728 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

729 # Now use primary_key_only=True with conflict in only the non-PK field. 

730 # This should be an integrity error and nothing should change. 

731 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

732 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True) 

733 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

734 # With primary_key_only=True a conflict in the primary key is ignored 

735 # regardless of whether there is a conflict elsewhere. 

736 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True) 

737 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

738 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True) 

739 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

740 

741 def testTransactionNesting(self): 

742 """Test that transactions can be nested with the behavior in the 

743 presence of exceptions working as documented. 

744 """ 

745 db = self.makeEmptyDatabase(origin=1) 

746 with db.declareStaticTables(create=True) as context: 

747 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

748 # Insert one row so we can trigger integrity errors by trying to insert 

749 # a duplicate of it below. 

750 db.insert(tables.a, {"name": "a1"}) 

751 # First test: error recovery via explicit savepoint=True in the inner 

752 # transaction. 

753 with db.transaction(): 

754 # This insert should succeed, and should not be rolled back because 

755 # the assertRaises context should catch any exception before it 

756 # propagates up to the outer transaction. 

757 db.insert(tables.a, {"name": "a2"}) 

758 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

759 with db.transaction(savepoint=True): 

760 # This insert should succeed, but should be rolled back. 

761 db.insert(tables.a, {"name": "a4"}) 

762 # This insert should fail (duplicate primary key), raising 

763 # an exception. 

764 db.insert(tables.a, {"name": "a1"}) 

765 self.assertCountEqual( 

766 [r._asdict() for r in self.query_list(db, tables.a.select())], 

767 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

768 ) 

769 # Second test: error recovery via implicit savepoint=True, when the 

770 # innermost transaction is inside a savepoint=True transaction. 

771 with db.transaction(): 

772 # This insert should succeed, and should not be rolled back 

773 # because the assertRaises context should catch any 

774 # exception before it propagates up to the outer 

775 # transaction. 

776 db.insert(tables.a, {"name": "a3"}) 

777 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

778 with db.transaction(savepoint=True): 

779 # This insert should succeed, but should be rolled back. 

780 db.insert(tables.a, {"name": "a4"}) 

781 with db.transaction(): 

782 # This insert should succeed, but should be rolled 

783 # back. 

784 db.insert(tables.a, {"name": "a5"}) 

785 # This insert should fail (duplicate primary key), 

786 # raising an exception. 

787 db.insert(tables.a, {"name": "a1"}) 

788 self.assertCountEqual( 

789 [r._asdict() for r in self.query_list(db, tables.a.select())], 

790 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

791 ) 

792 

793 def testTransactionLocking(self): 

794 """Test that `Database.transaction` can be used to acquire a lock 

795 that prohibits concurrent writes. 

796 """ 

797 db1 = self.makeEmptyDatabase(origin=1) 

798 with db1.declareStaticTables(create=True) as context: 

799 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

800 

801 async def _side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]: 

802 """One side of the concurrent locking test. 

803 

804 Parameters 

805 ---------- 

806 lock : `~collections.abc.Iterable` of `str` 

807 Locks. 

808 

809 Notes 

810 ----- 

811 This optionally locks the table (and maybe the whole database), 

812 does a select for its contents, inserts a new row, and then selects 

813 again, with some waiting in between to make sure the other side has 

814 a chance to _attempt_ to insert in between. If the locking is 

815 enabled and works, the difference between the selects should just 

816 be the insert done on this thread. 

817 """ 

818 # Give Side2 a chance to create a connection 

819 await asyncio.sleep(1.0) 

820 with db1.transaction(lock=lock): 

821 names1 = {row.name for row in self.query_list(db1, tables1.a.select())} 

822 # Give Side2 a chance to insert (which will be blocked if 

823 # we've acquired a lock). 

824 await asyncio.sleep(2.0) 

825 db1.insert(tables1.a, {"name": "a1"}) 

826 names2 = {row.name for row in self.query_list(db1, tables1.a.select())} 

827 return names1, names2 

828 

829 async def _side2() -> None: 

830 """Other side of the concurrent locking test. 

831 

832 Notes 

833 ----- 

834 This side just waits a bit and then tries to insert a row into the 

835 table that the other side is trying to lock. Hopefully that 

836 waiting is enough to give the other side a chance to acquire the 

837 lock and thus make this side block until the lock is released. If 

838 this side manages to do the insert before side1 acquires the lock, 

839 we'll just warn about not succeeding at testing the locking, 

840 because we can only make that unlikely, not impossible. 

841 """ 

842 

843 def _toRunInThread(): 

844 """Create new SQLite connection for use in thread. 

845 

846 SQLite locking isn't asyncio-friendly unless we actually 

847 run it in another thread. And SQLite gets very unhappy if 

848 we try to use a connection from multiple threads, so we have 

849 to create the new connection here instead of out in the main 

850 body of the test function. 

851 """ 

852 db2 = self.getNewConnection(db1, writeable=True) 

853 with db2.declareStaticTables(create=False) as context: 

854 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

855 with db2.transaction(): 

856 db2.insert(tables2.a, {"name": "a2"}) 

857 

858 await asyncio.sleep(2.0) 

859 loop = asyncio.get_running_loop() 

860 with ThreadPoolExecutor() as pool: 

861 await loop.run_in_executor(pool, _toRunInThread) 

862 

863 async def _testProblemsWithNoLocking() -> None: 

864 """Run side1 and side2 with no locking, attempting to demonstrate 

865 the problem that locking is supposed to solve. If we get unlucky 

866 with scheduling, side2 will just happen to insert after side1 is 

867 done, and we won't have anything definitive. We just warn in that 

868 case because we really don't want spurious test failures. 

869 """ 

870 task1 = asyncio.create_task(_side1()) 

871 task2 = asyncio.create_task(_side2()) 

872 

873 names1, names2 = await task1 

874 await task2 

875 if "a2" in names1: 

876 warnings.warn( 

877 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.", 

878 stacklevel=1, 

879 ) 

880 self.assertEqual(names1, {"a2"}) 

881 self.assertEqual(names2, {"a1", "a2"}) 

882 elif "a2" not in names2: 

883 warnings.warn( 

884 "Unlucky scheduling in no-locking test: concurrent INSERT " 

885 "happened after second SELECT even without locking.", 

886 stacklevel=1, 

887 ) 

888 self.assertEqual(names1, set()) 

889 self.assertEqual(names2, {"a1"}) 

890 else: 

891 # This is the expected case: both INSERTS happen between the 

892 # two SELECTS. If we don't get this almost all of the time we 

893 # should adjust the sleep amounts. 

894 self.assertEqual(names1, set()) 

895 self.assertEqual(names2, {"a1", "a2"}) 

896 

897 asyncio.run(_testProblemsWithNoLocking()) 

898 

899 # Clean up after first test. 

900 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

901 

902 async def _testSolutionWithLocking() -> None: 

903 """Run side1 and side2 with locking, which should make side2 block 

904 its insert until side2 releases its lock. 

905 """ 

906 task1 = asyncio.create_task(_side1(lock=[tables1.a])) 

907 task2 = asyncio.create_task(_side2()) 

908 

909 names1, names2 = await task1 

910 await task2 

911 if "a2" in names1: 

912 warnings.warn( 

913 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.", 

914 stacklevel=1, 

915 ) 

916 self.assertEqual(names1, {"a2"}) 

917 self.assertEqual(names2, {"a1", "a2"}) 

918 else: 

919 # This is the expected case: the side2 INSERT happens after the 

920 # last SELECT on side1. This can also happen due to unlucky 

921 # scheduling, and we have no way to detect that here, but the 

922 # similar "no-locking" test has at least some chance of being 

923 # affected by the same problem and warning about it. 

924 self.assertEqual(names1, set()) 

925 self.assertEqual(names2, {"a1"}) 

926 

927 asyncio.run(_testSolutionWithLocking()) 

928 

929 def testTimespanDatabaseRepresentation(self): 

930 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

931 methods that interact with it. 

932 """ 

933 # Make some test timespans to play with, with the full suite of 

934 # topological relationships. 

935 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

936 offset = astropy.time.TimeDelta(60, format="sec") 

937 timestamps = [start + offset * n for n in range(3)] 

938 aTimespans = [Timespan(begin=None, end=None)] 

939 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

940 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

941 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

942 aTimespans.append(Timespan.makeEmpty()) 

943 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

944 # Make another list of timespans that span the full range but don't 

945 # overlap. This is a subset of the previous list. 

946 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

947 bTimespans.extend( 

948 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True) 

949 ) 

950 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

951 # Make a database and create a table with that database's timespan 

952 # representation. This one will have no exclusion constraint and 

953 # a nullable timespan. 

954 db = self.makeEmptyDatabase(origin=1) 

955 TimespanReprClass = db.getTimespanRepresentation() 

956 aSpec = ddl.TableSpec( 

957 fields=[ 

958 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

959 ], 

960 ) 

961 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

962 aSpec.fields.add(fieldSpec) 

963 with db.declareStaticTables(create=True) as context: 

964 aTable = context.addTable("a", aSpec) 

965 self.maxDiff = None 

966 

967 def _convertRowForInsert(row: dict) -> dict: 

968 """Convert a row containing a Timespan instance into one suitable 

969 for insertion into the database. 

970 """ 

971 result = row.copy() 

972 ts = result.pop(TimespanReprClass.NAME) 

973 return TimespanReprClass.update(ts, result=result) 

974 

975 def _convertRowFromSelect(row: dict) -> dict: 

976 """Convert a row from the database into one containing a Timespan. 

977 

978 Parameters 

979 ---------- 

980 row : `dict` 

981 Original row. 

982 

983 Returns 

984 ------- 

985 row : `dict` 

986 The updated row. 

987 """ 

988 result = row.copy() 

989 timespan = TimespanReprClass.extract(result) 

990 for name in TimespanReprClass.getFieldNames(): 

991 del result[name] 

992 result[TimespanReprClass.NAME] = timespan 

993 return result 

994 

995 # Insert rows into table A, in chunks just to make things interesting. 

996 # Include one with a NULL timespan. 

997 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

998 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

999 db.insert(aTable, _convertRowForInsert(aRows[0])) 

1000 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[1:3]]) 

1001 db.insert(aTable, *[_convertRowForInsert(r) for r in aRows[3:]]) 

1002 # Add another one with a NULL timespan, but this time by invoking 

1003 # the server-side default. 

1004 aRows.append({"id": len(aRows)}) 

1005 db.insert(aTable, aRows[-1]) 

1006 aRows[-1][TimespanReprClass.NAME] = None 

1007 # Test basic round-trip through database. 

1008 self.assertEqual( 

1009 aRows, 

1010 [ 

1011 _convertRowFromSelect(row._asdict()) 

1012 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id)) 

1013 ], 

1014 ) 

1015 # Create another table B with a not-null timespan and (if the database 

1016 # supports it), an exclusion constraint. Use ensureTableExists this 

1017 # time to check that mode of table creation vs. timespans. 

1018 bSpec = ddl.TableSpec( 

1019 fields=[ 

1020 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

1021 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

1022 ], 

1023 ) 

1024 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

1025 bSpec.fields.add(fieldSpec) 

1026 if TimespanReprClass.hasExclusionConstraint(): 

1027 bSpec.exclusion.add(("key", TimespanReprClass)) 

1028 bTable = db.ensureTableExists("b", bSpec) 

1029 # Insert rows into table B, again in chunks. Each Timespan appears 

1030 # twice, but with different values for the 'key' field (which should 

1031 # still be okay for any exclusion constraint we may have defined). 

1032 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

1033 offset = len(bRows) 

1034 bRows.extend( 

1035 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

1036 ) 

1037 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[:2]]) 

1038 db.insert(bTable, _convertRowForInsert(bRows[2])) 

1039 db.insert(bTable, *[_convertRowForInsert(r) for r in bRows[3:]]) 

1040 # Insert a row with no timespan into table B. This should invoke the 

1041 # server-side default, which is a timespan over (-∞, ∞). We set 

1042 # key=3 to avoid upsetting an exclusion constraint that might exist. 

1043 bRows.append({"id": len(bRows), "key": 3}) 

1044 db.insert(bTable, bRows[-1]) 

1045 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

1046 # Test basic round-trip through database. 

1047 self.assertEqual( 

1048 bRows, 

1049 [ 

1050 _convertRowFromSelect(row._asdict()) 

1051 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id)) 

1052 ], 

1053 ) 

1054 # Test that we can't insert timespan=None into this table. 

1055 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1056 db.insert( 

1057 bTable, _convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None}) 

1058 ) 

1059 # IFF this database supports exclusion constraints, test that they 

1060 # also prevent inserts. 

1061 if TimespanReprClass.hasExclusionConstraint(): 

1062 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1063 db.insert( 

1064 bTable, 

1065 _convertRowForInsert( 

1066 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

1067 ), 

1068 ) 

1069 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1070 db.insert( 

1071 bTable, 

1072 _convertRowForInsert( 

1073 { 

1074 "id": len(bRows), 

1075 "key": 1, 

1076 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

1077 } 

1078 ), 

1079 ) 

1080 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1081 db.insert( 

1082 bTable, 

1083 _convertRowForInsert( 

1084 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

1085 ), 

1086 ) 

1087 # Test NULL checks in SELECT queries, on both tables. 

1088 aRepr = TimespanReprClass.from_columns(aTable.columns) 

1089 self.assertEqual( 

1090 [row[TimespanReprClass.NAME] is None for row in aRows], 

1091 [ 

1092 row.f 

1093 for row in self.query_list( 

1094 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

1095 ) 

1096 ], 

1097 ) 

1098 bRepr = TimespanReprClass.from_columns(bTable.columns) 

1099 self.assertEqual( 

1100 [False for row in bRows], 

1101 [ 

1102 row.f 

1103 for row in self.query_list( 

1104 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

1105 ) 

1106 ], 

1107 ) 

1108 # Test relationships expressions that relate in-database timespans to 

1109 # Python-literal timespans, all from the more complete 'a' set; check 

1110 # that this is consistent with Python-only relationship tests. 

1111 for rhsRow in aRows: 

1112 if rhsRow[TimespanReprClass.NAME] is None: 

1113 continue 

1114 with self.subTest(rhsRow=rhsRow): 

1115 expected = {} 

1116 for lhsRow in aRows: 

1117 if lhsRow[TimespanReprClass.NAME] is None: 

1118 expected[lhsRow["id"]] = (None, None, None, None) 

1119 else: 

1120 expected[lhsRow["id"]] = ( 

1121 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

1122 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

1123 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

1124 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

1125 ) 

1126 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1127 sql = sqlalchemy.sql.select( 

1128 aTable.columns.id.label("lhs"), 

1129 aRepr.overlaps(rhsRepr).label("overlaps"), 

1130 aRepr.contains(rhsRepr).label("contains"), 

1131 (aRepr < rhsRepr).label("less_than"), 

1132 (aRepr > rhsRepr).label("greater_than"), 

1133 ).select_from(aTable) 

1134 queried = { 

1135 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1136 for row in self.query_list(db, sql) 

1137 } 

1138 self.assertEqual(expected, queried) 

1139 # Test relationship expressions that relate in-database timespans to 

1140 # each other, all from the more complete 'a' set; check that this is 

1141 # consistent with Python-only relationship tests. 

1142 expected = {} 

1143 for lhs, rhs in itertools.product(aRows, aRows): 

1144 lhsT = lhs[TimespanReprClass.NAME] 

1145 rhsT = rhs[TimespanReprClass.NAME] 

1146 if lhsT is not None and rhsT is not None: 

1147 expected[lhs["id"], rhs["id"]] = ( 

1148 lhsT.overlaps(rhsT), 

1149 lhsT.contains(rhsT), 

1150 lhsT < rhsT, 

1151 lhsT > rhsT, 

1152 ) 

1153 else: 

1154 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1155 lhsSubquery = aTable.alias("lhs") 

1156 rhsSubquery = aTable.alias("rhs") 

1157 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns) 

1158 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns) 

1159 sql = sqlalchemy.sql.select( 

1160 lhsSubquery.columns.id.label("lhs"), 

1161 rhsSubquery.columns.id.label("rhs"), 

1162 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1163 lhsRepr.contains(rhsRepr).label("contains"), 

1164 (lhsRepr < rhsRepr).label("less_than"), 

1165 (lhsRepr > rhsRepr).label("greater_than"), 

1166 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1167 queried = { 

1168 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1169 for row in self.query_list(db, sql) 

1170 } 

1171 self.assertEqual(expected, queried) 

1172 # Test relationship expressions between in-database timespans and 

1173 # Python-literal instantaneous times. 

1174 for t in timestamps: 

1175 with self.subTest(t=t): 

1176 expected = {} 

1177 for lhsRow in aRows: 

1178 if lhsRow[TimespanReprClass.NAME] is None: 

1179 expected[lhsRow["id"]] = (None, None, None, None) 

1180 else: 

1181 expected[lhsRow["id"]] = ( 

1182 lhsRow[TimespanReprClass.NAME].contains(t), 

1183 lhsRow[TimespanReprClass.NAME].overlaps(t), 

1184 lhsRow[TimespanReprClass.NAME] < t, 

1185 lhsRow[TimespanReprClass.NAME] > t, 

1186 ) 

1187 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1188 sql = sqlalchemy.sql.select( 

1189 aTable.columns.id.label("lhs"), 

1190 aRepr.contains(rhs).label("contains"), 

1191 aRepr.overlaps(rhs).label("overlaps_point"), 

1192 (aRepr < rhs).label("less_than"), 

1193 (aRepr > rhs).label("greater_than"), 

1194 ).select_from(aTable) 

1195 queried = { 

1196 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than) 

1197 for row in self.query_list(db, sql) 

1198 } 

1199 self.assertEqual(expected, queried) 

1200 

1201 def testConstantRows(self): 

1202 """Test Database.constant_rows.""" 

1203 new_db = self.makeEmptyDatabase() 

1204 with new_db.declareStaticTables(create=True) as context: 

1205 static = context.addTableTuple(STATIC_TABLE_SPECS) 

1206 b_ids = new_db.insert( 

1207 static.b, 

1208 {"name": "b1", "value": 11}, 

1209 {"name": "b2", "value": 12}, 

1210 {"name": "b3", "value": 13}, 

1211 returnIds=True, 

1212 ) 

1213 values_spec = ddl.TableSpec( 

1214 [ 

1215 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger), 

1216 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)), 

1217 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()), 

1218 ], 

1219 ) 

1220 values_data = [ 

1221 {"b": b_ids[0], "s": "b1", "r": None}, 

1222 {"b": b_ids[2], "s": "b3", "r": Circle.empty()}, 

1223 ] 

1224 values = new_db.constant_rows(values_spec.fields, *values_data) 

1225 select_values_alone = sqlalchemy.sql.select( 

1226 values.columns["b"], values.columns["s"], values.columns["r"] 

1227 ) 

1228 self.assertCountEqual( 

1229 [row._mapping for row in self.query_list(new_db, select_values_alone)], 

1230 values_data, 

1231 ) 

1232 select_values_joined = sqlalchemy.sql.select( 

1233 values.columns["s"].label("name"), static.b.columns["value"].label("value") 

1234 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"])) 

1235 self.assertCountEqual( 

1236 [row._mapping for row in self.query_list(new_db, select_values_joined)], 

1237 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}], 

1238 ) 

1239 

1240 def test_aggregate(self) -> None: 

1241 """Test Database.apply_any_aggregate, ddl.Base64Region.union_aggregate, 

1242 and TimespanDatabaseRepresetnation.apply_any_aggregate. 

1243 """ 

1244 db = self.makeEmptyDatabase() 

1245 with db.declareStaticTables(create=True) as context: 

1246 t = context.addTable( 

1247 "t", 

1248 ddl.TableSpec( 

1249 [ 

1250 ddl.FieldSpec("id", sqlalchemy.BigInteger(), nullable=False), 

1251 ddl.FieldSpec("name", sqlalchemy.String(16), nullable=False), 

1252 ddl.FieldSpec.for_region(), 

1253 ] 

1254 + list(db.getTimespanRepresentation().makeFieldSpecs(nullable=True)), 

1255 ), 

1256 ) 

1257 pixelization = Mq3cPixelization(10) 

1258 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

1259 offset = astropy.time.TimeDelta(60, format="sec") 

1260 timespans = [Timespan(begin=start + offset * n, end=start + offset * (n + 1)) for n in range(3)] 

1261 ts_cls = db.getTimespanRepresentation() 

1262 ts_col = ts_cls.from_columns(t.columns) 

1263 db.insert( 

1264 t, 

1265 ts_cls.update(timespans[0], result={"id": 1, "name": "a", "region": pixelization.quad(12058870)}), 

1266 ts_cls.update(timespans[1], result={"id": 2, "name": "a", "region": pixelization.quad(12058871)}), 

1267 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058872)}), 

1268 ts_cls.update(timespans[2], result={"id": 3, "name": "b", "region": pixelization.quad(12058873)}), 

1269 ) 

1270 # This should use DISTINCT ON in PostgreSQL and GROUP BY in SQLite. 

1271 if db.has_distinct_on: 

1272 sql = ( 

1273 sqlalchemy.select( 

1274 t.c.id.label("i"), 

1275 t.c.name.label("n"), 

1276 *ts_col.flatten("t"), 

1277 ) 

1278 .select_from(t) 

1279 .distinct(t.c.id) 

1280 ) 

1281 elif db.has_any_aggregate: 

1282 sql = ( 

1283 sqlalchemy.select( 

1284 t.c.id.label("i"), 

1285 db.apply_any_aggregate(t.c.name).label("n"), 

1286 *ts_col.apply_any_aggregate(db.apply_any_aggregate).flatten("t"), 

1287 ) 

1288 .select_from(t) 

1289 .group_by(t.c.id) 

1290 ) 

1291 else: 

1292 raise AssertionError( 

1293 "PostgreSQL should support DISTINCT ON and SQLite should support no-op any aggregates." 

1294 ) 

1295 self.assertCountEqual( 

1296 [(row.i, row.n, ts_cls.extract(row._mapping, "t")) for row in self.query_list(db, sql)], 

1297 [(1, "a", timespans[0]), (2, "a", timespans[1]), (3, "b", timespans[2])], 

1298 ) 

1299 # Test union_aggregate in all versions of both database, with a GROUP 

1300 # BY that does not need apply_any_aggregate. 

1301 self.assertCountEqual( 

1302 [ 

1303 (row.i, row.r.encode()) 

1304 for row in self.query_list( 

1305 db, 

1306 sqlalchemy.select( 

1307 t.c.id.label("i"), ddl.Base64Region.union_aggregate(t.c.region).label("r") 

1308 ) 

1309 .select_from(t) 

1310 .group_by(t.c.id), 

1311 ) 

1312 ], 

1313 [ 

1314 (1, pixelization.quad(12058870).encode()), 

1315 (2, pixelization.quad(12058871).encode()), 

1316 (3, UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()), 

1317 ], 

1318 ) 

1319 if db.has_any_aggregate: 

1320 # This should use run in SQLite and PostgreSQL 16+. 

1321 self.assertCountEqual( 

1322 [ 

1323 (row.i, row.n, row.r.encode()) 

1324 for row in self.query_list( 

1325 db, 

1326 sqlalchemy.select( 

1327 t.c.id.label("i"), 

1328 db.apply_any_aggregate(t.c.name).label("n"), 

1329 ddl.Base64Region.union_aggregate(t.c.region).label("r"), 

1330 ) 

1331 .select_from(t) 

1332 .group_by(t.c.id), 

1333 ) 

1334 ], 

1335 [ 

1336 (1, "a", pixelization.quad(12058870).encode()), 

1337 (2, "a", pixelization.quad(12058871).encode()), 

1338 (3, "b", UnionRegion(pixelization.quad(12058872), pixelization.quad(12058873)).encode()), 

1339 ], 

1340 )