Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%

491 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 08:00 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ["DatabaseTests"] 

30 

31import asyncio 

32import itertools 

33import warnings 

34from abc import ABC, abstractmethod 

35from collections import namedtuple 

36from collections.abc import Iterable 

37from concurrent.futures import ThreadPoolExecutor 

38from contextlib import AbstractContextManager, contextmanager 

39from typing import Any 

40 

41import astropy.time 

42import sqlalchemy 

43from lsst.sphgeom import Circle, ConvexPolygon, UnitVector3d 

44 

45from ...core import Timespan, ddl 

46from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

47 

48StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

49 

50STATIC_TABLE_SPECS = StaticTablesTuple( 

51 a=ddl.TableSpec( 

52 fields=[ 

53 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

54 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

55 ] 

56 ), 

57 b=ddl.TableSpec( 

58 fields=[ 

59 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

60 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

61 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

62 ], 

63 unique=[("name",)], 

64 ), 

65 c=ddl.TableSpec( 

66 fields=[ 

67 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

68 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

69 ], 

70 foreignKeys=[ 

71 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

72 ], 

73 ), 

74) 

75 

76DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

77 fields=[ 

78 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

79 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

80 ], 

81 foreignKeys=[ 

82 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"), 

83 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

84 ], 

85) 

86 

87TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

88 fields=[ 

89 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

90 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

91 ], 

92) 

93 

94 

95@contextmanager 

96def _patch_getExistingTable(db: Database) -> Database: 

97 """Patch getExistingTable method in a database instance to test concurrent 

98 creation of tables. This patch obviously depends on knowning internals of 

99 ``ensureTableExists()`` implementation. 

100 """ 

101 original_method = db.getExistingTable 

102 

103 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None: 

104 # Return None on first call, but forward to original method after that 

105 db.getExistingTable = original_method 

106 return None 

107 

108 db.getExistingTable = _getExistingTable 

109 yield db 

110 db.getExistingTable = original_method 

111 

112 

113class DatabaseTests(ABC): 

114 """Generic tests for the `Database` interface that can be subclassed to 

115 generate tests for concrete implementations. 

116 """ 

117 

118 @abstractmethod 

119 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

120 """Return an empty `Database` with the given origin, or an 

121 automatically-generated one if ``origin`` is `None`. 

122 """ 

123 raise NotImplementedError() 

124 

125 @abstractmethod 

126 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]: 

127 """Return a context manager for a read-only connection into the given 

128 database. 

129 

130 The original database should be considered unusable within the context 

131 but safe to use again afterwards (this allows the context manager to 

132 block write access by temporarily changing user permissions to really 

133 guarantee that write operations are not performed). 

134 """ 

135 raise NotImplementedError() 

136 

137 @abstractmethod 

138 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

139 """Return a new `Database` instance that points to the same underlying 

140 storage as the given one. 

141 """ 

142 raise NotImplementedError() 

143 

144 def query_list( 

145 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase 

146 ) -> list[sqlalchemy.engine.Row]: 

147 """Run a SELECT or other read-only query against the database and 

148 return the results as a list. 

149 

150 This is a thin wrapper around database.query() that just avoids 

151 context-manager boilerplate that is usefully verbose in production code 

152 but just noise in tests. 

153 """ 

154 with database.transaction(), database.query(executable) as result: 

155 return result.fetchall() 

156 

157 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any: 

158 """Run a SELECT query that yields a single column and row against the 

159 database and return its value. 

160 

161 This is a thin wrapper around database.query() that just avoids 

162 context-manager boilerplate that is usefully verbose in production code 

163 but just noise in tests. 

164 """ 

165 with database.query(executable) as result: 

166 return result.scalar() 

167 

168 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

169 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

170 # Checking more than this currently seems fragile, as it might restrict 

171 # what Database implementations do; we don't care if the spec is 

172 # actually preserved in terms of types and constraints as long as we 

173 # can use the returned table as if it was. 

174 

175 def checkStaticSchema(self, tables: StaticTablesTuple): 

176 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

177 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

178 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

179 

180 def testDeclareStaticTables(self): 

181 """Tests for `Database.declareStaticSchema` and the methods it 

182 delegates to. 

183 """ 

184 # Create the static schema in a new, empty database. 

185 newDatabase = self.makeEmptyDatabase() 

186 with newDatabase.declareStaticTables(create=True) as context: 

187 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

188 self.checkStaticSchema(tables) 

189 # Check that we can load that schema even from a read-only connection. 

190 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

191 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

192 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

193 self.checkStaticSchema(tables) 

194 

195 def testDeclareStaticTablesTwice(self): 

196 """Tests for `Database.declareStaticSchema` being called twice.""" 

197 # Create the static schema in a new, empty database. 

198 newDatabase = self.makeEmptyDatabase() 

199 with newDatabase.declareStaticTables(create=True) as context: 

200 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

201 self.checkStaticSchema(tables) 

202 # Second time it should raise 

203 with self.assertRaises(SchemaAlreadyDefinedError): 

204 with newDatabase.declareStaticTables(create=True) as context: 

205 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

206 # Check schema, it should still contain all tables, and maybe some 

