Coverage for python/lsst/daf/butler/registry/tests/_database.py: 6%

492 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 10:56 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25import asyncio 

26import itertools 

27import warnings 

28from abc import ABC, abstractmethod 

29from collections import namedtuple 

30from collections.abc import Iterable 

31from concurrent.futures import ThreadPoolExecutor 

32from contextlib import contextmanager 

33from typing import Any, ContextManager 

34 

35import astropy.time 

36import sqlalchemy 

37from lsst.sphgeom import Circle, ConvexPolygon, UnitVector3d 

38 

39from ...core import Timespan, ddl 

40from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

41 

42StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

43 

44STATIC_TABLE_SPECS = StaticTablesTuple( 

45 a=ddl.TableSpec( 

46 fields=[ 

47 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

48 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

49 ] 

50 ), 

51 b=ddl.TableSpec( 

52 fields=[ 

53 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

54 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

55 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

56 ], 

57 unique=[("name",)], 

58 ), 

59 c=ddl.TableSpec( 

60 fields=[ 

61 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

62 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

63 ], 

64 foreignKeys=[ 

65 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

66 ], 

67 ), 

68) 

69 

70DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

71 fields=[ 

72 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

73 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

74 ], 

75 foreignKeys=[ 

76 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"), 

77 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

78 ], 

79) 

80 

81TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

82 fields=[ 

83 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

84 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

85 ], 

86) 

87 

88 

89@contextmanager 

90def _patch_getExistingTable(db: Database) -> Database: 

91 """Patch getExistingTable method in a database instance to test concurrent 

92 creation of tables. This patch obviously depends on knowning internals of 

93 ``ensureTableExists()`` implementation. 

94 """ 

95 original_method = db.getExistingTable 

96 

97 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None: 

98 # Return None on first call, but forward to original method after that 

99 db.getExistingTable = original_method 

100 return None 

101 

102 db.getExistingTable = _getExistingTable 

103 yield db 

104 db.getExistingTable = original_method 

105 

106 

107class DatabaseTests(ABC): 

108 """Generic tests for the `Database` interface that can be subclassed to 

109 generate tests for concrete implementations. 

110 """ 

111 

112 @abstractmethod 

113 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

114 """Return an empty `Database` with the given origin, or an 

115 automatically-generated one if ``origin`` is `None`. 

116 """ 

117 raise NotImplementedError() 

118 

119 @abstractmethod 

120 def asReadOnly(self, database: Database) -> ContextManager[Database]: 

121 """Return a context manager for a read-only connection into the given 

122 database. 

123 

124 The original database should be considered unusable within the context 

125 but safe to use again afterwards (this allows the context manager to 

126 block write access by temporarily changing user permissions to really 

127 guarantee that write operations are not performed). 

128 """ 

129 raise NotImplementedError() 

130 

131 @abstractmethod 

132 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

133 """Return a new `Database` instance that points to the same underlying 

134 storage as the given one. 

135 """ 

136 raise NotImplementedError() 

137 

138 def query_list( 

139 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase 

140 ) -> list[sqlalchemy.engine.Row]: 

141 """Run a SELECT or other read-only query against the database and 

142 return the results as a list. 

143 

144 This is a thin wrapper around database.query() that just avoids 

145 context-manager boilerplate that is usefully verbose in production code 

146 but just noise in tests. 

147 """ 

148 with database.transaction(): 

149 with database.query(executable) as result: 

150 return result.fetchall() 

151 

152 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any: 

153 """Run a SELECT query that yields a single column and row against the 

154 database and return its value. 

155 

156 This is a thin wrapper around database.query() that just avoids 

157 context-manager boilerplate that is usefully verbose in production code 

158 but just noise in tests. 

159 """ 

160 with database.query(executable) as result: 

161 return result.scalar() 

162 

163 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

164 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

165 # Checking more than this currently seems fragile, as it might restrict 

166 # what Database implementations do; we don't care if the spec is 

167 # actually preserved in terms of types and constraints as long as we 

168 # can use the returned table as if it was. 

169 

170 def checkStaticSchema(self, tables: StaticTablesTuple): 

171 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

172 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

173 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

174 

175 def testDeclareStaticTables(self): 

176 """Tests for `Database.declareStaticSchema` and the methods it 

177 delegates to. 

178 """ 

179 # Create the static schema in a new, empty database. 

180 newDatabase = self.makeEmptyDatabase() 

181 with newDatabase.declareStaticTables(create=True) as context: 

182 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

183 self.checkStaticSchema(tables) 

184 # Check that we can load that schema even from a read-only connection. 

185 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

186 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

187 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

188 self.checkStaticSchema(tables) 

189 

190 def testDeclareStaticTablesTwice(self): 

191 """Tests for `Database.declareStaticSchema` being called twice.""" 

192 # Create the static schema in a new, empty database. 

193 newDatabase = self.makeEmptyDatabase() 

194 with newDatabase.declareStaticTables(create=True) as context: 

195 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

196 self.checkStaticSchema(tables) 

197 # Second time it should raise 

198 with self.assertRaises(SchemaAlreadyDefinedError): 

