Coverage for python/lsst/daf/butler/registry/tests/_database.py: 7%

491 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-25 15:14 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25import asyncio 

26import itertools 

27import warnings 

28from abc import ABC, abstractmethod 

29from collections import namedtuple 

30from collections.abc import Iterable 

31from concurrent.futures import ThreadPoolExecutor 

32from contextlib import AbstractContextManager, contextmanager 

33from typing import Any 

34 

35import astropy.time 

36import sqlalchemy 

37from lsst.sphgeom import Circle, ConvexPolygon, UnitVector3d 

38 

39from ...core import Timespan, ddl 

40from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

41 

42StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

43 

44STATIC_TABLE_SPECS = StaticTablesTuple( 

45 a=ddl.TableSpec( 

46 fields=[ 

47 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

48 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

49 ] 

50 ), 

51 b=ddl.TableSpec( 

52 fields=[ 

53 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

54 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

55 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

56 ], 

57 unique=[("name",)], 

58 ), 

59 c=ddl.TableSpec( 

60 fields=[ 

61 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

62 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

63 ], 

64 foreignKeys=[ 

65 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

66 ], 

67 ), 

68) 

69 

70DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

71 fields=[ 

72 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

73 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

74 ], 

75 foreignKeys=[ 

76 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"), 

77 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

78 ], 

79) 

80 

81TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

82 fields=[ 

83 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

84 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

85 ], 

86) 

87 

88 

89@contextmanager 

90def _patch_getExistingTable(db: Database) -> Database: 

91 """Patch getExistingTable method in a database instance to test concurrent 

92 creation of tables. This patch obviously depends on knowning internals of 

93 ``ensureTableExists()`` implementation. 

94 """ 

95 original_method = db.getExistingTable 

96 