207 # extra. 

208 with newDatabase.declareStaticTables(create=False) as context: 

209 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

210 

211 def testRepr(self): 

212 """Test that repr does not return a generic thing.""" 

213 newDatabase = self.makeEmptyDatabase() 

214 rep = repr(newDatabase) 

215 # Check that stringification works and gives us something different 

216 self.assertNotEqual(rep, str(newDatabase)) 

217 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

218 self.assertIn("://", rep) 

219 

220 def testDynamicTables(self): 

221 """Tests for `Database.ensureTableExists` and 

222 `Database.getExistingTable`. 

223 """ 

224 # Need to start with the static schema. 

225 newDatabase = self.makeEmptyDatabase() 

226 with newDatabase.declareStaticTables(create=True) as context: 

227 context.addTableTuple(STATIC_TABLE_SPECS) 

228 # Try to ensure the dynamic table exists in a read-only version of that 

229 # database, which should fail because we can't create it. 

230 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

231 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

232 context.addTableTuple(STATIC_TABLE_SPECS) 

233 with self.assertRaises(ReadOnlyDatabaseError): 

234 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

235 # Just getting the dynamic table before it exists should return None. 

236 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

237 # Ensure the new table exists back in the original database, which 

238 # should create it. 

239 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

240 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

241 # Ensuring that it exists should just return the exact same table 

242 # instance again. 

243 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

244 # Try again from the read-only database. 

245 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

246 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

247 context.addTableTuple(STATIC_TABLE_SPECS) 

248 # Just getting the dynamic table should now work... 

249 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

250 # ...as should ensuring that it exists, since it now does. 

251 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

252 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

253 # Trying to get the table with a different specification (at least 

254 # in terms of what columns are present) should raise. 

255 with self.assertRaises(DatabaseConflictError): 

256 newDatabase.ensureTableExists( 

257 "d", 

258 ddl.TableSpec( 

259 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

260 ), 

261 ) 

262 # Calling ensureTableExists inside a transaction block is an error, 

263 # even if it would do nothing. 

264 with newDatabase.transaction(): 

265 with self.assertRaises(AssertionError): 

266 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

267 

268 def testDynamicTablesConcurrency(self): 

269 """Tests for `Database.ensureTableExists` concurrent use.""" 

270 # We cannot really run things concurrently in a deterministic way, here 

271 # we just simulate a situation when the table is created by other 

272 # process between the call to getExistingTable() and actual table 

273 # creation. 

274 db1 = self.makeEmptyDatabase() 

275 with db1.declareStaticTables(create=True) as context: 

276 context.addTableTuple(STATIC_TABLE_SPECS) 

277 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

278 

279 # Make a dynamic table using separate connection 

280 db2 = self.getNewConnection(db1, writeable=True) 

281 with db2.declareStaticTables(create=False) as context: 

282 context.addTableTuple(STATIC_TABLE_SPECS) 

283 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

284 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

285 

286 # Call it again but trick it into thinking that table is not there. 

287 # This test depends on knowing implementation of ensureTableExists() 

288 # which initially calls getExistingTable() to check that table may 

289 # exist, the patch intercepts that call and returns None. 

290 with _patch_getExistingTable(db1): 

291 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

292 

293 def testTemporaryTables(self): 

294 """Tests for `Database.temporary_table`, and `Database.insert` with the 

295 ``select`` argument. 

296 """ 

297 # Need to start with the static schema; also insert some test data. 

298 newDatabase = self.makeEmptyDatabase() 

299 with newDatabase.declareStaticTables(create=True) as context: 

300 static = context.addTableTuple(STATIC_TABLE_SPECS) 

301 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

302 bIds = newDatabase.insert( 

303 static.b, 

304 {"name": "b1", "value": 11}, 

305 {"name": "b2", "value": 12}, 

306 {"name": "b3", "value": 13}, 

307 returnIds=True, 

308 ) 

309 # Create the table. 

310 with newDatabase.session(): 

311 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1: 

312 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

313 # Insert via a INSERT INTO ... SELECT query. 

314 newDatabase.insert( 

315 table1, 

316 select=sqlalchemy.sql.select( 

317 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

318 ) 

319 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

320 .where( 

321 sqlalchemy.sql.and_( 

322 static.a.columns.name == "a1", 

323 static.b.columns.value <= 12, 

324 ) 

325 ), 

326 ) 

327 # Check that the inserted rows are present. 

328 self.assertCountEqual( 

329 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

330 [row._asdict() for row in self.query_list(newDatabase, table1.select())], 

331 ) 

332 # Create another one via a read-only connection to the 

333 # database. We _do_ allow temporary table modifications in 

334 # read-only databases. 

335 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

336 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

337 context.addTableTuple(STATIC_TABLE_SPECS) 

338 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2: 

339 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

340 # Those tables should not be the same, despite having 

341 # the same ddl. 

342 self.assertIsNot(table1, table2) 

343 # Do a slightly different insert into this table, to 

344 # check that it works in a read-only database. This 

345 # time we pass column names as a kwarg to insert 

346 # instead of by labeling the columns in the select. 