199 with newDatabase.declareStaticTables(create=True) as context: 

200 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

201 # Check schema, it should still contain all tables, and maybe some 

202 # extra. 

203 with newDatabase.declareStaticTables(create=False) as context: 

204 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

205 

206 def testRepr(self): 

207 """Test that repr does not return a generic thing.""" 

208 newDatabase = self.makeEmptyDatabase() 

209 rep = repr(newDatabase) 

210 # Check that stringification works and gives us something different 

211 self.assertNotEqual(rep, str(newDatabase)) 

212 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

213 self.assertIn("://", rep) 

214 

215 def testDynamicTables(self): 

216 """Tests for `Database.ensureTableExists` and 

217 `Database.getExistingTable`. 

218 """ 

219 # Need to start with the static schema. 

220 newDatabase = self.makeEmptyDatabase() 

221 with newDatabase.declareStaticTables(create=True) as context: 

222 context.addTableTuple(STATIC_TABLE_SPECS) 

223 # Try to ensure the dynamic table exists in a read-only version of that 

224 # database, which should fail because we can't create it. 

225 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

226 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

227 context.addTableTuple(STATIC_TABLE_SPECS) 

228 with self.assertRaises(ReadOnlyDatabaseError): 

229 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

230 # Just getting the dynamic table before it exists should return None. 

231 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

232 # Ensure the new table exists back in the original database, which 

233 # should create it. 

234 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

235 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

236 # Ensuring that it exists should just return the exact same table 

237 # instance again. 

238 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

239 # Try again from the read-only database. 

240 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

241 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

242 context.addTableTuple(STATIC_TABLE_SPECS) 

243 # Just getting the dynamic table should now work... 

244 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

245 # ...as should ensuring that it exists, since it now does. 

246 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

247 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

248 # Trying to get the table with a different specification (at least 

249 # in terms of what columns are present) should raise. 

250 with self.assertRaises(DatabaseConflictError): 

251 newDatabase.ensureTableExists( 

252 "d", 

253 ddl.TableSpec( 

254 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

255 ), 

256 ) 

257 # Calling ensureTableExists inside a transaction block is an error, 

258 # even if it would do nothing. 

259 with newDatabase.transaction(): 

260 with self.assertRaises(AssertionError): 

261 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

262 

263 def testDynamicTablesConcurrency(self): 

264 """Tests for `Database.ensureTableExists` concurrent use.""" 

265 # We cannot really run things concurrently in a deterministic way, here 

266 # we just simulate a situation when the table is created by other 

267 # process between the call to getExistingTable() and actual table 

268 # creation. 

269 db1 = self.makeEmptyDatabase() 

270 with db1.declareStaticTables(create=True) as context: 

271 context.addTableTuple(STATIC_TABLE_SPECS) 

272 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

273 

274 # Make a dynamic table using separate connection 

275 db2 = self.getNewConnection(db1, writeable=True) 

276 with db2.declareStaticTables(create=False) as context: 

277 context.addTableTuple(STATIC_TABLE_SPECS) 

278 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

279 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

280 

281 # Call it again but trick it into thinking that table is not there. 

282 # This test depends on knowing implementation of ensureTableExists() 

283 # which initially calls getExistingTable() to check that table may 

284 # exist, the patch intercepts that call and returns None. 

285 with _patch_getExistingTable(db1): 

286 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

287 

288 def testTemporaryTables(self): 

289 """Tests for `Database.temporary_table`, and `Database.insert` with the 

290 ``select`` argument. 

291 """ 

292 # Need to start with the static schema; also insert some test data. 

293 newDatabase = self.makeEmptyDatabase() 

294 with newDatabase.declareStaticTables(create=True) as context: 

295 static = context.addTableTuple(STATIC_TABLE_SPECS) 

296 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

297 bIds = newDatabase.insert( 

298 static.b, 

299 {"name": "b1", "value": 11}, 

300 {"name": "b2", "value": 12}, 

301 {"name": "b3", "value": 13}, 

302 returnIds=True, 

303 ) 

304 # Create the table. 

305 with newDatabase.session(): 

306 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1: 

307 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

308 # Insert via a INSERT INTO ... SELECT query. 

309 newDatabase.insert( 

310 table1, 

311 select=sqlalchemy.sql.select( 

312 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

313 ) 

314 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

315 .where( 

316 sqlalchemy.sql.and_( 

317 static.a.columns.name == "a1", 

318 static.b.columns.value <= 12, 

319 ) 

320 ), 

321 ) 

322 # Check that the inserted rows are present. 

323 self.assertCountEqual( 

324 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

325 [row._asdict() for row in self.query_list(newDatabase, table1.select())], 

326 ) 

327 # Create another one via a read-only connection to the 

328 # database. We _do_ allow temporary table modifications in 

329 # read-only databases. 

330 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

331 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

332 context.addTableTuple(STATIC_TABLE_SPECS) 

333 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2: 

334 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

335 # Those tables should not be the same, despite having 

336 # the same ddl. 

337 self.assertIsNot(table1, table2) 

338 # Do a slightly different insert into this table, to 

339 # check that it works in a read-only database. This 