97 def _getExistingTable(name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table | None: 

98 # Return None on first call, but forward to original method after that 

99 db.getExistingTable = original_method 

100 return None 

101 

102 db.getExistingTable = _getExistingTable 

103 yield db 

104 db.getExistingTable = original_method 

105 

106 

107class DatabaseTests(ABC): 

108 """Generic tests for the `Database` interface that can be subclassed to 

109 generate tests for concrete implementations. 

110 """ 

111 

112 @abstractmethod 

113 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

114 """Return an empty `Database` with the given origin, or an 

115 automatically-generated one if ``origin`` is `None`. 

116 """ 

117 raise NotImplementedError() 

118 

119 @abstractmethod 

120 def asReadOnly(self, database: Database) -> AbstractContextManager[Database]: 

121 """Return a context manager for a read-only connection into the given 

122 database. 

123 

124 The original database should be considered unusable within the context 

125 but safe to use again afterwards (this allows the context manager to 

126 block write access by temporarily changing user permissions to really 

127 guarantee that write operations are not performed). 

128 """ 

129 raise NotImplementedError() 

130 

131 @abstractmethod 

132 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

133 """Return a new `Database` instance that points to the same underlying 

134 storage as the given one. 

135 """ 

136 raise NotImplementedError() 

137 

138 def query_list( 

139 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase 

140 ) -> list[sqlalchemy.engine.Row]: 

141 """Run a SELECT or other read-only query against the database and 

142 return the results as a list. 

143 

144 This is a thin wrapper around database.query() that just avoids 

145 context-manager boilerplate that is usefully verbose in production code 

146 but just noise in tests. 

147 """ 

148 with database.transaction(), database.query(executable) as result: 

149 return result.fetchall() 

150 

151 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any: 

152 """Run a SELECT query that yields a single column and row against the 

153 database and return its value. 

154 

155 This is a thin wrapper around database.query() that just avoids 

156 context-manager boilerplate that is usefully verbose in production code 

157 but just noise in tests. 

158 """ 

159 with database.query(executable) as result: 

160 return result.scalar() 

161 

162 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

163 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

164 # Checking more than this currently seems fragile, as it might restrict 

165 # what Database implementations do; we don't care if the spec is 

166 # actually preserved in terms of types and constraints as long as we 

167 # can use the returned table as if it was. 

168 

169 def checkStaticSchema(self, tables: StaticTablesTuple): 

170 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

171 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

172 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

173 

174 def testDeclareStaticTables(self): 

175 """Tests for `Database.declareStaticSchema` and the methods it 

176 delegates to. 

177 """ 

178 # Create the static schema in a new, empty database. 

179 newDatabase = self.makeEmptyDatabase() 

180 with newDatabase.declareStaticTables(create=True) as context: 

181 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

182 self.checkStaticSchema(tables) 

183 # Check that we can load that schema even from a read-only connection. 

184 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

185 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

186 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

187 self.checkStaticSchema(tables) 

188 

189 def testDeclareStaticTablesTwice(self): 

190 """Tests for `Database.declareStaticSchema` being called twice.""" 

191 # Create the static schema in a new, empty database. 

192 newDatabase = self.makeEmptyDatabase() 

193 with newDatabase.declareStaticTables(create=True) as context: 

194 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

195 self.checkStaticSchema(tables) 

196 # Second time it should raise 

197 with self.assertRaises(SchemaAlreadyDefinedError): 

198 with newDatabase.declareStaticTables(create=True) as context: 

199 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

200 # Check schema, it should still contain all tables, and maybe some 

201 # extra. 

202 with newDatabase.declareStaticTables(create=False) as context: 

203 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

204 

205 def testRepr(self): 

206 """Test that repr does not return a generic thing.""" 

207 newDatabase = self.makeEmptyDatabase() 

208 rep = repr(newDatabase) 

209 # Check that stringification works and gives us something different 

210 self.assertNotEqual(rep, str(newDatabase)) 

211 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

212 self.assertIn("://", rep) 

213 

214 def testDynamicTables(self): 

215 """Tests for `Database.ensureTableExists` and 

216 `Database.getExistingTable`. 

217 """ 

218 # Need to start with the static schema. 

219 newDatabase = self.makeEmptyDatabase() 

220 with newDatabase.declareStaticTables(create=True) as context: 

221 context.addTableTuple(STATIC_TABLE_SPECS) 

222 # Try to ensure the dynamic table exists in a read-only version of that 

223 # database, which should fail because we can't create it. 

224 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

225 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

226 context.addTableTuple(STATIC_TABLE_SPECS) 

227 with self.assertRaises(ReadOnlyDatabaseError): 

228 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

229 # Just getting the dynamic table before it exists should return None. 

230 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

231 # Ensure the new table exists back in the original database, which 

232 # should create it. 

233 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

234 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

235 # Ensuring that it exists should just return the exact same table 

236 # instance again. 

237 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

238 # Try again from the read-only database. 

239 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

240 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

241 context.addTableTuple(STATIC_TABLE_SPECS) 

242 # Just getting the dynamic table should now work... 

243 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

244 # ...as should ensuring that it exists, since it now does. 

245 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

246 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

247 # Trying to get the table with a different specification (at least 

248 # in terms of what columns are present) should raise. 

249 with self.assertRaises(DatabaseConflictError): 

250 newDatabase.ensureTableExists( 

251 "d", 

252 ddl.TableSpec( 

253 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

254 ), 

255 ) 

256 # Calling ensureTableExists inside a transaction block is an error, 

257 # even if it would do nothing. 

258 with newDatabase.transaction(): 

259 with self.assertRaises(AssertionError): 

260 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

261 

262 def testDynamicTablesConcurrency(self): 

263 """Tests for `Database.ensureTableExists` concurrent use.""" 

264 # We cannot really run things concurrently in a deterministic way, here 

265 # we just simulate a situation when the table is created by other 

266 # process between the call to getExistingTable() and actual table 

267 # creation. 

268 db1 = self.makeEmptyDatabase() 

269 with db1.declareStaticTables(create=True) as context: 

270 context.addTableTuple(STATIC_TABLE_SPECS) 

271 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

272 

273 # Make a dynamic table using separate connection 

274 db2 = self.getNewConnection(db1, writeable=True) 

275 with db2.declareStaticTables(create=False) as context: 

276 context.addTableTuple(STATIC_TABLE_SPECS) 

277 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

278 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

279 

280 # Call it again but trick it into thinking that table is not there. 

281 # This test depends on knowing implementation of ensureTableExists() 

282 # which initially calls getExistingTable() to check that table may 

283 # exist, the patch intercepts that call and returns None. 

284 with _patch_getExistingTable(db1): 

285 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

286 

287 def testTemporaryTables(self): 

288 """Tests for `Database.temporary_table`, and `Database.insert` with the 

289 ``select`` argument. 

290 """ 

291 # Need to start with the static schema; also insert some test data. 

292 newDatabase = self.makeEmptyDatabase() 

293 with newDatabase.declareStaticTables(create=True) as context: 

294 static = context.addTableTuple(STATIC_TABLE_SPECS) 

295 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

296 bIds = newDatabase.insert( 

297 static.b, 

298 {"name": "b1", "value": 11}, 

299 {"name": "b2", "value": 12}, 

300 {"name": "b3", "value": 13}, 

301 returnIds=True, 

302 ) 

303 # Create the table. 

304 with newDatabase.session(): 

305 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1: 

306 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

307 # Insert via a INSERT INTO ... SELECT query. 

308 newDatabase.insert( 

309 table1, 

310 select=sqlalchemy.sql.select( 

311 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

312 ) 

313 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

314 .where( 

315 sqlalchemy.sql.and_( 

316 static.a.columns.name == "a1", 

317 static.b.columns.value <= 12, 

318 ) 

319 ), 

320 ) 

321 # Check that the inserted rows are present. 

322 self.assertCountEqual( 

323 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

324 [row._asdict() for row in self.query_list(newDatabase, table1.select())], 

325 ) 

326 # Create another one via a read-only connection to the 

327 # database. We _do_ allow temporary table modifications in 

328 # read-only databases. 

329 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

330 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

331 context.addTableTuple(STATIC_TABLE_SPECS) 

332 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2: 

333 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

334 # Those tables should not be the same, despite having 

335 # the same ddl. 

336 self.assertIsNot(table1, table2) 

337 # Do a slightly different insert into this table, to 

338 # check that it works in a read-only database. This 

339 # time we pass column names as a kwarg to insert 

340 # instead of by labeling the columns in the select. 

341 existingReadOnlyDatabase.insert( 

342 table2, 

343 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

344 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

345 .where( 

346 sqlalchemy.sql.and_( 

347 static.a.columns.name == "a2", 

348 static.b.columns.value >= 12, 

349 ) 

350 ), 

351 names=["a_name", "b_id"], 

352 ) 

353 # Check that the inserted rows are present. 

354 self.assertCountEqual( 

355 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

356 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())], 