347 existingReadOnlyDatabase.insert( 

348 table2, 

349 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

350 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

351 .where( 

352 sqlalchemy.sql.and_( 

353 static.a.columns.name == "a2", 

354 static.b.columns.value >= 12, 

355 ) 

356 ), 

357 names=["a_name", "b_id"], 

358 ) 

359 # Check that the inserted rows are present. 

360 self.assertCountEqual( 

361 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

362 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())], 

363 ) 

364 # Exiting the context managers will drop the temporary tables from 

365 # the read-only DB. It's unspecified whether attempting to use it 

366 # after this point is an error or just never returns any results, 

367 # so we can't test what it does, only that it's not an error. 

368 

369 def testSchemaSeparation(self): 

370 """Test that creating two different `Database` instances allows us 

371 to create different tables with the same name in each. 

372 """ 

373 db1 = self.makeEmptyDatabase(origin=1) 

374 with db1.declareStaticTables(create=True) as context: 

375 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

376 self.checkStaticSchema(tables) 

377 

378 db2 = self.makeEmptyDatabase(origin=2) 

379 # Make the DDL here intentionally different so we'll definitely 

380 # notice if db1 and db2 are pointing at the same schema. 

381 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

382 with db2.declareStaticTables(create=True) as context: 

383 # Make the DDL here intentionally different so we'll definitely 

384 # notice if db1 and db2 are pointing at the same schema. 

385 table = context.addTable("a", spec) 

386 self.checkTable(spec, table) 

387 

388 def testInsertQueryDelete(self): 

389 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

390 methods, as well as the `Base64Region` type and the ``onDelete`` 

391 argument to `ddl.ForeignKeySpec`. 