340 # time we pass column names as a kwarg to insert 

341 # instead of by labeling the columns in the select. 

342 existingReadOnlyDatabase.insert( 

343 table2, 

344 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

345 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

346 .where( 

347 sqlalchemy.sql.and_( 

348 static.a.columns.name == "a2", 

349 static.b.columns.value >= 12, 

350 ) 

351 ), 

352 names=["a_name", "b_id"], 

353 ) 

354 # Check that the inserted rows are present. 

355 self.assertCountEqual( 

356 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

357 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())], 

358 ) 

359 # Exiting the context managers will drop the temporary tables from 

360 # the read-only DB. It's unspecified whether attempting to use it 

361 # after this point is an error or just never returns any results, 

362 # so we can't test what it does, only that it's not an error. 

363 

364 def testSchemaSeparation(self): 

365 """Test that creating two different `Database` instances allows us 

366 to create different tables with the same name in each. 

367 """ 

368 db1 = self.makeEmptyDatabase(origin=1) 

369 with db1.declareStaticTables(create=True) as context: 

370 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

371 self.checkStaticSchema(tables) 

372 

373 db2 = self.makeEmptyDatabase(origin=2) 

374 # Make the DDL here intentionally different so we'll definitely 

375 # notice if db1 and db2 are pointing at the same schema. 

376 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

377 with db2.declareStaticTables(create=True) as context: 

378 # Make the DDL here intentionally different so we'll definitely 

379 # notice if db1 and db2 are pointing at the same schema. 

380 table = context.addTable("a", spec) 

381 self.checkTable(spec, table) 

382 

383 def testInsertQueryDelete(self): 

384 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

385 methods, as well as the `Base64Region` type and the ``onDelete`` 

386 argument to `ddl.ForeignKeySpec`. 