357 ) 

358 # Exiting the context managers will drop the temporary tables from 

359 # the read-only DB. It's unspecified whether attempting to use it 

360 # after this point is an error or just never returns any results, 

361 # so we can't test what it does, only that it's not an error. 

362 

363 def testSchemaSeparation(self): 

364 """Test that creating two different `Database` instances allows us 

365 to create different tables with the same name in each. 

366 """ 

367 db1 = self.makeEmptyDatabase(origin=1) 

368 with db1.declareStaticTables(create=True) as context: 

369 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

370 self.checkStaticSchema(tables) 

371 

372 db2 = self.makeEmptyDatabase(origin=2) 

373 # Make the DDL here intentionally different so we'll definitely 

374 # notice if db1 and db2 are pointing at the same schema. 

375 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

376 with db2.declareStaticTables(create=True) as context: 

377 # Make the DDL here intentionally different so we'll definitely 

378 # notice if db1 and db2 are pointing at the same schema. 

379 table = context.addTable("a", spec) 

380 self.checkTable(spec, table) 

381 

382 def testInsertQueryDelete(self): 

383 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

384 methods, as well as the `Base64Region` type and the ``onDelete`` 

385 argument to `ddl.ForeignKeySpec`. 

386 """ 

387 db = self.makeEmptyDatabase(origin=1) 

388 with db.declareStaticTables(create=True) as context: 

389 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

390 # Insert a single, non-autoincrement row that contains a region and 

391 # query to get it back. 

392 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

393 row = {"name": "a1", "region": region} 

394 db.insert(tables.a, row) 

395 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row]) 

396 # Insert multiple autoincrement rows but do not try to get the IDs 

397 # back immediately. 

398 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

399 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))] 

400 self.assertEqual(len(results), 2) 

401 for row in results: 

402 self.assertIn(row["name"], ("b1", "b2")) 

403 self.assertIsInstance(row["id"], int) 

404 self.assertGreater(results[1]["id"], results[0]["id"]) 

405 # Insert multiple autoincrement rows and get the IDs back from insert. 

406 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

407 ids = db.insert(tables.b, *rows, returnIds=True) 

408 results = [ 

409 r._asdict() 

410 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

411 ] 

412 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

413 self.assertCountEqual(results, expected) 

414 self.assertTrue(all(result["id"] is not None for result in results)) 

415 # Insert multiple rows into a table with an autoincrement primary key, 

416 # then use the returned IDs to insert into a dynamic table. 

417 rows = [{"b_id": results[0]["id"]}, {"b_id": None}] 

418 ids = db.insert(tables.c, *rows, returnIds=True) 

419 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

420 expected = [dict(row, id=id) for row, id in zip(rows, ids, strict=True)] 

421 self.assertCountEqual(results, expected) 

422 self.assertTrue(all(result["id"] is not None for result in results)) 

423 # Add the dynamic table. 

424 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

425 # Insert into it. 

426 rows = [{"c_id": id, "a_name": "a1"} for id in ids] 

427 db.insert(d, *rows) 

428 results = [r._asdict() for r in self.query_list(db, d.select())] 

429 self.assertCountEqual(rows, results) 

430 # Insert multiple rows into a table with an autoincrement primary key, 

431 # but pass in a value for the autoincrement key. 

432 rows2 = [ 

433 {"id": 700, "b_id": None}, 

434 {"id": 701, "b_id": None}, 

435 ] 

436 db.insert(tables.c, *rows2) 

437 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

438 self.assertCountEqual(results, expected + rows2) 

439 self.assertTrue(all(result["id"] is not None for result in results)) 

440 

441 # Define 'SELECT COUNT(*)' query for later use. 

442 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

443 # Get the values we inserted into table b. 

444 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())] 

445 # Remove two row from table b by ID. 

446 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

447 self.assertEqual(n, 2) 

448 # Remove the other two rows from table b by name. 

449 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

450 self.assertEqual(n, 2) 

451 # There should now be no rows in table b. 

452 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

453 # All b_id values in table c should now be NULL, because there's an 

454 # onDelete='SET NULL' foreign key. 

455 self.assertEqual( 

456 self.query_scalar( 

457 db, 

458 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711 

459 ), 

460 0, 

461 ) 

462 # Remove all rows in table a (there's only one); this should remove all 

463 # rows in d due to onDelete='CASCADE'. 

464 n = db.delete(tables.a, []) 

465 self.assertEqual(n, 1) 

466 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0) 

467 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0) 

468 

469 def testDeleteWhere(self): 

470 """Tests for `Database.deleteWhere`.""" 

471 db = self.makeEmptyDatabase(origin=1) 

472 with db.declareStaticTables(create=True) as context: 

473 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

474 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

475 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

476 

477 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

478 self.assertEqual(n, 3) 

479 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7) 

480 

481 n = db.deleteWhere( 

482 tables.b, 

483 tables.b.columns.id.in_( 

484 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

485 ), 

486 ) 

487 self.assertEqual(n, 4) 

488 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3) 

489 

490 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

491 self.assertEqual(n, 1) 

492 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2) 

493 

494 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

495 self.assertEqual(n, 2) 

496 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

497 

498 def testUpdate(self): 

499 """Tests for `Database.update`.""" 

500 db = self.makeEmptyDatabase(origin=1) 

501 with db.declareStaticTables(create=True) as context: 

502 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

503 # Insert two rows into table a, both without regions. 

504 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

505 # Update one of the rows with a region. 

506 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

507 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

508 self.assertEqual(n, 1) 

509 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

510 self.assertCountEqual( 

511 [r._asdict() for r in self.query_list(db, sql)], 

512 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

513 ) 

514 

515 def testSync(self): 

516 """Tests for `Database.sync`.""" 

517 db = self.makeEmptyDatabase(origin=1) 

518 with db.declareStaticTables(create=True) as context: 

519 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

520 # Insert a row with sync, because it doesn't exist yet. 

521 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

522 self.assertTrue(inserted) 

523 self.assertEqual( 

524 [{"id": values["id"], "name": "b1", "value": 10}], 

525 [r._asdict() for r in self.query_list(db, tables.b.select())], 

526 ) 

527 # Repeat that operation, which should do nothing but return the 

528 # requested values. 

529 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

530 self.assertFalse(inserted) 

531 self.assertEqual( 

532 [{"id": values["id"], "name": "b1", "value": 10}], 

533 [r._asdict() for r in self.query_list(db, tables.b.select())], 

534 ) 

535 # Repeat the operation without the 'extra' arg, which should also just 

536 # return the existing row. 

537 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

538 self.assertFalse(inserted) 

539 self.assertEqual( 

540 [{"id": values["id"], "name": "b1", "value": 10}], 

541 [r._asdict() for r in self.query_list(db, tables.b.select())], 

542 ) 

543 # Repeat the operation with a different value in 'extra'. That still 

544 # shouldn't be an error, because 'extra' is only used if we really do 

545 # insert. Also drop the 'returning' argument. 

546 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

547 self.assertFalse(inserted) 

548 self.assertEqual( 

549 [{"id": values["id"], "name": "b1", "value": 10}], 

550 [r._asdict() for r in self.query_list(db, tables.b.select())], 

551 ) 

552 # Repeat the operation with the correct value in 'compared' instead of 

553 # 'extra'. 

554 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

555 self.assertFalse(inserted) 

556 self.assertEqual( 

557 [{"id": values["id"], "name": "b1", "value": 10}], 

558 [r._asdict() for r in self.query_list(db, tables.b.select())], 

559 ) 

560 # Repeat the operation with an incorrect value in 'compared'; this 

561 # should raise. 

562 with self.assertRaises(DatabaseConflictError): 

563 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

564 # Try to sync in a read-only database. This should work if and only 

565 # if the matching row already exists. 

566 with self.asReadOnly(db) as rodb: 

567 with rodb.declareStaticTables(create=False) as context: 

568 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

569 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

570 self.assertFalse(inserted) 

571 self.assertEqual( 

572 [{"id": values["id"], "name": "b1", "value": 10}], 

573 [r._asdict() for r in self.query_list(rodb, tables.b.select())], 

574 ) 

575 with self.assertRaises(ReadOnlyDatabaseError): 

576 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

577 # Repeat the operation with a different value in 'compared' and ask to 

578 # update. 

579 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

580 self.assertEqual(updated, {"value": 10}) 

581 self.assertEqual( 

582 [{"id": values["id"], "name": "b1", "value": 20}], 

583 [r._asdict() for r in self.query_list(db, tables.b.select())], 

584 ) 

585 

586 def testReplace(self): 

587 """Tests for `Database.replace`.""" 

588 db = self.makeEmptyDatabase(origin=1) 

589 with db.declareStaticTables(create=True) as context: 

590 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

591 # Use 'replace' to insert a single row that contains a region and 

592 # query to get it back. 

593 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

594 row1 = {"name": "a1", "region": region} 

595 db.replace(tables.a, row1) 

596 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

597 # Insert another row without a region. 

598 row2 = {"name": "a2", "region": None} 

599 db.replace(tables.a, row2) 

600 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

601 # Use replace to re-insert both of those rows again, which should do 

602 # nothing. 

603 db.replace(tables.a, row1, row2) 

604 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

605 # Replace row1 with a row with no region, while reinserting row2. 

606 row1a = {"name": "a1", "region": None} 

607 db.replace(tables.a, row1a, row2) 

608 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2]) 

609 # Replace both rows, returning row1 to its original state, while adding 

610 # a new one. Pass them in in a different order. 

611 row2a = {"name": "a2", "region": region} 

612 row3 = {"name": "a3", "region": None} 

613 db.replace(tables.a, row3, row2a, row1) 

614 self.assertCountEqual( 

615 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3] 

616 ) 

617 

618 def testEnsure(self): 

619 """Tests for `Database.ensure`.""" 

620 db = self.makeEmptyDatabase(origin=1) 

621 with db.declareStaticTables(create=True) as context: 

622 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

623 # Use 'ensure' to insert a single row that contains a region and 

624 # query to get it back. 

625 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

626 row1 = {"name": "a1", "region": region} 

627 self.assertEqual(db.ensure(tables.a, row1), 1) 

628 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

629 # Insert another row without a region. 

630 row2 = {"name": "a2", "region": None} 

631 self.assertEqual(db.ensure(tables.a, row2), 1) 

632 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

633 # Use ensure to re-insert both of those rows again, which should do 

634 # nothing. 

635 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

636 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

637 # Attempt to insert row1's key with no region, while 

638 # reinserting row2. This should also do nothing. 

639 row1a = {"name": "a1", "region": None} 

640 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

641 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

642 # Attempt to insert new rows for both existing keys, this time also 

643 # adding a new row. Pass them in in a different order. Only the new 

644 # row should be added. 

645 row2a = {"name": "a2", "region": region} 

646 row3 = {"name": "a3", "region": None} 

647 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

648 self.assertCountEqual( 

649 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3] 

650 ) 

651 # Add some data to a table with both a primary key and a different 

652 # unique constraint. 

653 row_b = {"id": 5, "name": "five", "value": 50} 

654 db.insert(tables.b, row_b) 

655 # Attempt ensure with primary_key_only=False and a conflict for the 

656 # non-PK constraint. This should do nothing. 

657 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}) 

658 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

659 # Now use primary_key_only=True with conflict in only the non-PK field. 

660 # This should be an integrity error and nothing should change. 

661 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

662 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True) 

663 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

664 # With primary_key_only=True a conflict in the primary key is ignored 

665 # regardless of whether there is a conflict elsewhere. 

666 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True) 

667 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

668 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True) 

669 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

670 

671 def testTransactionNesting(self): 

672 """Test that transactions can be nested with the behavior in the 