392 """ 

393 db = self.makeEmptyDatabase(origin=1) 

394 with db.declareStaticTables(create=True) as context: 

395 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

396 # Insert a single, non-autoincrement row that contains a region and 

397 # query to get it back. 

398 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

399 row = {"name": "a1", "region": region} 

400 db.insert(tables.a, row) 

401 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row]) 

402 # Insert multiple autoincrement rows but do not try to get the IDs 

403 # back immediately. 

404 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

405 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))] 

406 self.assertEqual(len(results), 2) 

407 for row in results: 

408 self.assertIn(row["name"], ("b1", "b2")) 

409 self.assertIsInstance(row["id"], int) 

410 self.assertGreater(results[1]["id"], results[0]["id"]) 

411 # Insert multiple autoincrement rows and get the IDs back from insert. 

412 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

413 ids = db.insert(tables.b, *rows, returnIds=True) 

414 results = [ 

415 r._asdict() 

416 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

417 ] 

418 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

419 self.assertCountEqual(results, expected) 

420 self.assertTrue(all(result["id"] is not None for result in results)) 

421 # Insert multiple rows into a table with an autoincrement primary key, 

422 # then use the returned IDs to insert into a dynamic table. 

423 rows = [{"b_id": results[0]["id"]}, {"b_id": None}] 

424 ids = db.insert(tables.c, *rows, returnIds=True) 

425 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

426 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

427 self.assertCountEqual(results, expected) 

428 self.assertTrue(all(result["id"] is not None for result in results)) 

429 # Add the dynamic table. 

430 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

431 # Insert into it. 

432 rows = [{"c_id": id, "a_name": "a1"} for id in ids] 

433 db.insert(d, *rows) 

434 results = [r._asdict() for r in self.query_list(db, d.select())] 

435 self.assertCountEqual(rows, results) 

436 # Insert multiple rows into a table with an autoincrement primary key, 

437 # but pass in a value for the autoincrement key. 

438 rows2 = [ 

439 {"id": 700, "b_id": None}, 

440 {"id": 701, "b_id": None}, 

441 ] 

442 db.insert(tables.c, *rows2) 

443 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

444 self.assertCountEqual(results, expected + rows2) 

445 self.assertTrue(all(result["id"] is not None for result in results)) 

446 

447 # Define 'SELECT COUNT(*)' query for later use. 

448 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

449 # Get the values we inserted into table b. 

450 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())] 

451 # Remove two row from table b by ID. 

452 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

453 self.assertEqual(n, 2) 

454 # Remove the other two rows from table b by name. 

455 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

456 self.assertEqual(n, 2) 

457 # There should now be no rows in table b. 

458 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

459 # All b_id values in table c should now be NULL, because there's an 

460 # onDelete='SET NULL' foreign key. 

461 self.assertEqual( 

462 self.query_scalar( 

463 db, 

464 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711 

465 ), 

466 0, 

467 ) 

468 # Remove all rows in table a (there's only one); this should remove all 

469 # rows in d due to onDelete='CASCADE'. 

470 n = db.delete(tables.a, []) 

471 self.assertEqual(n, 1) 

472 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0) 

473 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0) 

474 

475 def testDeleteWhere(self): 

476 """Tests for `Database.deleteWhere`.""" 

477 db = self.makeEmptyDatabase(origin=1) 

478 with db.declareStaticTables(create=True) as context: 

479 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

480 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

481 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

482 

483 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

484 self.assertEqual(n, 3) 

485 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7) 

486 

487 n = db.deleteWhere( 

488 tables.b, 

489 tables.b.columns.id.in_( 

490 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

491 ), 

492 ) 

493 self.assertEqual(n, 4) 

494 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3) 

495 

496 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

497 self.assertEqual(n, 1) 

498 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2) 

499 

500 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

501 self.assertEqual(n, 2) 

502 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

503 

504 def testUpdate(self): 

505 """Tests for `Database.update`.""" 

506 db = self.makeEmptyDatabase(origin=1) 

507 with db.declareStaticTables(create=True) as context: 

508 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

509 # Insert two rows into table a, both without regions. 

510 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

511 # Update one of the rows with a region. 

512 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

513 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

514 self.assertEqual(n, 1) 

515 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

516 self.assertCountEqual( 

517 [r._asdict() for r in self.query_list(db, sql)], 

518 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

519 ) 

520 

521 def testSync(self): 

522 """Tests for `Database.sync`.""" 

523 db = self.makeEmptyDatabase(origin=1) 

524 with db.declareStaticTables(create=True) as context: 

525 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

526 # Insert a row with sync, because it doesn't exist yet. 

527 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

528 self.assertTrue(inserted) 

529 self.assertEqual( 

530 [{"id": values["id"], "name": "b1", "value": 10}], 

531 [r._asdict() for r in self.query_list(db, tables.b.select())], 

532 ) 

533 # Repeat that operation, which should do nothing but return the 

534 # requested values. 

535 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

536 self.assertFalse(inserted) 

537 self.assertEqual( 

538 [{"id": values["id"], "name": "b1", "value": 10}], 

539 [r._asdict() for r in self.query_list(db, tables.b.select())], 

540 ) 

541 # Repeat the operation without the 'extra' arg, which should also just 

542 # return the existing row. 

543 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

544 self.assertFalse(inserted) 

545 self.assertEqual( 

546 [{"id": values["id"], "name": "b1", "value": 10}], 

547 [r._asdict() for r in self.query_list(db, tables.b.select())], 

548 ) 

549 # Repeat the operation with a different value in 'extra'. That still 

550 # shouldn't be an error, because 'extra' is only used if we really do 

551 # insert. Also drop the 'returning' argument. 

552 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

553 self.assertFalse(inserted) 

554 self.assertEqual( 

555 [{"id": values["id"], "name": "b1", "value": 10}], 

556 [r._asdict() for r in self.query_list(db, tables.b.select())], 

557 ) 

558 # Repeat the operation with the correct value in 'compared' instead of 

559 # 'extra'. 

560 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

561 self.assertFalse(inserted) 

562 self.assertEqual( 

563 [{"id": values["id"], "name": "b1", "value": 10}], 

564 [r._asdict() for r in self.query_list(db, tables.b.select())], 

565 ) 

566 # Repeat the operation with an incorrect value in 'compared'; this 

567 # should raise. 

568 with self.assertRaises(DatabaseConflictError): 

569 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

570 # Try to sync in a read-only database. This should work if and only 

571 # if the matching row already exists. 

572 with self.asReadOnly(db) as rodb: 

573 with rodb.declareStaticTables(create=False) as context: 

574 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

575 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

576 self.assertFalse(inserted) 

577 self.assertEqual( 

578 [{"id": values["id"], "name": "b1", "value": 10}], 

579 [r._asdict() for r in self.query_list(rodb, tables.b.select())], 

580 ) 

581 with self.assertRaises(ReadOnlyDatabaseError): 

582 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

583 # Repeat the operation with a different value in 'compared' and ask to 

584 # update. 

585 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

586 self.assertEqual(updated, {"value": 10}) 

587 self.assertEqual( 

588 [{"id": values["id"], "name": "b1", "value": 20}], 

589 [r._asdict() for r in self.query_list(db, tables.b.select())], 

590 ) 

591 

592 def testReplace(self): 

593 """Tests for `Database.replace`.""" 

594 db = self.makeEmptyDatabase(origin=1) 

595 with db.declareStaticTables(create=True) as context: 

596 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

597 # Use 'replace' to insert a single row that contains a region and 

598 # query to get it back. 

599 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

600 row1 = {"name": "a1", "region": region} 

601 db.replace(tables.a, row1) 

602 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

603 # Insert another row without a region. 

604 row2 = {"name": "a2", "region": None} 

605 db.replace(tables.a, row2) 

606 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

607 # Use replace to re-insert both of those rows again, which should do 

608 # nothing. 

609 db.replace(tables.a, row1, row2) 

610 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

611 # Replace row1 with a row with no region, while reinserting row2. 

612 row1a = {"name": "a1", "region": None} 

613 db.replace(tables.a, row1a, row2) 

614 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2]) 

615 # Replace both rows, returning row1 to its original state, while adding 

616 # a new one. Pass them in in a different order. 

617 row2a = {"name": "a2", "region": region} 

618 row3 = {"name": "a3", "region": None} 

619 db.replace(tables.a, row3, row2a, row1) 

620 self.assertCountEqual( 

621 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3] 

622 ) 

623 

624 def testEnsure(self): 

625 """Tests for `Database.ensure`.""" 

626 db = self.makeEmptyDatabase(origin=1) 

627 with db.declareStaticTables(create=True) as context: 

628 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

629 # Use 'ensure' to insert a single row that contains a region and 

630 # query to get it back. 

631 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

632 row1 = {"name": "a1", "region": region} 

633 self.assertEqual(db.ensure(tables.a, row1), 1) 

634 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

635 # Insert another row without a region. 

636 row2 = {"name": "a2", "region": None} 

637 self.assertEqual(db.ensure(tables.a, row2), 1) 

638 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

639 # Use ensure to re-insert both of those rows again, which should do 

640 # nothing. 

641 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

642 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

643 # Attempt to insert row1's key with no region, while 

644 # reinserting row2. This should also do nothing. 

645 row1a = {"name": "a1", "region": None} 

646 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

647 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

648 # Attempt to insert new rows for both existing keys, this time also 

649 # adding a new row. Pass them in in a different order. Only the new 

650 # row should be added. 

651 row2a = {"name": "a2", "region": region} 

652 row3 = {"name": "a3", "region": None} 

653 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

654 self.assertCountEqual( 

655 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3] 

656 ) 

657 # Add some data to a table with both a primary key and a different 

658 # unique constraint. 

659 row_b = {"id": 5, "name": "five", "value": 50} 

660 db.insert(tables.b, row_b) 

661 # Attempt ensure with primary_key_only=False and a conflict for the 

662 # non-PK constraint. This should do nothing. 

663 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}) 

664 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

665 # Now use primary_key_only=True with conflict in only the non-PK field. 

666 # This should be an integrity error and nothing should change. 

667 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

668 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True) 

669 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

670 # With primary_key_only=True a conflict in the primary key is ignored 

671 # regardless of whether there is a conflict elsewhere. 

672 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True) 

673 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

674 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True) 

675 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

676 

677 def testTransactionNesting(self): 

678 """Test that transactions can be nested with the behavior in the 