387 """ 

388 db = self.makeEmptyDatabase(origin=1) 

389 with db.declareStaticTables(create=True) as context: 

390 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

391 # Insert a single, non-autoincrement row that contains a region and 

392 # query to get it back. 

393 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

394 row = {"name": "a1", "region": region} 

395 db.insert(tables.a, row) 

396 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row]) 

397 # Insert multiple autoincrement rows but do not try to get the IDs 

398 # back immediately. 

399 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

400 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))] 

401 self.assertEqual(len(results), 2) 

402 for row in results: 

403 self.assertIn(row["name"], ("b1", "b2")) 

404 self.assertIsInstance(row["id"], int) 

405 self.assertGreater(results[1]["id"], results[0]["id"]) 

406 # Insert multiple autoincrement rows and get the IDs back from insert. 

407 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

408 ids = db.insert(tables.b, *rows, returnIds=True) 

409 results = [ 

410 r._asdict() 

411 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

412 ] 

413 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

414 self.assertCountEqual(results, expected) 

415 self.assertTrue(all(result["id"] is not None for result in results)) 

416 # Insert multiple rows into a table with an autoincrement primary key, 

417 # then use the returned IDs to insert into a dynamic table. 

418 rows = [{"b_id": results[0]["id"]}, {"b_id": None}] 

419 ids = db.insert(tables.c, *rows, returnIds=True) 

420 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

421 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

422 self.assertCountEqual(results, expected) 

423 self.assertTrue(all(result["id"] is not None for result in results)) 

424 # Add the dynamic table. 

425 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

426 # Insert into it. 

427 rows = [{"c_id": id, "a_name": "a1"} for id in ids] 

428 db.insert(d, *rows) 

429 results = [r._asdict() for r in self.query_list(db, d.select())] 

430 self.assertCountEqual(rows, results) 

431 # Insert multiple rows into a table with an autoincrement primary key, 

432 # but pass in a value for the autoincrement key. 

433 rows2 = [ 

434 {"id": 700, "b_id": None}, 

435 {"id": 701, "b_id": None}, 

436 ] 

437 db.insert(tables.c, *rows2) 

438 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

439 self.assertCountEqual(results, expected + rows2) 

440 self.assertTrue(all(result["id"] is not None for result in results)) 

441 

442 # Define 'SELECT COUNT(*)' query for later use. 

443 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

444 # Get the values we inserted into table b. 

445 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())] 

446 # Remove two row from table b by ID. 

447 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

448 self.assertEqual(n, 2) 

449 # Remove the other two rows from table b by name. 

450 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

451 self.assertEqual(n, 2) 

452 # There should now be no rows in table b. 

453 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

454 # All b_id values in table c should now be NULL, because there's an 

455 # onDelete='SET NULL' foreign key. 

456 self.assertEqual( 

457 self.query_scalar( 

458 db, 

459 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711 

460 ), 

461 0, 

462 ) 

463 # Remove all rows in table a (there's only one); this should remove all 

464 # rows in d due to onDelete='CASCADE'. 

465 n = db.delete(tables.a, []) 

466 self.assertEqual(n, 1) 

467 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0) 

468 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0) 

469 

470 def testDeleteWhere(self): 

471 """Tests for `Database.deleteWhere`.""" 

472 db = self.makeEmptyDatabase(origin=1) 

473 with db.declareStaticTables(create=True) as context: 

474 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

475 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

476 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

477 

478 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

479 self.assertEqual(n, 3) 

480 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7) 

481 

482 n = db.deleteWhere( 

483 tables.b, 

484 tables.b.columns.id.in_( 

485 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

486 ), 

487 ) 

488 self.assertEqual(n, 4) 

489 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3) 

490 

491 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

492 self.assertEqual(n, 1) 

493 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2) 

494 

495 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

496 self.assertEqual(n, 2) 

497 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

498 

499 def testUpdate(self): 

500 """Tests for `Database.update`.""" 

501 db = self.makeEmptyDatabase(origin=1) 

502 with db.declareStaticTables(create=True) as context: 

503 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

504 # Insert two rows into table a, both without regions. 

505 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

506 # Update one of the rows with a region. 

507 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

508 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

509 self.assertEqual(n, 1) 

510 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

511 self.assertCountEqual( 

512 [r._asdict() for r in self.query_list(db, sql)], 

513 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

514 ) 

515 

516 def testSync(self): 

517 """Tests for `Database.sync`.""" 

518 db = self.makeEmptyDatabase(origin=1) 

519 with db.declareStaticTables(create=True) as context: 

520 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

521 # Insert a row with sync, because it doesn't exist yet. 

522 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

523 self.assertTrue(inserted) 

524 self.assertEqual( 

525 [{"id": values["id"], "name": "b1", "value": 10}], 

526 [r._asdict() for r in self.query_list(db, tables.b.select())], 

527 ) 

528 # Repeat that operation, which should do nothing but return the 

529 # requested values. 

530 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

531 self.assertFalse(inserted) 

532 self.assertEqual( 

533 [{"id": values["id"], "name": "b1", "value": 10}], 

534 [r._asdict() for r in self.query_list(db, tables.b.select())], 

535 ) 

536 # Repeat the operation without the 'extra' arg, which should also just 

537 # return the existing row. 

538 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

539 self.assertFalse(inserted) 

540 self.assertEqual( 

541 [{"id": values["id"], "name": "b1", "value": 10}], 

542 [r._asdict() for r in self.query_list(db, tables.b.select())], 

543 ) 

544 # Repeat the operation with a different value in 'extra'. That still 

545 # shouldn't be an error, because 'extra' is only used if we really do 

546 # insert. Also drop the 'returning' argument. 

547 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

548 self.assertFalse(inserted) 

549 self.assertEqual( 

550 [{"id": values["id"], "name": "b1", "value": 10}], 

551 [r._asdict() for r in self.query_list(db, tables.b.select())], 

552 ) 

553 # Repeat the operation with the correct value in 'compared' instead of 

554 # 'extra'. 

555 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

556 self.assertFalse(inserted) 

557 self.assertEqual( 

558 [{"id": values["id"], "name": "b1", "value": 10}], 

559 [r._asdict() for r in self.query_list(db, tables.b.select())], 

560 ) 

561 # Repeat the operation with an incorrect value in 'compared'; this 

562 # should raise. 

563 with self.assertRaises(DatabaseConflictError): 

564 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

565 # Try to sync in a read-only database. This should work if and only 

566 # if the matching row already exists. 

567 with self.asReadOnly(db) as rodb: 

568 with rodb.declareStaticTables(create=False) as context: 

569 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

570 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

571 self.assertFalse(inserted) 

572 self.assertEqual( 

573 [{"id": values["id"], "name": "b1", "value": 10}], 

574 [r._asdict() for r in self.query_list(rodb, tables.b.select())], 

575 ) 

576 with self.assertRaises(ReadOnlyDatabaseError): 

577 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

578 # Repeat the operation with a different value in 'compared' and ask to 

579 # update. 

580 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

581 self.assertEqual(updated, {"value": 10}) 

582 self.assertEqual( 

583 [{"id": values["id"], "name": "b1", "value": 20}], 

584 [r._asdict() for r in self.query_list(db, tables.b.select())], 

585 ) 

586 

587 def testReplace(self): 

588 """Tests for `Database.replace`.""" 

589 db = self.makeEmptyDatabase(origin=1) 

590 with db.declareStaticTables(create=True) as context: 

591 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

592 # Use 'replace' to insert a single row that contains a region and 

593 # query to get it back. 

594 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

595 row1 = {"name": "a1", "region": region} 

596 db.replace(tables.a, row1) 

597 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

598 # Insert another row without a region. 

599 row2 = {"name": "a2", "region": None} 

600 db.replace(tables.a, row2) 

601 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

602 # Use replace to re-insert both of those rows again, which should do 

603 # nothing. 

604 db.replace(tables.a, row1, row2) 

605 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

606 # Replace row1 with a row with no region, while reinserting row2. 

607 row1a = {"name": "a1", "region": None} 

608 db.replace(tables.a, row1a, row2) 

609 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2]) 

610 # Replace both rows, returning row1 to its original state, while adding 

611 # a new one. Pass them in in a different order. 

612 row2a = {"name": "a2", "region": region} 

613 row3 = {"name": "a3", "region": None} 

614 db.replace(tables.a, row3, row2a, row1) 

615 self.assertCountEqual( 

616 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3] 

617 ) 

618 

619 def testEnsure(self): 

620 """Tests for `Database.ensure`.""" 

621 db = self.makeEmptyDatabase(origin=1) 

622 with db.declareStaticTables(create=True) as context: 

623 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

624 # Use 'ensure' to insert a single row that contains a region and 

625 # query to get it back. 

626 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

627 row1 = {"name": "a1", "region": region} 

628 self.assertEqual(db.ensure(tables.a, row1), 1) 

629 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

630 # Insert another row without a region. 

631 row2 = {"name": "a2", "region": None} 

632 self.assertEqual(db.ensure(tables.a, row2), 1) 

633 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

634 # Use ensure to re-insert both of those rows again, which should do 

635 # nothing. 

636 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

637 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

638 # Attempt to insert row1's key with no region, while 

639 # reinserting row2. This should also do nothing. 

640 row1a = {"name": "a1", "region": None} 

641 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

642 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

643 # Attempt to insert new rows for both existing keys, this time also 

644 # adding a new row. Pass them in in a different order. Only the new 

645 # row should be added. 

646 row2a = {"name": "a2", "region": region} 

647 row3 = {"name": "a3", "region": None} 

648 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

649 self.assertCountEqual( 

650 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3] 

651 ) 

652 # Add some data to a table with both a primary key and a different 

653 # unique constraint. 

654 row_b = {"id": 5, "name": "five", "value": 50} 

655 db.insert(tables.b, row_b) 

656 # Attempt ensure with primary_key_only=False and a conflict for the 

657 # non-PK constraint. This should do nothing. 

658 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}) 

659 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

660 # Now use primary_key_only=True with conflict in only the non-PK field. 

661 # This should be an integrity error and nothing should change. 

662 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

663 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True) 

664 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

665 # With primary_key_only=True a conflict in the primary key is ignored 

666 # regardless of whether there is a conflict elsewhere. 

667 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True) 

668 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

669 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True) 

670 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

671 

672 def testTransactionNesting(self): 

673 """Test that transactions can be nested with the behavior in the 