673 presence of exceptions working as documented. 

674 """ 

675 db = self.makeEmptyDatabase(origin=1) 

676 with db.declareStaticTables(create=True) as context: 

677 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

678 # Insert one row so we can trigger integrity errors by trying to insert 

679 # a duplicate of it below. 

680 db.insert(tables.a, {"name": "a1"}) 

681 # First test: error recovery via explicit savepoint=True in the inner 

682 # transaction. 

683 with db.transaction(): 

684 # This insert should succeed, and should not be rolled back because 

685 # the assertRaises context should catch any exception before it 

686 # propagates up to the outer transaction. 

687 db.insert(tables.a, {"name": "a2"}) 

688 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

689 with db.transaction(savepoint=True): 

690 # This insert should succeed, but should be rolled back. 

691 db.insert(tables.a, {"name": "a4"}) 

692 # This insert should fail (duplicate primary key), raising 

693 # an exception. 

694 db.insert(tables.a, {"name": "a1"}) 

695 self.assertCountEqual( 

696 [r._asdict() for r in self.query_list(db, tables.a.select())], 

697 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

698 ) 

699 # Second test: error recovery via implicit savepoint=True, when the 

700 # innermost transaction is inside a savepoint=True transaction. 

701 with db.transaction(): 

702 # This insert should succeed, and should not be rolled back 

703 # because the assertRaises context should catch any 

704 # exception before it propagates up to the outer 

705 # transaction. 

706 db.insert(tables.a, {"name": "a3"}) 

707 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

708 with db.transaction(savepoint=True): 

709 # This insert should succeed, but should be rolled back. 

710 db.insert(tables.a, {"name": "a4"}) 

711 with db.transaction(): 

712 # This insert should succeed, but should be rolled 

713 # back. 

714 db.insert(tables.a, {"name": "a5"}) 

715 # This insert should fail (duplicate primary key), 

716 # raising an exception. 

717 db.insert(tables.a, {"name": "a1"}) 

718 self.assertCountEqual( 

719 [r._asdict() for r in self.query_list(db, tables.a.select())], 

720 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

721 ) 

722 

723 def testTransactionLocking(self): 

724 """Test that `Database.transaction` can be used to acquire a lock 