679 presence of exceptions working as documented. 

680 """ 

681 db = self.makeEmptyDatabase(origin=1) 

682 with db.declareStaticTables(create=True) as context: 

683 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

684 # Insert one row so we can trigger integrity errors by trying to insert 

685 # a duplicate of it below. 

686 db.insert(tables.a, {"name": "a1"}) 

687 # First test: error recovery via explicit savepoint=True in the inner 

688 # transaction. 

689 with db.transaction(): 

690 # This insert should succeed, and should not be rolled back because 

691 # the assertRaises context should catch any exception before it 

692 # propagates up to the outer transaction. 

693 db.insert(tables.a, {"name": "a2"}) 

694 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

695 with db.transaction(savepoint=True): 

696 # This insert should succeed, but should be rolled back. 

697 db.insert(tables.a, {"name": "a4"}) 

698 # This insert should fail (duplicate primary key), raising 

699 # an exception. 

700 db.insert(tables.a, {"name": "a1"}) 

701 self.assertCountEqual( 

702 [r._asdict() for r in self.query_list(db, tables.a.select())], 

703 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

704 ) 

705 # Second test: error recovery via implicit savepoint=True, when the 

706 # innermost transaction is inside a savepoint=True transaction. 

707 with db.transaction(): 

708 # This insert should succeed, and should not be rolled back 

709 # because the assertRaises context should catch any 

710 # exception before it propagates up to the outer 

711 # transaction. 

712 db.insert(tables.a, {"name": "a3"}) 

713 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

714 with db.transaction(savepoint=True): 

715 # This insert should succeed, but should be rolled back. 

716 db.insert(tables.a, {"name": "a4"}) 

717 with db.transaction(): 

718 # This insert should succeed, but should be rolled 

719 # back. 

720 db.insert(tables.a, {"name": "a5"}) 

721 # This insert should fail (duplicate primary key), 

722 # raising an exception. 

723 db.insert(tables.a, {"name": "a1"}) 

724 self.assertCountEqual( 

725 [r._asdict() for r in self.query_list(db, tables.a.select())], 

726 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

727 ) 

728 

729 def testTransactionLocking(self): 

730 """Test that `Database.transaction` can be used to acquire a lock 

731 that prohibits concurrent writes. 

732 """ 

733 db1 = self.makeEmptyDatabase(origin=1) 

734 with db1.declareStaticTables(create=True) as context: 

735 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

736 

737 async def side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]: 

738 """One side of the concurrent locking test. 

739 

740 This optionally locks the table (and maybe the whole database), 

741 does a select for its contents, inserts a new row, and then selects 

742 again, with some waiting in between to make sure the other side has 

743 a chance to _attempt_ to insert in between. If the locking is 

744 enabled and works, the difference between the selects should just 

745 be the insert done on this thread. 

746 """ 

747 # Give Side2 a chance to create a connection 

748 await asyncio.sleep(1.0) 

749 with db1.transaction(lock=lock): 

750 names1 = {row.name for row in self.query_list(db1, tables1.a.select())} 

751 # Give Side2 a chance to insert (which will be blocked if 

752 # we've acquired a lock). 

753 await asyncio.sleep(2.0) 

754 db1.insert(tables1.a, {"name": "a1"}) 

755 names2 = {row.name for row in self.query_list(db1, tables1.a.select())} 

756 return names1, names2 

757 

758 async def side2() -> None: 

759 """Other side of the concurrent locking test. 