674 presence of exceptions working as documented. 

675 """ 

676 db = self.makeEmptyDatabase(origin=1) 

677 with db.declareStaticTables(create=True) as context: 

678 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

679 # Insert one row so we can trigger integrity errors by trying to insert 

680 # a duplicate of it below. 

681 db.insert(tables.a, {"name": "a1"}) 

682 # First test: error recovery via explicit savepoint=True in the inner 

683 # transaction. 

684 with db.transaction(): 

685 # This insert should succeed, and should not be rolled back because 

686 # the assertRaises context should catch any exception before it 

687 # propagates up to the outer transaction. 

688 db.insert(tables.a, {"name": "a2"}) 

689 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

690 with db.transaction(savepoint=True): 

691 # This insert should succeed, but should be rolled back. 

692 db.insert(tables.a, {"name": "a4"}) 

693 # This insert should fail (duplicate primary key), raising 

694 # an exception. 

695 db.insert(tables.a, {"name": "a1"}) 

696 self.assertCountEqual( 

697 [r._asdict() for r in self.query_list(db, tables.a.select())], 

698 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

699 ) 

700 # Second test: error recovery via implicit savepoint=True, when the 

701 # innermost transaction is inside a savepoint=True transaction. 

702 with db.transaction(): 

703 # This insert should succeed, and should not be rolled back 

704 # because the assertRaises context should catch any 

705 # exception before it propagates up to the outer 

706 # transaction. 

707 db.insert(tables.a, {"name": "a3"}) 

708 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

709 with db.transaction(savepoint=True): 

710 # This insert should succeed, but should be rolled back. 

711 db.insert(tables.a, {"name": "a4"}) 

712 with db.transaction(): 

713 # This insert should succeed, but should be rolled 

714 # back. 

715 db.insert(tables.a, {"name": "a5"}) 

716 # This insert should fail (duplicate primary key), 

717 # raising an exception. 

718 db.insert(tables.a, {"name": "a1"}) 

719 self.assertCountEqual( 

720 [r._asdict() for r in self.query_list(db, tables.a.select())], 

721 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

722 ) 

723 

724 def testTransactionLocking(self): 

725 """Test that `Database.transaction` can be used to acquire a lock 

726 that prohibits concurrent writes. 

727 """ 

728 db1 = self.makeEmptyDatabase(origin=1) 

729 with db1.declareStaticTables(create=True) as context: 

730 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

731 

732 async def side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]: 

733 """One side of the concurrent locking test. 

734 

735 This optionally locks the table (and maybe the whole database), 

736 does a select for its contents, inserts a new row, and then selects 

737 again, with some waiting in between to make sure the other side has 

738 a chance to _attempt_ to insert in between. If the locking is 

739 enabled and works, the difference between the selects should just 

740 be the insert done on this thread. 