725 that prohibits concurrent writes. 

726 """ 

727 db1 = self.makeEmptyDatabase(origin=1) 

728 with db1.declareStaticTables(create=True) as context: 

729 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

730 

731 async def side1(lock: Iterable[str] = ()) -> tuple[set[str], set[str]]: 

732 """One side of the concurrent locking test. 

733 

734 This optionally locks the table (and maybe the whole database), 

735 does a select for its contents, inserts a new row, and then selects 

736 again, with some waiting in between to make sure the other side has 

737 a chance to _attempt_ to insert in between. If the locking is 

738 enabled and works, the difference between the selects should just 

739 be the insert done on this thread. 

740 """ 

741 # Give Side2 a chance to create a connection 

742 await asyncio.sleep(1.0) 

743 with db1.transaction(lock=lock): 

744 names1 = {row.name for row in self.query_list(db1, tables1.a.select())} 

745 # Give Side2 a chance to insert (which will be blocked if 

746 # we've acquired a lock). 

747 await asyncio.sleep(2.0) 

748 db1.insert(tables1.a, {"name": "a1"}) 

749 names2 = {row.name for row in self.query_list(db1, tables1.a.select())} 

750 return names1, names2 

751 

752 async def side2() -> None: 

753 """Other side of the concurrent locking test. 