760 

761 This side just waits a bit and then tries to insert a row into the 

762 table that the other side is trying to lock. Hopefully that 

763 waiting is enough to give the other side a chance to acquire the 

764 lock and thus make this side block until the lock is released. If 

765 this side manages to do the insert before side1 acquires the lock, 

766 we'll just warn about not succeeding at testing the locking, 

767 because we can only make that unlikely, not impossible. 

768 """ 

769 

770 def toRunInThread(): 

771 """Create new SQLite connection for use in thread. 

772 

773 SQLite locking isn't asyncio-friendly unless we actually 

774 run it in another thread. And SQLite gets very unhappy if 

775 we try to use a connection from multiple threads, so we have 

776 to create the new connection here instead of out in the main 

777 body of the test function. 

778 """ 

779 db2 = self.getNewConnection(db1, writeable=True) 

780 with db2.declareStaticTables(create=False) as context: 

781 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

782 with db2.transaction(): 

783 db2.insert(tables2.a, {"name": "a2"}) 

784 

785 await asyncio.sleep(2.0) 

786 loop = asyncio.get_running_loop() 

787 with ThreadPoolExecutor() as pool: 

788 await loop.run_in_executor(pool, toRunInThread) 

789 

790 async def testProblemsWithNoLocking() -> None: 

791 """Run side1 and side2 with no locking, attempting to demonstrate 

792 the problem that locking is supposed to solve. If we get unlucky 

793 with scheduling, side2 will just happen to insert after side1 is 

794 done, and we won't have anything definitive. We just warn in that 

795 case because we really don't want spurious test failures. 

796 """ 

797 task1 = asyncio.create_task(side1()) 

798 task2 = asyncio.create_task(side2()) 

799 

800 names1, names2 = await task1 

801 await task2 

802 if "a2" in names1: 

803 warnings.warn( 

804 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.", 

805 stacklevel=1, 

806 ) 

807 self.assertEqual(names1, {"a2"}) 

808 self.assertEqual(names2, {"a1", "a2"}) 

809 elif "a2" not in names2: 

810 warnings.warn( 

811 "Unlucky scheduling in no-locking test: concurrent INSERT " 

812 "happened after second SELECT even without locking.", 

813 stacklevel=1, 

814 ) 

815 self.assertEqual(names1, set()) 

816 self.assertEqual(names2, {"a1"}) 

817 else: 

818 # This is the expected case: both INSERTS happen between the 

819 # two SELECTS. If we don't get this almost all of the time we 

820 # should adjust the sleep amounts. 

821 self.assertEqual(names1, set()) 

822 self.assertEqual(names2, {"a1", "a2"}) 

823 

824 asyncio.run(testProblemsWithNoLocking()) 

825 

826 # Clean up after first test. 

827 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

828 

829 async def testSolutionWithLocking() -> None: 

830 """Run side1 and side2 with locking, which should make side2 block 

831 its insert until side2 releases its lock. 

832 """ 

833 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

834 task2 = asyncio.create_task(side2()) 

835 

836 names1, names2 = await task1 

837 await task2 

838 if "a2" in names1: 

839 warnings.warn( 

840 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.", 

841 stacklevel=1, 

842 ) 

843 self.assertEqual(names1, {"a2"}) 

844 self.assertEqual(names2, {"a1", "a2"}) 

845 else: 

846 # This is the expected case: the side2 INSERT happens after the 

847 # last SELECT on side1. This can also happen due to unlucky 

848 # scheduling, and we have no way to detect that here, but the 

849 # similar "no-locking" test has at least some chance of being 

850 # affected by the same problem and warning about it. 

851 self.assertEqual(names1, set()) 

852 self.assertEqual(names2, {"a1"}) 

853 

854 asyncio.run(testSolutionWithLocking()) 

855 

856 def testTimespanDatabaseRepresentation(self): 

857 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

858 methods that interact with it. 

859 """ 

860 # Make some test timespans to play with, with the full suite of 

861 # topological relationships. 

862 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

863 offset = astropy.time.TimeDelta(60, format="sec") 

864 timestamps = [start + offset * n for n in range(3)] 

865 aTimespans = [Timespan(begin=None, end=None)] 

866 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

867 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

868 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

869 aTimespans.append(Timespan.makeEmpty()) 

870 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

871 # Make another list of timespans that span the full range but don't 

872 # overlap. This is a subset of the previous list. 

873 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

874 bTimespans.extend( 

875 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True) 

876 ) 

877 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

878 # Make a database and create a table with that database's timespan 

879 # representation. This one will have no exclusion constraint and 

880 # a nullable timespan. 

881 db = self.makeEmptyDatabase(origin=1) 

882 TimespanReprClass = db.getTimespanRepresentation() 

883 aSpec = ddl.TableSpec( 

884 fields=[ 

885 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

886 ], 

887 ) 

888 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

889 aSpec.fields.add(fieldSpec) 