741 """ 

742 # Give Side2 a chance to create a connection 

743 await asyncio.sleep(1.0) 

744 with db1.transaction(lock=lock): 

745 names1 = {row.name for row in self.query_list(db1, tables1.a.select())} 

746 # Give Side2 a chance to insert (which will be blocked if 

747 # we've acquired a lock). 

748 await asyncio.sleep(2.0) 

749 db1.insert(tables1.a, {"name": "a1"}) 

750 names2 = {row.name for row in self.query_list(db1, tables1.a.select())} 

751 return names1, names2 

752 

753 async def side2() -> None: 

754 """Other side of the concurrent locking test. 

755 

756 This side just waits a bit and then tries to insert a row into the 

757 table that the other side is trying to lock. Hopefully that 

758 waiting is enough to give the other side a chance to acquire the 

759 lock and thus make this side block until the lock is released. If 

760 this side manages to do the insert before side1 acquires the lock, 

761 we'll just warn about not succeeding at testing the locking, 

762 because we can only make that unlikely, not impossible. 

763 """ 

764 

765 def toRunInThread(): 

766 """Create new SQLite connection for use in thread. 

767 

768 SQLite locking isn't asyncio-friendly unless we actually 

769 run it in another thread. And SQLite gets very unhappy if 

770 we try to use a connection from multiple threads, so we have 

771 to create the new connection here instead of out in the main 

772 body of the test function. 

773 """ 

774 db2 = self.getNewConnection(db1, writeable=True) 

775 with db2.declareStaticTables(create=False) as context: 

776 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

777 with db2.transaction(): 

778 db2.insert(tables2.a, {"name": "a2"}) 

779 

780 await asyncio.sleep(2.0) 

781 loop = asyncio.get_running_loop() 

782 with ThreadPoolExecutor() as pool: 

783 await loop.run_in_executor(pool, toRunInThread) 

784 

785 async def testProblemsWithNoLocking() -> None: 

786 """Run side1 and side2 with no locking, attempting to demonstrate 

787 the problem that locking is supposed to solve. If we get unlucky 

788 with scheduling, side2 will just happen to insert after side1 is 

789 done, and we won't have anything definitive. We just warn in that 

790 case because we really don't want spurious test failures. 

791 """ 

792 task1 = asyncio.create_task(side1()) 

793 task2 = asyncio.create_task(side2()) 

794 

795 names1, names2 = await task1 

796 await task2 

797 if "a2" in names1: 

798 warnings.warn( 

799 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT." 

800 ) 

801 self.assertEqual(names1, {"a2"}) 

802 self.assertEqual(names2, {"a1", "a2"}) 

803 elif "a2" not in names2: 

804 warnings.warn( 

805 "Unlucky scheduling in no-locking test: concurrent INSERT " 

806 "happened after second SELECT even without locking." 

807 ) 

808 self.assertEqual(names1, set()) 

809 self.assertEqual(names2, {"a1"}) 

810 else: 

811 # This is the expected case: both INSERTS happen between the 

812 # two SELECTS. If we don't get this almost all of the time we 

813 # should adjust the sleep amounts. 

814 self.assertEqual(names1, set()) 

815 self.assertEqual(names2, {"a1", "a2"}) 

816 

817 asyncio.run(testProblemsWithNoLocking()) 

818 

819 # Clean up after first test. 

820 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

821 

822 async def testSolutionWithLocking() -> None: 

823 """Run side1 and side2 with locking, which should make side2 block 

824 its insert until side2 releases its lock. 

825 """ 

826 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

827 task2 = asyncio.create_task(side2()) 

828 

829 names1, names2 = await task1 

830 await task2 

831 if "a2" in names1: 

832 warnings.warn( 

833 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT." 

834 ) 

835 self.assertEqual(names1, {"a2"}) 

836 self.assertEqual(names2, {"a1", "a2"}) 

837 else: 

838 # This is the expected case: the side2 INSERT happens after the 

839 # last SELECT on side1. This can also happen due to unlucky 

840 # scheduling, and we have no way to detect that here, but the 

841 # similar "no-locking" test has at least some chance of being 

842 # affected by the same problem and warning about it. 

843 self.assertEqual(names1, set()) 

844 self.assertEqual(names2, {"a1"}) 

845 

846 asyncio.run(testSolutionWithLocking()) 

847 

848 def testTimespanDatabaseRepresentation(self): 

849 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

850 methods that interact with it. 

851 """ 

852 # Make some test timespans to play with, with the full suite of 

853 # topological relationships. 

854 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

855 offset = astropy.time.TimeDelta(60, format="sec") 

856 timestamps = [start + offset * n for n in range(3)] 

857 aTimespans = [Timespan(begin=None, end=None)] 

858 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

859 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

860 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

861 aTimespans.append(Timespan.makeEmpty()) 

862 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

863 # Make another list of timespans that span the full range but don't 

864 # overlap. This is a subset of the previous list. 

865 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

866 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:])) 

867 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

868 # Make a database and create a table with that database's timespan 

869 # representation. This one will have no exclusion constraint and 

870 # a nullable timespan. 

871 db = self.makeEmptyDatabase(origin=1) 

872 TimespanReprClass = db.getTimespanRepresentation() 

873 aSpec = ddl.TableSpec( 

874 fields=[ 

875 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

876 ], 

877 ) 

878 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

879 aSpec.fields.add(fieldSpec) 