754 

755 This side just waits a bit and then tries to insert a row into the 

756 table that the other side is trying to lock. Hopefully that 

757 waiting is enough to give the other side a chance to acquire the 

758 lock and thus make this side block until the lock is released. If 

759 this side manages to do the insert before side1 acquires the lock, 

760 we'll just warn about not succeeding at testing the locking, 

761 because we can only make that unlikely, not impossible. 

762 """ 

763 

764 def toRunInThread(): 

765 """Create new SQLite connection for use in thread. 

766 

767 SQLite locking isn't asyncio-friendly unless we actually 

768 run it in another thread. And SQLite gets very unhappy if 

769 we try to use a connection from multiple threads, so we have 

770 to create the new connection here instead of out in the main 

771 body of the test function. 

772 """ 

773 db2 = self.getNewConnection(db1, writeable=True) 

774 with db2.declareStaticTables(create=False) as context: 

775 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

776 with db2.transaction(): 

777 db2.insert(tables2.a, {"name": "a2"}) 

778 

779 await asyncio.sleep(2.0) 

780 loop = asyncio.get_running_loop() 

781 with ThreadPoolExecutor() as pool: 

782 await loop.run_in_executor(pool, toRunInThread) 

783 

784 async def testProblemsWithNoLocking() -> None: 

785 """Run side1 and side2 with no locking, attempting to demonstrate 

786 the problem that locking is supposed to solve. If we get unlucky 

787 with scheduling, side2 will just happen to insert after side1 is 

788 done, and we won't have anything definitive. We just warn in that 

789 case because we really don't want spurious test failures. 

790 """ 

791 task1 = asyncio.create_task(side1()) 

792 task2 = asyncio.create_task(side2()) 

793 

794 names1, names2 = await task1 

795 await task2 

796 if "a2" in names1: 

797 warnings.warn( 

798 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT.", 

799 stacklevel=1, 

800 ) 

801 self.assertEqual(names1, {"a2"}) 

802 self.assertEqual(names2, {"a1", "a2"}) 

803 elif "a2" not in names2: 

804 warnings.warn( 

805 "Unlucky scheduling in no-locking test: concurrent INSERT " 

806 "happened after second SELECT even without locking.", 

807 stacklevel=1, 

808 ) 

809 self.assertEqual(names1, set()) 

810 self.assertEqual(names2, {"a1"}) 

811 else: 

812 # This is the expected case: both INSERTS happen between the 

813 # two SELECTS. If we don't get this almost all of the time we 

814 # should adjust the sleep amounts. 

815 self.assertEqual(names1, set()) 

816 self.assertEqual(names2, {"a1", "a2"}) 

817 

818 asyncio.run(testProblemsWithNoLocking()) 

819 

820 # Clean up after first test. 

821 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

822 

823 async def testSolutionWithLocking() -> None: 

824 """Run side1 and side2 with locking, which should make side2 block 

825 its insert until side2 releases its lock. 

826 """ 

827 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

828 task2 = asyncio.create_task(side2()) 

829 

830 names1, names2 = await task1 

831 await task2 

832 if "a2" in names1: 

833 warnings.warn( 

834 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT.", 

835 stacklevel=1, 

836 ) 

837 self.assertEqual(names1, {"a2"}) 

838 self.assertEqual(names2, {"a1", "a2"}) 

839 else: 

840 # This is the expected case: the side2 INSERT happens after the 

841 # last SELECT on side1. This can also happen due to unlucky 

842 # scheduling, and we have no way to detect that here, but the 

843 # similar "no-locking" test has at least some chance of being 

844 # affected by the same problem and warning about it. 

845 self.assertEqual(names1, set()) 

846 self.assertEqual(names2, {"a1"}) 

847 

848 asyncio.run(testSolutionWithLocking()) 

849 

850 def testTimespanDatabaseRepresentation(self): 

851 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

852 methods that interact with it. 

853 """ 

854 # Make some test timespans to play with, with the full suite of 

855 # topological relationships. 

856 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

857 offset = astropy.time.TimeDelta(60, format="sec") 

858 timestamps = [start + offset * n for n in range(3)] 

859 aTimespans = [Timespan(begin=None, end=None)] 

860 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

861 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

862 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

863 aTimespans.append(Timespan.makeEmpty()) 

864 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

865 # Make another list of timespans that span the full range but don't 

866 # overlap. This is a subset of the previous list. 

867 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

868 bTimespans.extend( 

869 Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:], strict=True) 

870 ) 

871 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

872 # Make a database and create a table with that database's timespan 

873 # representation. This one will have no exclusion constraint and 

874 # a nullable timespan. 

875 db = self.makeEmptyDatabase(origin=1) 

876 TimespanReprClass = db.getTimespanRepresentation() 

877 aSpec = ddl.TableSpec( 

878 fields=[ 

879 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

880 ], 

881 ) 