890 with db.declareStaticTables(create=True) as context: 

891 aTable = context.addTable("a", aSpec) 

892 self.maxDiff = None 

893 

894 def convertRowForInsert(row: dict) -> dict: 

895 """Convert a row containing a Timespan instance into one suitable 

896 for insertion into the database. 

897 """ 

898 result = row.copy() 

899 ts = result.pop(TimespanReprClass.NAME) 

900 return TimespanReprClass.update(ts, result=result) 

901 

902 def convertRowFromSelect(row: dict) -> dict: 

903 """Convert a row from the database into one containing a Timespan. 

904 

905 Parameters 

906 ---------- 

907 row : `dict` 

908 Original row. 

909 

910 Returns 

911 ------- 

912 row : `dict` 

913 The updated row. 

914 """ 

915 result = row.copy() 

916 timespan = TimespanReprClass.extract(result) 

917 for name in TimespanReprClass.getFieldNames(): 

918 del result[name] 

919 result[TimespanReprClass.NAME] = timespan 

920 return result 

921 

922 # Insert rows into table A, in chunks just to make things interesting. 

923 # Include one with a NULL timespan. 

924 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

925 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

926 db.insert(aTable, convertRowForInsert(aRows[0])) 

927 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

928 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

929 # Add another one with a NULL timespan, but this time by invoking 

930 # the server-side default. 

931 aRows.append({"id": len(aRows)}) 

932 db.insert(aTable, aRows[-1]) 

933 aRows[-1][TimespanReprClass.NAME] = None 

934 # Test basic round-trip through database. 

935 self.assertEqual( 

936 aRows, 

937 [ 

938 convertRowFromSelect(row._asdict()) 

939 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id)) 

940 ], 

941 ) 

942 # Create another table B with a not-null timespan and (if the database 

943 # supports it), an exclusion constraint. Use ensureTableExists this 

944 # time to check that mode of table creation vs. timespans. 

945 bSpec = ddl.TableSpec( 

946 fields=[ 

947 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

948 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

949 ], 

950 ) 

951 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

952 bSpec.fields.add(fieldSpec) 

953 if TimespanReprClass.hasExclusionConstraint(): 

954 bSpec.exclusion.add(("key", TimespanReprClass)) 

955 bTable = db.ensureTableExists("b", bSpec) 

956 # Insert rows into table B, again in chunks. Each Timespan appears 

957 # twice, but with different values for the 'key' field (which should 

958 # still be okay for any exclusion constraint we may have defined). 

959 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

960 offset = len(bRows) 

961 bRows.extend( 

962 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

963 ) 

964 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

965 db.insert(bTable, convertRowForInsert(bRows[2])) 

966 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

967 # Insert a row with no timespan into table B. This should invoke the 

968 # server-side default, which is a timespan over (-∞, ∞). We set 

969 # key=3 to avoid upsetting an exclusion constraint that might exist. 

970 bRows.append({"id": len(bRows), "key": 3}) 

971 db.insert(bTable, bRows[-1]) 

972 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

973 # Test basic round-trip through database. 

974 self.assertEqual( 

975 bRows, 

976 [ 

977 convertRowFromSelect(row._asdict()) 

978 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id)) 

979 ], 

980 ) 

981 # Test that we can't insert timespan=None into this table. 

982 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

983 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})) 

984 # IFF this database supports exclusion constraints, test that they 

985 # also prevent inserts. 

986 if TimespanReprClass.hasExclusionConstraint(): 

987 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

988 db.insert( 

989 bTable, 

990 convertRowForInsert( 

991 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

992 ), 

993 ) 

994 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

995 db.insert( 

996 bTable, 

997 convertRowForInsert( 

998 { 

999 "id": len(bRows), 

1000 "key": 1, 

1001 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

1002 } 

1003 ), 

1004 ) 

1005 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1006 db.insert( 

1007 bTable, 

1008 convertRowForInsert( 

1009 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

1010 ), 

1011 ) 

1012 # Test NULL checks in SELECT queries, on both tables. 

1013 aRepr = TimespanReprClass.from_columns(aTable.columns) 

1014 self.assertEqual( 

1015 [row[TimespanReprClass.NAME] is None for row in aRows], 

1016 [ 

1017 row.f 

1018 for row in self.query_list( 

1019 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

1020 ) 

1021 ], 

1022 ) 

1023 bRepr = TimespanReprClass.from_columns(bTable.columns) 

1024 self.assertEqual( 

1025 [False for row in bRows], 

1026 [ 

1027 row.f 

1028 for row in self.query_list( 

1029 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

1030 ) 

1031 ], 

1032 ) 

1033 # Test relationships expressions that relate in-database timespans to 

1034 # Python-literal timespans, all from the more complete 'a' set; check 

1035 # that this is consistent with Python-only relationship tests. 

1036 for rhsRow in aRows: 

1037 if rhsRow[TimespanReprClass.NAME] is None: 

1038 continue 

1039 with self.subTest(rhsRow=rhsRow): 

1040 expected = {} 

1041 for lhsRow in aRows: 

1042 if lhsRow[TimespanReprClass.NAME] is None: 