880 with db.declareStaticTables(create=True) as context: 

881 aTable = context.addTable("a", aSpec) 

882 self.maxDiff = None 

883 

884 def convertRowForInsert(row: dict) -> dict: 

885 """Convert a row containing a Timespan instance into one suitable 

886 for insertion into the database. 

887 """ 

888 result = row.copy() 

889 ts = result.pop(TimespanReprClass.NAME) 

890 return TimespanReprClass.update(ts, result=result) 

891 

892 def convertRowFromSelect(row: dict) -> dict: 

893 """Convert a row from the database into one containing a Timespan. 

894 

895 Parameters 

896 ---------- 

897 row : `dict` 

898 Original row. 

899 

900 Returns 

901 ------- 

902 row : `dict` 

903 The updated row. 

904 """ 

905 result = row.copy() 

906 timespan = TimespanReprClass.extract(result) 

907 for name in TimespanReprClass.getFieldNames(): 

908 del result[name] 

909 result[TimespanReprClass.NAME] = timespan 

910 return result 

911 

912 # Insert rows into table A, in chunks just to make things interesting. 

913 # Include one with a NULL timespan. 

914 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

915 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

916 db.insert(aTable, convertRowForInsert(aRows[0])) 

917 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

918 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

919 # Add another one with a NULL timespan, but this time by invoking 

920 # the server-side default. 

921 aRows.append({"id": len(aRows)}) 

922 db.insert(aTable, aRows[-1]) 

923 aRows[-1][TimespanReprClass.NAME] = None 

924 # Test basic round-trip through database. 

925 self.assertEqual( 

926 aRows, 

927 [ 

928 convertRowFromSelect(row._asdict()) 

929 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id)) 

930 ], 

931 ) 

932 # Create another table B with a not-null timespan and (if the database 

933 # supports it), an exclusion constraint. Use ensureTableExists this 

934 # time to check that mode of table creation vs. timespans. 

935 bSpec = ddl.TableSpec( 

936 fields=[ 

937 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

938 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

939 ], 

940 ) 

941 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

942 bSpec.fields.add(fieldSpec) 

943 if TimespanReprClass.hasExclusionConstraint(): 

944 bSpec.exclusion.add(("key", TimespanReprClass)) 

945 bTable = db.ensureTableExists("b", bSpec) 

946 # Insert rows into table B, again in chunks. Each Timespan appears 

947 # twice, but with different values for the 'key' field (which should 

948 # still be okay for any exclusion constraint we may have defined). 

949 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

950 offset = len(bRows) 

951 bRows.extend( 

952 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

953 ) 

954 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

955 db.insert(bTable, convertRowForInsert(bRows[2])) 

956 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

957 # Insert a row with no timespan into table B. This should invoke the 

958 # server-side default, which is a timespan over (-∞, ∞). We set 

959 # key=3 to avoid upsetting an exclusion constraint that might exist. 

960 bRows.append({"id": len(bRows), "key": 3}) 

961 db.insert(bTable, bRows[-1]) 

962 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

963 # Test basic round-trip through database. 

964 self.assertEqual( 

965 bRows, 

966 [ 

967 convertRowFromSelect(row._asdict()) 

968 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id)) 

969 ], 

970 ) 

971 # Test that we can't insert timespan=None into this table. 

972 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

973 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})) 

974 # IFF this database supports exclusion constraints, test that they 

975 # also prevent inserts. 

976 if TimespanReprClass.hasExclusionConstraint(): 

977 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

978 db.insert( 

979 bTable, 

980 convertRowForInsert( 

981 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

982 ), 

983 ) 

984 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

985 db.insert( 

986 bTable, 

987 convertRowForInsert( 

988 { 

989 "id": len(bRows), 

990 "key": 1, 

991 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

992 } 

993 ), 

994 ) 

995 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

996 db.insert( 

997 bTable, 

998 convertRowForInsert( 

999 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

1000 ), 

1001 ) 

1002 # Test NULL checks in SELECT queries, on both tables. 

1003 aRepr = TimespanReprClass.from_columns(aTable.columns) 

1004 self.assertEqual( 

1005 [row[TimespanReprClass.NAME] is None for row in aRows], 

1006 [ 

1007 row.f 

1008 for row in self.query_list( 

1009 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

1010 ) 

1011 ], 

1012 ) 

1013 bRepr = TimespanReprClass.from_columns(bTable.columns) 

1014 self.assertEqual( 

1015 [False for row in bRows], 

1016 [ 

1017 row.f 

1018 for row in self.query_list( 

1019 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

1020 ) 

1021 ], 

1022 ) 

1023 # Test relationships expressions that relate in-database timespans to 

1024 # Python-literal timespans, all from the more complete 'a' set; check 

1025 # that this is consistent with Python-only relationship tests. 

1026 for rhsRow in aRows: 

1027 if rhsRow[TimespanReprClass.NAME] is None: 

1028 continue 

1029 with self.subTest(rhsRow=rhsRow): 

1030 expected = {} 

1031 for lhsRow in aRows: 

1032 if lhsRow[TimespanReprClass.NAME] is None: 

1033 expected[lhsRow["id"]] = (None, None, None, None) 