882 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

883 aSpec.fields.add(fieldSpec) 

884 with db.declareStaticTables(create=True) as context: 

885 aTable = context.addTable("a", aSpec) 

886 self.maxDiff = None 

887 

888 def convertRowForInsert(row: dict) -> dict: 

889 """Convert a row containing a Timespan instance into one suitable 

890 for insertion into the database. 

891 """ 

892 result = row.copy() 

893 ts = result.pop(TimespanReprClass.NAME) 

894 return TimespanReprClass.update(ts, result=result) 

895 

896 def convertRowFromSelect(row: dict) -> dict: 

897 """Convert a row from the database into one containing a Timespan. 

898 

899 Parameters 

900 ---------- 

901 row : `dict` 

902 Original row. 

903 

904 Returns 

905 ------- 

906 row : `dict` 

907 The updated row. 

908 """ 

909 result = row.copy() 

910 timespan = TimespanReprClass.extract(result) 

911 for name in TimespanReprClass.getFieldNames(): 

912 del result[name] 

913 result[TimespanReprClass.NAME] = timespan 

914 return result 

915 

916 # Insert rows into table A, in chunks just to make things interesting. 

917 # Include one with a NULL timespan. 

918 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

919 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

920 db.insert(aTable, convertRowForInsert(aRows[0])) 

921 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

922 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

923 # Add another one with a NULL timespan, but this time by invoking 

924 # the server-side default. 

925 aRows.append({"id": len(aRows)}) 

926 db.insert(aTable, aRows[-1]) 

927 aRows[-1][TimespanReprClass.NAME] = None 

928 # Test basic round-trip through database. 

929 self.assertEqual( 

930 aRows, 

931 [ 

932 convertRowFromSelect(row._asdict()) 

933 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id)) 

934 ], 

935 ) 

936 # Create another table B with a not-null timespan and (if the database 

937 # supports it), an exclusion constraint. Use ensureTableExists this 

938 # time to check that mode of table creation vs. timespans. 

939 bSpec = ddl.TableSpec( 

940 fields=[ 

941 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

942 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

943 ], 

944 ) 

945 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

946 bSpec.fields.add(fieldSpec) 

947 if TimespanReprClass.hasExclusionConstraint(): 

948 bSpec.exclusion.add(("key", TimespanReprClass)) 

949 bTable = db.ensureTableExists("b", bSpec) 

950 # Insert rows into table B, again in chunks. Each Timespan appears 

951 # twice, but with different values for the 'key' field (which should 

952 # still be okay for any exclusion constraint we may have defined). 

953 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

954 offset = len(bRows) 

955 bRows.extend( 

956 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

957 ) 

958 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

959 db.insert(bTable, convertRowForInsert(bRows[2])) 

960 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

961 # Insert a row with no timespan into table B. This should invoke the 

962 # server-side default, which is a timespan over (-∞, ∞). We set 

963 # key=3 to avoid upsetting an exclusion constraint that might exist. 

964 bRows.append({"id": len(bRows), "key": 3}) 

965 db.insert(bTable, bRows[-1]) 

966 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

967 # Test basic round-trip through database. 

968 self.assertEqual( 

969 bRows, 

970 [ 

971 convertRowFromSelect(row._asdict()) 

972 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id)) 

973 ], 

974 ) 

975 # Test that we can't insert timespan=None into this table. 

976 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

977 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})) 

978 # IFF this database supports exclusion constraints, test that they 

979 # also prevent inserts. 

980 if TimespanReprClass.hasExclusionConstraint(): 

981 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

982 db.insert( 

983 bTable, 

984 convertRowForInsert( 

985 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

986 ), 

987 ) 

988 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

989 db.insert( 

990 bTable, 

991 convertRowForInsert( 

992 { 

993 "id": len(bRows), 

994 "key": 1, 

995 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

996 } 

997 ), 

998 ) 

999 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

1000 db.insert( 

1001 bTable, 

1002 convertRowForInsert( 

1003 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

1004 ), 

1005 ) 

1006 # Test NULL checks in SELECT queries, on both tables. 

1007 aRepr = TimespanReprClass.from_columns(aTable.columns) 

1008 self.assertEqual( 

1009 [row[TimespanReprClass.NAME] is None for row in aRows], 

1010 [ 

1011 row.f 

1012 for row in self.query_list( 

1013 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

1014 ) 

1015 ], 

1016 ) 

1017 bRepr = TimespanReprClass.from_columns(bTable.columns) 

1018 self.assertEqual( 

1019 [False for row in bRows], 

1020 [ 

1021 row.f 

1022 for row in self.query_list( 

1023 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

1024 ) 

1025 ], 

1026 ) 

1027 # Test relationships expressions that relate in-database timespans to 

1028 # Python-literal timespans, all from the more complete 'a' set; check 

1029 # that this is consistent with Python-only relationship tests. 

1030 for rhsRow in aRows: 

1031 if rhsRow[TimespanReprClass.NAME] is None: 

1032 continue 

1033 with self.subTest(rhsRow=rhsRow): 

1034 expected = {} 

1035 for lhsRow in aRows: 

1036 if lhsRow[TimespanReprClass.NAME] is None: 