1043 expected[lhsRow["id"]] = (None, None, None, None) 

1044 else: 

1045 expected[lhsRow["id"]] = ( 

1046 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

1047 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

1048 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

1049 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

1050 ) 

1051 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1052 sql = sqlalchemy.sql.select( 

1053 aTable.columns.id.label("lhs"), 

1054 aRepr.overlaps(rhsRepr).label("overlaps"), 

1055 aRepr.contains(rhsRepr).label("contains"), 

1056 (aRepr < rhsRepr).label("less_than"), 

1057 (aRepr > rhsRepr).label("greater_than"), 

1058 ).select_from(aTable) 

1059 queried = { 

1060 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1061 for row in self.query_list(db, sql) 

1062 } 

1063 self.assertEqual(expected, queried) 

1064 # Test relationship expressions that relate in-database timespans to 

1065 # each other, all from the more complete 'a' set; check that this is 

1066 # consistent with Python-only relationship tests. 

1067 expected = {} 

1068 for lhs, rhs in itertools.product(aRows, aRows): 

1069 lhsT = lhs[TimespanReprClass.NAME] 

1070 rhsT = rhs[TimespanReprClass.NAME] 

1071 if lhsT is not None and rhsT is not None: 

1072 expected[lhs["id"], rhs["id"]] = ( 

1073 lhsT.overlaps(rhsT), 

1074 lhsT.contains(rhsT), 

1075 lhsT < rhsT, 

1076 lhsT > rhsT, 

1077 ) 

1078 else: 

1079 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1080 lhsSubquery = aTable.alias("lhs") 

1081 rhsSubquery = aTable.alias("rhs") 

1082 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns) 

1083 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns) 

1084 sql = sqlalchemy.sql.select( 

1085 lhsSubquery.columns.id.label("lhs"), 

1086 rhsSubquery.columns.id.label("rhs"), 

1087 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1088 lhsRepr.contains(rhsRepr).label("contains"), 

1089 (lhsRepr < rhsRepr).label("less_than"), 

1090 (lhsRepr > rhsRepr).label("greater_than"), 

1091 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1092 queried = { 

1093 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1094 for row in self.query_list(db, sql) 

1095 } 

1096 self.assertEqual(expected, queried) 

1097 # Test relationship expressions between in-database timespans and 

1098 # Python-literal instantaneous times. 

1099 for t in timestamps: 

1100 with self.subTest(t=t): 

1101 expected = {} 

1102 for lhsRow in aRows: 

1103 if lhsRow[TimespanReprClass.NAME] is None: 

1104 expected[lhsRow["id"]] = (None, None, None, None) 

1105 else: 

1106 expected[lhsRow["id"]] = ( 

1107 lhsRow[TimespanReprClass.NAME].contains(t), 

1108 lhsRow[TimespanReprClass.NAME].overlaps(t), 

1109 lhsRow[TimespanReprClass.NAME] < t, 

1110 lhsRow[TimespanReprClass.NAME] > t, 

1111 ) 

1112 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1113 sql = sqlalchemy.sql.select( 

1114 aTable.columns.id.label("lhs"), 

1115 aRepr.contains(rhs).label("contains"), 

1116 aRepr.overlaps(rhs).label("overlaps_point"), 

1117 (aRepr < rhs).label("less_than"), 

1118 (aRepr > rhs).label("greater_than"), 

1119 ).select_from(aTable) 

1120 queried = { 

1121 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than) 

1122 for row in self.query_list(db, sql) 

1123 } 

1124 self.assertEqual(expected, queried) 

1125 

1126 def testConstantRows(self): 

1127 """Test Database.constant_rows.""" 

1128 new_db = self.makeEmptyDatabase() 

1129 with new_db.declareStaticTables(create=True) as context: 

1130 static = context.addTableTuple(STATIC_TABLE_SPECS) 

1131 b_ids = new_db.insert( 

1132 static.b, 

1133 {"name": "b1", "value": 11}, 

1134 {"name": "b2", "value": 12}, 

1135 {"name": "b3", "value": 13}, 

1136 returnIds=True, 

1137 ) 

1138 values_spec = ddl.TableSpec( 

1139 [ 

1140 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger), 

1141 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)), 

1142 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()), 

1143 ], 

1144 ) 

1145 values_data = [ 

1146 {"b": b_ids[0], "s": "b1", "r": None}, 

1147 {"b": b_ids[2], "s": "b3", "r": Circle.empty()}, 

1148 ] 

1149 values = new_db.constant_rows(values_spec.fields, *values_data) 

1150 select_values_alone = sqlalchemy.sql.select( 

1151 values.columns["b"], values.columns["s"], values.columns["r"] 

1152 ) 

1153 self.assertCountEqual( 

1154 [row._mapping for row in self.query_list(new_db, select_values_alone)], 

1155 values_data, 

1156 ) 

1157 select_values_joined = sqlalchemy.sql.select( 

1158 values.columns["s"].label("name"), static.b.columns["value"].label("value") 

1159 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"])) 

1160 self.assertCountEqual( 

1161 [row._mapping for row in self.query_list(new_db, select_values_joined)], 

1162 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}], 

1163 )