1034 else: 

1035 expected[lhsRow["id"]] = ( 

1036 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

1037 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

1038 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

1039 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

1040 ) 

1041 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1042 sql = sqlalchemy.sql.select( 

1043 aTable.columns.id.label("lhs"), 

1044 aRepr.overlaps(rhsRepr).label("overlaps"), 

1045 aRepr.contains(rhsRepr).label("contains"), 

1046 (aRepr < rhsRepr).label("less_than"), 

1047 (aRepr > rhsRepr).label("greater_than"), 

1048 ).select_from(aTable) 

1049 queried = { 

1050 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1051 for row in self.query_list(db, sql) 

1052 } 

1053 self.assertEqual(expected, queried) 

1054 # Test relationship expressions that relate in-database timespans to 

1055 # each other, all from the more complete 'a' set; check that this is 

1056 # consistent with Python-only relationship tests. 

1057 expected = {} 

1058 for lhs, rhs in itertools.product(aRows, aRows): 

1059 lhsT = lhs[TimespanReprClass.NAME] 

1060 rhsT = rhs[TimespanReprClass.NAME] 

1061 if lhsT is not None and rhsT is not None: 

1062 expected[lhs["id"], rhs["id"]] = ( 

1063 lhsT.overlaps(rhsT), 

1064 lhsT.contains(rhsT), 

1065 lhsT < rhsT, 

1066 lhsT > rhsT, 

1067 ) 

1068 else: 

1069 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1070 lhsSubquery = aTable.alias("lhs") 

1071 rhsSubquery = aTable.alias("rhs") 

1072 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns) 

1073 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns) 

1074 sql = sqlalchemy.sql.select( 

1075 lhsSubquery.columns.id.label("lhs"), 

1076 rhsSubquery.columns.id.label("rhs"), 

1077 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1078 lhsRepr.contains(rhsRepr).label("contains"), 

1079 (lhsRepr < rhsRepr).label("less_than"), 

1080 (lhsRepr > rhsRepr).label("greater_than"), 

1081 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1082 queried = { 

1083 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1084 for row in self.query_list(db, sql) 

1085 } 

1086 self.assertEqual(expected, queried) 

1087 # Test relationship expressions between in-database timespans and 

1088 # Python-literal instantaneous times. 

1089 for t in timestamps: 

1090 with self.subTest(t=t): 

1091 expected = {} 

1092 for lhsRow in aRows: 

1093 if lhsRow[TimespanReprClass.NAME] is None: 

1094 expected[lhsRow["id"]] = (None, None, None, None) 

1095 else: 

1096 expected[lhsRow["id"]] = ( 

1097 lhsRow[TimespanReprClass.NAME].contains(t), 

1098 lhsRow[TimespanReprClass.NAME].overlaps(t), 

1099 lhsRow[TimespanReprClass.NAME] < t, 

1100 lhsRow[TimespanReprClass.NAME] > t, 

1101 ) 

1102 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1103 sql = sqlalchemy.sql.select( 

1104 aTable.columns.id.label("lhs"), 

1105 aRepr.contains(rhs).label("contains"), 

1106 aRepr.overlaps(rhs).label("overlaps_point"), 

1107 (aRepr < rhs).label("less_than"), 

1108 (aRepr > rhs).label("greater_than"), 

1109 ).select_from(aTable) 

1110 queried = { 

1111 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than) 

1112 for row in self.query_list(db, sql) 

1113 } 

1114 self.assertEqual(expected, queried) 

1115 

1116 def testConstantRows(self): 

1117 """Test Database.constant_rows.""" 

1118 new_db = self.makeEmptyDatabase() 

1119 with new_db.declareStaticTables(create=True) as context: 

1120 static = context.addTableTuple(STATIC_TABLE_SPECS) 

1121 b_ids = new_db.insert( 

1122 static.b, 

1123 {"name": "b1", "value": 11}, 

1124 {"name": "b2", "value": 12}, 

1125 {"name": "b3", "value": 13}, 

1126 returnIds=True, 

1127 ) 

1128 values_spec = ddl.TableSpec( 

1129 [ 

1130 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger), 

1131 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)), 

1132 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()), 

1133 ], 

1134 ) 

1135 values_data = [ 

1136 {"b": b_ids[0], "s": "b1", "r": None}, 

1137 {"b": b_ids[2], "s": "b3", "r": Circle.empty()}, 

1138 ] 

1139 values = new_db.constant_rows(values_spec.fields, *values_data) 

1140 select_values_alone = sqlalchemy.sql.select( 

1141 values.columns["b"], values.columns["s"], values.columns["r"] 

1142 ) 

1143 self.assertCountEqual( 

1144 [row._mapping for row in self.query_list(new_db, select_values_alone)], 

1145 values_data, 

1146 ) 

1147 select_values_joined = sqlalchemy.sql.select( 

1148 values.columns["s"].label("name"), static.b.columns["value"].label("value") 

1149 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"])) 

1150 self.assertCountEqual( 

1151 [row._mapping for row in self.query_list(new_db, select_values_joined)], 

1152 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}], 

1153 )