1037 expected[lhsRow["id"]] = (None, None, None, None) 

1038 else: 

1039 expected[lhsRow["id"]] = ( 

1040 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

1041 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

1042 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

1043 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

1044 ) 

1045 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1046 sql = sqlalchemy.sql.select( 

1047 aTable.columns.id.label("lhs"), 

1048 aRepr.overlaps(rhsRepr).label("overlaps"), 

1049 aRepr.contains(rhsRepr).label("contains"), 

1050 (aRepr < rhsRepr).label("less_than"), 

1051 (aRepr > rhsRepr).label("greater_than"), 

1052 ).select_from(aTable) 

1053 queried = { 

1054 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1055 for row in self.query_list(db, sql) 

1056 } 

1057 self.assertEqual(expected, queried) 

1058 # Test relationship expressions that relate in-database timespans to 

1059 # each other, all from the more complete 'a' set; check that this is 

1060 # consistent with Python-only relationship tests. 

1061 expected = {} 

1062 for lhs, rhs in itertools.product(aRows, aRows): 

1063 lhsT = lhs[TimespanReprClass.NAME] 

1064 rhsT = rhs[TimespanReprClass.NAME] 

1065 if lhsT is not None and rhsT is not None: 

1066 expected[lhs["id"], rhs["id"]] = ( 

1067 lhsT.overlaps(rhsT), 

1068 lhsT.contains(rhsT), 

1069 lhsT < rhsT, 

1070 lhsT > rhsT, 

1071 ) 

1072 else: 

1073 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1074 lhsSubquery = aTable.alias("lhs") 

1075 rhsSubquery = aTable.alias("rhs") 

1076 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns) 

1077 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns) 

1078 sql = sqlalchemy.sql.select( 

1079 lhsSubquery.columns.id.label("lhs"), 

1080 rhsSubquery.columns.id.label("rhs"), 

1081 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1082 lhsRepr.contains(rhsRepr).label("contains"), 

1083 (lhsRepr < rhsRepr).label("less_than"), 

1084 (lhsRepr > rhsRepr).label("greater_than"), 

1085 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1086 queried = { 

1087 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1088 for row in self.query_list(db, sql) 

1089 } 

1090 self.assertEqual(expected, queried) 

1091 # Test relationship expressions between in-database timespans and 

1092 # Python-literal instantaneous times. 

1093 for t in timestamps: 

1094 with self.subTest(t=t): 

1095 expected = {} 

1096 for lhsRow in aRows: 

1097 if lhsRow[TimespanReprClass.NAME] is None: 

1098 expected[lhsRow["id"]] = (None, None, None, None) 

1099 else: 

1100 expected[lhsRow["id"]] = ( 

1101 lhsRow[TimespanReprClass.NAME].contains(t), 

1102 lhsRow[TimespanReprClass.NAME].overlaps(t), 

1103 lhsRow[TimespanReprClass.NAME] < t, 

1104 lhsRow[TimespanReprClass.NAME] > t, 

1105 ) 

1106 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1107 sql = sqlalchemy.sql.select( 

1108 aTable.columns.id.label("lhs"), 

1109 aRepr.contains(rhs).label("contains"), 

1110 aRepr.overlaps(rhs).label("overlaps_point"), 

1111 (aRepr < rhs).label("less_than"), 

1112 (aRepr > rhs).label("greater_than"), 

1113 ).select_from(aTable) 

1114 queried = { 

1115 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than) 

1116 for row in self.query_list(db, sql) 

1117 } 

1118 self.assertEqual(expected, queried) 

1119 

1120 def testConstantRows(self): 

1121 """Test Database.constant_rows.""" 

1122 new_db = self.makeEmptyDatabase() 

1123 with new_db.declareStaticTables(create=True) as context: 

1124 static = context.addTableTuple(STATIC_TABLE_SPECS) 

1125 b_ids = new_db.insert( 

1126 static.b, 

1127 {"name": "b1", "value": 11}, 

1128 {"name": "b2", "value": 12}, 

1129 {"name": "b3", "value": 13}, 

1130 returnIds=True, 

1131 ) 

1132 values_spec = ddl.TableSpec( 

1133 [ 

1134 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger), 

1135 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)), 

1136 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()), 

1137 ], 

1138 ) 

1139 values_data = [ 

1140 {"b": b_ids[0], "s": "b1", "r": None}, 

1141 {"b": b_ids[2], "s": "b3", "r": Circle.empty()}, 

1142 ] 

1143 values = new_db.constant_rows(values_spec.fields, *values_data) 

1144 select_values_alone = sqlalchemy.sql.select( 

1145 values.columns["b"], values.columns["s"], values.columns["r"] 

1146 ) 

1147 self.assertCountEqual( 

1148 [row._mapping for row in self.query_list(new_db, select_values_alone)], 

1149 values_data, 

1150 ) 

1151 select_values_joined = sqlalchemy.sql.select( 

1152 values.columns["s"].label("name"), static.b.columns["value"].label("value") 

1153 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"])) 

1154 self.assertCountEqual( 

1155 [row._mapping for row in self.query_list(new_db, select_values_joined)], 

1156 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}], 

1157 )