Coverage for python/lsst/daf/butler/registry/tests/_database.py: 6%

493 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-08 10:28 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25import asyncio 

26import itertools 

27import warnings 

28from abc import ABC, abstractmethod 

29from collections import namedtuple 

30from concurrent.futures import ThreadPoolExecutor 

31from contextlib import contextmanager 

32from typing import Any, ContextManager, Iterable, Optional, Set, Tuple 

33 

34import astropy.time 

35import sqlalchemy 

36from lsst.sphgeom import Circle, ConvexPolygon, UnitVector3d 

37 

38from ...core import Timespan, ddl 

39from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

40 

41StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

42 

43STATIC_TABLE_SPECS = StaticTablesTuple( 

44 a=ddl.TableSpec( 

45 fields=[ 

46 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

47 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

48 ] 

49 ), 

50 b=ddl.TableSpec( 

51 fields=[ 

52 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

53 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

54 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

55 ], 

56 unique=[("name",)], 

57 ), 

58 c=ddl.TableSpec( 

59 fields=[ 

60 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

61 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

62 ], 

63 foreignKeys=[ 

64 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

65 ], 

66 ), 

67) 

68 

69DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

70 fields=[ 

71 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

72 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

73 ], 

74 foreignKeys=[ 

75 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"), 

76 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

77 ], 

78) 

79 

80TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

81 fields=[ 

82 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

83 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

84 ], 

85) 

86 

87 

88@contextmanager 

89def _patch_getExistingTable(db: Database) -> Database: 

90 """Patch getExistingTable method in a database instance to test concurrent 

91 creation of tables. This patch obviously depends on knowning internals of 

92 ``ensureTableExists()`` implementation. 

93 """ 

94 original_method = db.getExistingTable 

95 

96 def _getExistingTable(name: str, spec: ddl.TableSpec) -> Optional[sqlalchemy.schema.Table]: 

97 # Return None on first call, but forward to original method after that 

98 db.getExistingTable = original_method 

99 return None 

100 

101 db.getExistingTable = _getExistingTable 

102 yield db 

103 db.getExistingTable = original_method 

104 

105 

106class DatabaseTests(ABC): 

107 """Generic tests for the `Database` interface that can be subclassed to 

108 generate tests for concrete implementations. 

109 """ 

110 

111 @abstractmethod 

112 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

113 """Return an empty `Database` with the given origin, or an 

114 automatically-generated one if ``origin`` is `None`. 

115 """ 

116 raise NotImplementedError() 

117 

118 @abstractmethod 

119 def asReadOnly(self, database: Database) -> ContextManager[Database]: 

120 """Return a context manager for a read-only connection into the given 

121 database. 

122 

123 The original database should be considered unusable within the context 

124 but safe to use again afterwards (this allows the context manager to 

125 block write access by temporarily changing user permissions to really 

126 guarantee that write operations are not performed). 

127 """ 

128 raise NotImplementedError() 

129 

130 @abstractmethod 

131 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

132 """Return a new `Database` instance that points to the same underlying 

133 storage as the given one. 

134 """ 

135 raise NotImplementedError() 

136 

137 def query_list( 

138 self, database: Database, executable: sqlalchemy.sql.expression.SelectBase 

139 ) -> list[sqlalchemy.engine.Row]: 

140 """Run a SELECT or other read-only query against the database and 

141 return the results as a list. 

142 

143 This is a thin wrapper around database.query() that just avoids 

144 context-manager boilerplate that is usefully verbose in production code 

145 but just noise in tests. 

146 """ 

147 with database.query(executable) as result: 

148 return result.fetchall() 

149 

150 def query_scalar(self, database: Database, executable: sqlalchemy.sql.expression.SelectBase) -> Any: 

151 """Run a SELECT query that yields a single column and row against the 

152 database and return its value. 

153 

154 This is a thin wrapper around database.query() that just avoids 

155 context-manager boilerplate that is usefully verbose in production code 

156 but just noise in tests. 

157 """ 

158 with database.query(executable) as result: 

159 return result.scalar() 

160 

161 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

162 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

163 # Checking more than this currently seems fragile, as it might restrict 

164 # what Database implementations do; we don't care if the spec is 

165 # actually preserved in terms of types and constraints as long as we 

166 # can use the returned table as if it was. 

167 

168 def checkStaticSchema(self, tables: StaticTablesTuple): 

169 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

170 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

171 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

172 

173 def testDeclareStaticTables(self): 

174 """Tests for `Database.declareStaticSchema` and the methods it 

175 delegates to. 

176 """ 

177 # Create the static schema in a new, empty database. 

178 newDatabase = self.makeEmptyDatabase() 

179 with newDatabase.declareStaticTables(create=True) as context: 

180 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

181 self.checkStaticSchema(tables) 

182 # Check that we can load that schema even from a read-only connection. 

183 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

184 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

185 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

186 self.checkStaticSchema(tables) 

187 

188 def testDeclareStaticTablesTwice(self): 

189 """Tests for `Database.declareStaticSchema` being called twice.""" 

190 # Create the static schema in a new, empty database. 

191 newDatabase = self.makeEmptyDatabase() 

192 with newDatabase.declareStaticTables(create=True) as context: 

193 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

194 self.checkStaticSchema(tables) 

195 # Second time it should raise 

196 with self.assertRaises(SchemaAlreadyDefinedError): 

197 with newDatabase.declareStaticTables(create=True) as context: 

198 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

199 # Check schema, it should still contain all tables, and maybe some 

200 # extra. 

201 with newDatabase.declareStaticTables(create=False) as context: 

202 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

203 

204 def testRepr(self): 

205 """Test that repr does not return a generic thing.""" 

206 newDatabase = self.makeEmptyDatabase() 

207 rep = repr(newDatabase) 

208 # Check that stringification works and gives us something different 

209 self.assertNotEqual(rep, str(newDatabase)) 

210 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

211 self.assertIn("://", rep) 

212 

213 def testDynamicTables(self): 

214 """Tests for `Database.ensureTableExists` and 

215 `Database.getExistingTable`. 

216 """ 

217 # Need to start with the static schema. 

218 newDatabase = self.makeEmptyDatabase() 

219 with newDatabase.declareStaticTables(create=True) as context: 

220 context.addTableTuple(STATIC_TABLE_SPECS) 

221 # Try to ensure the dynamic table exists in a read-only version of that 

222 # database, which should fail because we can't create it. 

223 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

224 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

225 context.addTableTuple(STATIC_TABLE_SPECS) 

226 with self.assertRaises(ReadOnlyDatabaseError): 

227 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

228 # Just getting the dynamic table before it exists should return None. 

229 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

230 # Ensure the new table exists back in the original database, which 

231 # should create it. 

232 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

233 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

234 # Ensuring that it exists should just return the exact same table 

235 # instance again. 

236 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

237 # Try again from the read-only database. 

238 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

239 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

240 context.addTableTuple(STATIC_TABLE_SPECS) 

241 # Just getting the dynamic table should now work... 

242 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

243 # ...as should ensuring that it exists, since it now does. 

244 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

245 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

246 # Trying to get the table with a different specification (at least 

247 # in terms of what columns are present) should raise. 

248 with self.assertRaises(DatabaseConflictError): 

249 newDatabase.ensureTableExists( 

250 "d", 

251 ddl.TableSpec( 

252 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

253 ), 

254 ) 

255 # Calling ensureTableExists inside a transaction block is an error, 

256 # even if it would do nothing. 

257 with newDatabase.transaction(): 

258 with self.assertRaises(AssertionError): 

259 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

260 

261 def testDynamicTablesConcurrency(self): 

262 """Tests for `Database.ensureTableExists` concurrent use.""" 

263 # We cannot really run things concurrently in a deterministic way, here 

264 # we just simulate a situation when the table is created by other 

265 # process between the call to getExistingTable() and actual table 

266 # creation. 

267 db1 = self.makeEmptyDatabase() 

268 with db1.declareStaticTables(create=True) as context: 

269 context.addTableTuple(STATIC_TABLE_SPECS) 

270 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

271 

272 # Make a dynamic table using separate connection 

273 db2 = self.getNewConnection(db1, writeable=True) 

274 with db2.declareStaticTables(create=False) as context: 

275 context.addTableTuple(STATIC_TABLE_SPECS) 

276 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

277 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

278 

279 # Call it again but trick it into thinking that table is not there. 

280 # This test depends on knowing implementation of ensureTableExists() 

281 # which initially calls getExistingTable() to check that table may 

282 # exist, the patch intercepts that call and returns None. 

283 with _patch_getExistingTable(db1): 

284 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

285 

286 def testTemporaryTables(self): 

287 """Tests for `Database.temporary_table`, and `Database.insert` with the 

288 ``select`` argument. 

289 """ 

290 # Need to start with the static schema; also insert some test data. 

291 newDatabase = self.makeEmptyDatabase() 

292 with newDatabase.declareStaticTables(create=True) as context: 

293 static = context.addTableTuple(STATIC_TABLE_SPECS) 

294 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

295 bIds = newDatabase.insert( 

296 static.b, 

297 {"name": "b1", "value": 11}, 

298 {"name": "b2", "value": 12}, 

299 {"name": "b3", "value": 13}, 

300 returnIds=True, 

301 ) 

302 # Create the table. 

303 with newDatabase.session(): 

304 with newDatabase.temporary_table(TEMPORARY_TABLE_SPEC, "e1") as table1: 

305 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

306 # Insert via a INSERT INTO ... SELECT query. 

307 newDatabase.insert( 

308 table1, 

309 select=sqlalchemy.sql.select( 

310 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

311 ) 

312 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

313 .where( 

314 sqlalchemy.sql.and_( 

315 static.a.columns.name == "a1", 

316 static.b.columns.value <= 12, 

317 ) 

318 ), 

319 ) 

320 # Check that the inserted rows are present. 

321 self.assertCountEqual( 

322 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

323 [row._asdict() for row in self.query_list(newDatabase, table1.select())], 

324 ) 

325 # Create another one via a read-only connection to the 

326 # database. We _do_ allow temporary table modifications in 

327 # read-only databases. 

328 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

329 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

330 context.addTableTuple(STATIC_TABLE_SPECS) 

331 with existingReadOnlyDatabase.temporary_table(TEMPORARY_TABLE_SPEC) as table2: 

332 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

333 # Those tables should not be the same, despite having 

334 # the same ddl. 

335 self.assertIsNot(table1, table2) 

336 # Do a slightly different insert into this table, to 

337 # check that it works in a read-only database. This 

338 # time we pass column names as a kwarg to insert 

339 # instead of by labeling the columns in the select. 

340 existingReadOnlyDatabase.insert( 

341 table2, 

342 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

343 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

344 .where( 

345 sqlalchemy.sql.and_( 

346 static.a.columns.name == "a2", 

347 static.b.columns.value >= 12, 

348 ) 

349 ), 

350 names=["a_name", "b_id"], 

351 ) 

352 # Check that the inserted rows are present. 

353 self.assertCountEqual( 

354 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

355 [row._asdict() for row in self.query_list(existingReadOnlyDatabase, table2.select())], 

356 ) 

357 # Exiting the context managers will drop the temporary tables from 

358 # the read-only DB. It's unspecified whether attempting to use it 

359 # after this point is an error or just never returns any results, 

360 # so we can't test what it does, only that it's not an error. 

361 

362 def testSchemaSeparation(self): 

363 """Test that creating two different `Database` instances allows us 

364 to create different tables with the same name in each. 

365 """ 

366 db1 = self.makeEmptyDatabase(origin=1) 

367 with db1.declareStaticTables(create=True) as context: 

368 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

369 self.checkStaticSchema(tables) 

370 

371 db2 = self.makeEmptyDatabase(origin=2) 

372 # Make the DDL here intentionally different so we'll definitely 

373 # notice if db1 and db2 are pointing at the same schema. 

374 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

375 with db2.declareStaticTables(create=True) as context: 

376 # Make the DDL here intentionally different so we'll definitely 

377 # notice if db1 and db2 are pointing at the same schema. 

378 table = context.addTable("a", spec) 

379 self.checkTable(spec, table) 

380 

381 def testInsertQueryDelete(self): 

382 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

383 methods, as well as the `Base64Region` type and the ``onDelete`` 

384 argument to `ddl.ForeignKeySpec`. 

385 """ 

386 db = self.makeEmptyDatabase(origin=1) 

387 with db.declareStaticTables(create=True) as context: 

388 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

389 # Insert a single, non-autoincrement row that contains a region and 

390 # query to get it back. 

391 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

392 row = {"name": "a1", "region": region} 

393 db.insert(tables.a, row) 

394 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row]) 

395 # Insert multiple autoincrement rows but do not try to get the IDs 

396 # back immediately. 

397 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

398 results = [r._asdict() for r in self.query_list(db, tables.b.select().order_by("id"))] 

399 self.assertEqual(len(results), 2) 

400 for row in results: 

401 self.assertIn(row["name"], ("b1", "b2")) 

402 self.assertIsInstance(row["id"], int) 

403 self.assertGreater(results[1]["id"], results[0]["id"]) 

404 # Insert multiple autoincrement rows and get the IDs back from insert. 

405 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

406 ids = db.insert(tables.b, *rows, returnIds=True) 

407 results = [ 

408 r._asdict() 

409 for r in self.query_list(db, tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

410 ] 

411 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

412 self.assertCountEqual(results, expected) 

413 self.assertTrue(all(result["id"] is not None for result in results)) 

414 # Insert multiple rows into a table with an autoincrement primary key, 

415 # then use the returned IDs to insert into a dynamic table. 

416 rows = [{"b_id": results[0]["id"]}, {"b_id": None}] 

417 ids = db.insert(tables.c, *rows, returnIds=True) 

418 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

419 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

420 self.assertCountEqual(results, expected) 

421 self.assertTrue(all(result["id"] is not None for result in results)) 

422 # Add the dynamic table. 

423 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

424 # Insert into it. 

425 rows = [{"c_id": id, "a_name": "a1"} for id in ids] 

426 db.insert(d, *rows) 

427 results = [r._asdict() for r in self.query_list(db, d.select())] 

428 self.assertCountEqual(rows, results) 

429 # Insert multiple rows into a table with an autoincrement primary key, 

430 # but pass in a value for the autoincrement key. 

431 rows2 = [ 

432 {"id": 700, "b_id": None}, 

433 {"id": 701, "b_id": None}, 

434 ] 

435 db.insert(tables.c, *rows2) 

436 results = [r._asdict() for r in self.query_list(db, tables.c.select())] 

437 self.assertCountEqual(results, expected + rows2) 

438 self.assertTrue(all(result["id"] is not None for result in results)) 

439 

440 # Define 'SELECT COUNT(*)' query for later use. 

441 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

442 # Get the values we inserted into table b. 

443 bValues = [r._asdict() for r in self.query_list(db, tables.b.select())] 

444 # Remove two row from table b by ID. 

445 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

446 self.assertEqual(n, 2) 

447 # Remove the other two rows from table b by name. 

448 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

449 self.assertEqual(n, 2) 

450 # There should now be no rows in table b. 

451 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

452 # All b_id values in table c should now be NULL, because there's an 

453 # onDelete='SET NULL' foreign key. 

454 self.assertEqual( 

455 self.query_scalar( 

456 db, 

457 count.select_from(tables.c).where(tables.c.columns.b_id != None), # noqa:E711 

458 ), 

459 0, 

460 ) 

461 # Remove all rows in table a (there's only one); this should remove all 

462 # rows in d due to onDelete='CASCADE'. 

463 n = db.delete(tables.a, []) 

464 self.assertEqual(n, 1) 

465 self.assertEqual(self.query_scalar(db, count.select_from(tables.a)), 0) 

466 self.assertEqual(self.query_scalar(db, count.select_from(d)), 0) 

467 

468 def testDeleteWhere(self): 

469 """Tests for `Database.deleteWhere`.""" 

470 db = self.makeEmptyDatabase(origin=1) 

471 with db.declareStaticTables(create=True) as context: 

472 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

473 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

474 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

475 

476 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

477 self.assertEqual(n, 3) 

478 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 7) 

479 

480 n = db.deleteWhere( 

481 tables.b, 

482 tables.b.columns.id.in_( 

483 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

484 ), 

485 ) 

486 self.assertEqual(n, 4) 

487 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 3) 

488 

489 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

490 self.assertEqual(n, 1) 

491 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 2) 

492 

493 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

494 self.assertEqual(n, 2) 

495 self.assertEqual(self.query_scalar(db, count.select_from(tables.b)), 0) 

496 

497 def testUpdate(self): 

498 """Tests for `Database.update`.""" 

499 db = self.makeEmptyDatabase(origin=1) 

500 with db.declareStaticTables(create=True) as context: 

501 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

502 # Insert two rows into table a, both without regions. 

503 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

504 # Update one of the rows with a region. 

505 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

506 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

507 self.assertEqual(n, 1) 

508 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

509 self.assertCountEqual( 

510 [r._asdict() for r in self.query_list(db, sql)], 

511 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

512 ) 

513 

514 def testSync(self): 

515 """Tests for `Database.sync`.""" 

516 db = self.makeEmptyDatabase(origin=1) 

517 with db.declareStaticTables(create=True) as context: 

518 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

519 # Insert a row with sync, because it doesn't exist yet. 

520 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

521 self.assertTrue(inserted) 

522 self.assertEqual( 

523 [{"id": values["id"], "name": "b1", "value": 10}], 

524 [r._asdict() for r in self.query_list(db, tables.b.select())], 

525 ) 

526 # Repeat that operation, which should do nothing but return the 

527 # requested values. 

528 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

529 self.assertFalse(inserted) 

530 self.assertEqual( 

531 [{"id": values["id"], "name": "b1", "value": 10}], 

532 [r._asdict() for r in self.query_list(db, tables.b.select())], 

533 ) 

534 # Repeat the operation without the 'extra' arg, which should also just 

535 # return the existing row. 

536 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

537 self.assertFalse(inserted) 

538 self.assertEqual( 

539 [{"id": values["id"], "name": "b1", "value": 10}], 

540 [r._asdict() for r in self.query_list(db, tables.b.select())], 

541 ) 

542 # Repeat the operation with a different value in 'extra'. That still 

543 # shouldn't be an error, because 'extra' is only used if we really do 

544 # insert. Also drop the 'returning' argument. 

545 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

546 self.assertFalse(inserted) 

547 self.assertEqual( 

548 [{"id": values["id"], "name": "b1", "value": 10}], 

549 [r._asdict() for r in self.query_list(db, tables.b.select())], 

550 ) 

551 # Repeat the operation with the correct value in 'compared' instead of 

552 # 'extra'. 

553 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

554 self.assertFalse(inserted) 

555 self.assertEqual( 

556 [{"id": values["id"], "name": "b1", "value": 10}], 

557 [r._asdict() for r in self.query_list(db, tables.b.select())], 

558 ) 

559 # Repeat the operation with an incorrect value in 'compared'; this 

560 # should raise. 

561 with self.assertRaises(DatabaseConflictError): 

562 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

563 # Try to sync in a read-only database. This should work if and only 

564 # if the matching row already exists. 

565 with self.asReadOnly(db) as rodb: 

566 with rodb.declareStaticTables(create=False) as context: 

567 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

568 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

569 self.assertFalse(inserted) 

570 self.assertEqual( 

571 [{"id": values["id"], "name": "b1", "value": 10}], 

572 [r._asdict() for r in self.query_list(rodb, tables.b.select())], 

573 ) 

574 with self.assertRaises(ReadOnlyDatabaseError): 

575 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

576 # Repeat the operation with a different value in 'compared' and ask to 

577 # update. 

578 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

579 self.assertEqual(updated, {"value": 10}) 

580 self.assertEqual( 

581 [{"id": values["id"], "name": "b1", "value": 20}], 

582 [r._asdict() for r in self.query_list(db, tables.b.select())], 

583 ) 

584 

585 def testReplace(self): 

586 """Tests for `Database.replace`.""" 

587 db = self.makeEmptyDatabase(origin=1) 

588 with db.declareStaticTables(create=True) as context: 

589 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

590 # Use 'replace' to insert a single row that contains a region and 

591 # query to get it back. 

592 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

593 row1 = {"name": "a1", "region": region} 

594 db.replace(tables.a, row1) 

595 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

596 # Insert another row without a region. 

597 row2 = {"name": "a2", "region": None} 

598 db.replace(tables.a, row2) 

599 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

600 # Use replace to re-insert both of those rows again, which should do 

601 # nothing. 

602 db.replace(tables.a, row1, row2) 

603 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

604 # Replace row1 with a row with no region, while reinserting row2. 

605 row1a = {"name": "a1", "region": None} 

606 db.replace(tables.a, row1a, row2) 

607 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1a, row2]) 

608 # Replace both rows, returning row1 to its original state, while adding 

609 # a new one. Pass them in in a different order. 

610 row2a = {"name": "a2", "region": region} 

611 row3 = {"name": "a3", "region": None} 

612 db.replace(tables.a, row3, row2a, row1) 

613 self.assertCountEqual( 

614 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2a, row3] 

615 ) 

616 

617 def testEnsure(self): 

618 """Tests for `Database.ensure`.""" 

619 db = self.makeEmptyDatabase(origin=1) 

620 with db.declareStaticTables(create=True) as context: 

621 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

622 # Use 'ensure' to insert a single row that contains a region and 

623 # query to get it back. 

624 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

625 row1 = {"name": "a1", "region": region} 

626 self.assertEqual(db.ensure(tables.a, row1), 1) 

627 self.assertEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1]) 

628 # Insert another row without a region. 

629 row2 = {"name": "a2", "region": None} 

630 self.assertEqual(db.ensure(tables.a, row2), 1) 

631 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

632 # Use ensure to re-insert both of those rows again, which should do 

633 # nothing. 

634 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

635 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

636 # Attempt to insert row1's key with no region, while 

637 # reinserting row2. This should also do nothing. 

638 row1a = {"name": "a1", "region": None} 

639 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

640 self.assertCountEqual([r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2]) 

641 # Attempt to insert new rows for both existing keys, this time also 

642 # adding a new row. Pass them in in a different order. Only the new 

643 # row should be added. 

644 row2a = {"name": "a2", "region": region} 

645 row3 = {"name": "a3", "region": None} 

646 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

647 self.assertCountEqual( 

648 [r._asdict() for r in self.query_list(db, tables.a.select())], [row1, row2, row3] 

649 ) 

650 # Add some data to a table with both a primary key and a different 

651 # unique constraint. 

652 row_b = {"id": 5, "name": "five", "value": 50} 

653 db.insert(tables.b, row_b) 

654 # Attempt ensure with primary_key_only=False and a conflict for the 

655 # non-PK constraint. This should do nothing. 

656 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}) 

657 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

658 # Now use primary_key_only=True with conflict in only the non-PK field. 

659 # This should be an integrity error and nothing should change. 

660 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

661 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True) 

662 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

663 # With primary_key_only=True a conflict in the primary key is ignored 

664 # regardless of whether there is a conflict elsewhere. 

665 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True) 

666 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

667 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True) 

668 self.assertEqual([r._asdict() for r in self.query_list(db, tables.b.select())], [row_b]) 

669 

670 def testTransactionNesting(self): 

671 """Test that transactions can be nested with the behavior in the 

672 presence of exceptions working as documented. 

673 """ 

674 db = self.makeEmptyDatabase(origin=1) 

675 with db.declareStaticTables(create=True) as context: 

676 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

677 # Insert one row so we can trigger integrity errors by trying to insert 

678 # a duplicate of it below. 

679 db.insert(tables.a, {"name": "a1"}) 

680 # First test: error recovery via explicit savepoint=True in the inner 

681 # transaction. 

682 with db.transaction(): 

683 # This insert should succeed, and should not be rolled back because 

684 # the assertRaises context should catch any exception before it 

685 # propagates up to the outer transaction. 

686 db.insert(tables.a, {"name": "a2"}) 

687 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

688 with db.transaction(savepoint=True): 

689 # This insert should succeed, but should be rolled back. 

690 db.insert(tables.a, {"name": "a4"}) 

691 # This insert should fail (duplicate primary key), raising 

692 # an exception. 

693 db.insert(tables.a, {"name": "a1"}) 

694 self.assertCountEqual( 

695 [r._asdict() for r in self.query_list(db, tables.a.select())], 

696 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

697 ) 

698 # Second test: error recovery via implicit savepoint=True, when the 

699 # innermost transaction is inside a savepoint=True transaction. 

700 with db.transaction(): 

701 # This insert should succeed, and should not be rolled back 

702 # because the assertRaises context should catch any 

703 # exception before it propagates up to the outer 

704 # transaction. 

705 db.insert(tables.a, {"name": "a3"}) 

706 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

707 with db.transaction(savepoint=True): 

708 # This insert should succeed, but should be rolled back. 

709 db.insert(tables.a, {"name": "a4"}) 

710 with db.transaction(): 

711 # This insert should succeed, but should be rolled 

712 # back. 

713 db.insert(tables.a, {"name": "a5"}) 

714 # This insert should fail (duplicate primary key), 

715 # raising an exception. 

716 db.insert(tables.a, {"name": "a1"}) 

717 self.assertCountEqual( 

718 [r._asdict() for r in self.query_list(db, tables.a.select())], 

719 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

720 ) 

721 

722 def testTransactionLocking(self): 

723 """Test that `Database.transaction` can be used to acquire a lock 

724 that prohibits concurrent writes. 

725 """ 

726 db1 = self.makeEmptyDatabase(origin=1) 

727 with db1.declareStaticTables(create=True) as context: 

728 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

729 

730 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]: 

731 """One side of the concurrent locking test. 

732 

733 This optionally locks the table (and maybe the whole database), 

734 does a select for its contents, inserts a new row, and then selects 

735 again, with some waiting in between to make sure the other side has 

736 a chance to _attempt_ to insert in between. If the locking is 

737 enabled and works, the difference between the selects should just 

738 be the insert done on this thread. 

739 """ 

740 # Give Side2 a chance to create a connection 

741 await asyncio.sleep(1.0) 

742 with db1.transaction(lock=lock): 

743 names1 = {row.name for row in self.query_list(db1, tables1.a.select())} 

744 # Give Side2 a chance to insert (which will be blocked if 

745 # we've acquired a lock). 

746 await asyncio.sleep(2.0) 

747 db1.insert(tables1.a, {"name": "a1"}) 

748 names2 = {row.name for row in self.query_list(db1, tables1.a.select())} 

749 return names1, names2 

750 

751 async def side2() -> None: 

752 """The other side of the concurrent locking test. 

753 

754 This side just waits a bit and then tries to insert a row into the 

755 table that the other side is trying to lock. Hopefully that 

756 waiting is enough to give the other side a chance to acquire the 

757 lock and thus make this side block until the lock is released. If 

758 this side manages to do the insert before side1 acquires the lock, 

759 we'll just warn about not succeeding at testing the locking, 

760 because we can only make that unlikely, not impossible. 

761 """ 

762 

763 def toRunInThread(): 

764 """SQLite locking isn't asyncio-friendly unless we actually 

765 run it in another thread. And SQLite gets very unhappy if 

766 we try to use a connection from multiple threads, so we have 

767 to create the new connection here instead of out in the main 

768 body of the test function. 

769 """ 

770 db2 = self.getNewConnection(db1, writeable=True) 

771 with db2.declareStaticTables(create=False) as context: 

772 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

773 with db2.transaction(): 

774 db2.insert(tables2.a, {"name": "a2"}) 

775 

776 await asyncio.sleep(2.0) 

777 loop = asyncio.get_running_loop() 

778 with ThreadPoolExecutor() as pool: 

779 await loop.run_in_executor(pool, toRunInThread) 

780 

781 async def testProblemsWithNoLocking() -> None: 

782 """Run side1 and side2 with no locking, attempting to demonstrate 

783 the problem that locking is supposed to solve. If we get unlucky 

784 with scheduling, side2 will just happen to insert after side1 is 

785 done, and we won't have anything definitive. We just warn in that 

786 case because we really don't want spurious test failures. 

787 """ 

788 task1 = asyncio.create_task(side1()) 

789 task2 = asyncio.create_task(side2()) 

790 

791 names1, names2 = await task1 

792 await task2 

793 if "a2" in names1: 

794 warnings.warn( 

795 "Unlucky scheduling in no-locking test: concurrent INSERT happened before first SELECT." 

796 ) 

797 self.assertEqual(names1, {"a2"}) 

798 self.assertEqual(names2, {"a1", "a2"}) 

799 elif "a2" not in names2: 

800 warnings.warn( 

801 "Unlucky scheduling in no-locking test: concurrent INSERT " 

802 "happened after second SELECT even without locking." 

803 ) 

804 self.assertEqual(names1, set()) 

805 self.assertEqual(names2, {"a1"}) 

806 else: 

807 # This is the expected case: both INSERTS happen between the 

808 # two SELECTS. If we don't get this almost all of the time we 

809 # should adjust the sleep amounts. 

810 self.assertEqual(names1, set()) 

811 self.assertEqual(names2, {"a1", "a2"}) 

812 

813 asyncio.run(testProblemsWithNoLocking()) 

814 

815 # Clean up after first test. 

816 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

817 

818 async def testSolutionWithLocking() -> None: 

819 """Run side1 and side2 with locking, which should make side2 block 

820 its insert until side2 releases its lock. 

821 """ 

822 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

823 task2 = asyncio.create_task(side2()) 

824 

825 names1, names2 = await task1 

826 await task2 

827 if "a2" in names1: 

828 warnings.warn( 

829 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT." 

830 ) 

831 self.assertEqual(names1, {"a2"}) 

832 self.assertEqual(names2, {"a1", "a2"}) 

833 else: 

834 # This is the expected case: the side2 INSERT happens after the 

835 # last SELECT on side1. This can also happen due to unlucky 

836 # scheduling, and we have no way to detect that here, but the 

837 # similar "no-locking" test has at least some chance of being 

838 # affected by the same problem and warning about it. 

839 self.assertEqual(names1, set()) 

840 self.assertEqual(names2, {"a1"}) 

841 

842 asyncio.run(testSolutionWithLocking()) 

843 

844 def testTimespanDatabaseRepresentation(self): 

845 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

846 methods that interact with it. 

847 """ 

848 # Make some test timespans to play with, with the full suite of 

849 # topological relationships. 

850 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

851 offset = astropy.time.TimeDelta(60, format="sec") 

852 timestamps = [start + offset * n for n in range(3)] 

853 aTimespans = [Timespan(begin=None, end=None)] 

854 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

855 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

856 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

857 aTimespans.append(Timespan.makeEmpty()) 

858 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

859 # Make another list of timespans that span the full range but don't 

860 # overlap. This is a subset of the previous list. 

861 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

862 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:])) 

863 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

864 # Make a database and create a table with that database's timespan 

865 # representation. This one will have no exclusion constraint and 

866 # a nullable timespan. 

867 db = self.makeEmptyDatabase(origin=1) 

868 TimespanReprClass = db.getTimespanRepresentation() 

869 aSpec = ddl.TableSpec( 

870 fields=[ 

871 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

872 ], 

873 ) 

874 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

875 aSpec.fields.add(fieldSpec) 

876 with db.declareStaticTables(create=True) as context: 

877 aTable = context.addTable("a", aSpec) 

878 self.maxDiff = None 

879 

880 def convertRowForInsert(row: dict) -> dict: 

881 """Convert a row containing a Timespan instance into one suitable 

882 for insertion into the database. 

883 """ 

884 result = row.copy() 

885 ts = result.pop(TimespanReprClass.NAME) 

886 return TimespanReprClass.update(ts, result=result) 

887 

888 def convertRowFromSelect(row: dict) -> dict: 

889 """Convert a row from the database into one containing a Timespan. 

890 

891 Parameters 

892 ---------- 

893 row : `dict` 

894 Original row. 

895 

896 Returns 

897 ------- 

898 row : `dict` 

899 The updated row. 

900 """ 

901 result = row.copy() 

902 timespan = TimespanReprClass.extract(result) 

903 for name in TimespanReprClass.getFieldNames(): 

904 del result[name] 

905 result[TimespanReprClass.NAME] = timespan 

906 return result 

907 

908 # Insert rows into table A, in chunks just to make things interesting. 

909 # Include one with a NULL timespan. 

910 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

911 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

912 db.insert(aTable, convertRowForInsert(aRows[0])) 

913 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

914 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

915 # Add another one with a NULL timespan, but this time by invoking 

916 # the server-side default. 

917 aRows.append({"id": len(aRows)}) 

918 db.insert(aTable, aRows[-1]) 

919 aRows[-1][TimespanReprClass.NAME] = None 

920 # Test basic round-trip through database. 

921 self.assertEqual( 

922 aRows, 

923 [ 

924 convertRowFromSelect(row._asdict()) 

925 for row in self.query_list(db, aTable.select().order_by(aTable.columns.id)) 

926 ], 

927 ) 

928 # Create another table B with a not-null timespan and (if the database 

929 # supports it), an exclusion constraint. Use ensureTableExists this 

930 # time to check that mode of table creation vs. timespans. 

931 bSpec = ddl.TableSpec( 

932 fields=[ 

933 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

934 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

935 ], 

936 ) 

937 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

938 bSpec.fields.add(fieldSpec) 

939 if TimespanReprClass.hasExclusionConstraint(): 

940 bSpec.exclusion.add(("key", TimespanReprClass)) 

941 bTable = db.ensureTableExists("b", bSpec) 

942 # Insert rows into table B, again in chunks. Each Timespan appears 

943 # twice, but with different values for the 'key' field (which should 

944 # still be okay for any exclusion constraint we may have defined). 

945 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

946 offset = len(bRows) 

947 bRows.extend( 

948 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

949 ) 

950 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

951 db.insert(bTable, convertRowForInsert(bRows[2])) 

952 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

953 # Insert a row with no timespan into table B. This should invoke the 

954 # server-side default, which is a timespan over (-∞, ∞). We set 

955 # key=3 to avoid upsetting an exclusion constraint that might exist. 

956 bRows.append({"id": len(bRows), "key": 3}) 

957 db.insert(bTable, bRows[-1]) 

958 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

959 # Test basic round-trip through database. 

960 self.assertEqual( 

961 bRows, 

962 [ 

963 convertRowFromSelect(row._asdict()) 

964 for row in self.query_list(db, bTable.select().order_by(bTable.columns.id)) 

965 ], 

966 ) 

967 # Test that we can't insert timespan=None into this table. 

968 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

969 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})) 

970 # IFF this database supports exclusion constraints, test that they 

971 # also prevent inserts. 

972 if TimespanReprClass.hasExclusionConstraint(): 

973 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

974 db.insert( 

975 bTable, 

976 convertRowForInsert( 

977 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

978 ), 

979 ) 

980 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

981 db.insert( 

982 bTable, 

983 convertRowForInsert( 

984 { 

985 "id": len(bRows), 

986 "key": 1, 

987 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

988 } 

989 ), 

990 ) 

991 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

992 db.insert( 

993 bTable, 

994 convertRowForInsert( 

995 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

996 ), 

997 ) 

998 # Test NULL checks in SELECT queries, on both tables. 

999 aRepr = TimespanReprClass.from_columns(aTable.columns) 

1000 self.assertEqual( 

1001 [row[TimespanReprClass.NAME] is None for row in aRows], 

1002 [ 

1003 row.f 

1004 for row in self.query_list( 

1005 db, sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

1006 ) 

1007 ], 

1008 ) 

1009 bRepr = TimespanReprClass.from_columns(bTable.columns) 

1010 self.assertEqual( 

1011 [False for row in bRows], 

1012 [ 

1013 row.f 

1014 for row in self.query_list( 

1015 db, sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

1016 ) 

1017 ], 

1018 ) 

1019 # Test relationships expressions that relate in-database timespans to 

1020 # Python-literal timespans, all from the more complete 'a' set; check 

1021 # that this is consistent with Python-only relationship tests. 

1022 for rhsRow in aRows: 

1023 if rhsRow[TimespanReprClass.NAME] is None: 

1024 continue 

1025 with self.subTest(rhsRow=rhsRow): 

1026 expected = {} 

1027 for lhsRow in aRows: 

1028 if lhsRow[TimespanReprClass.NAME] is None: 

1029 expected[lhsRow["id"]] = (None, None, None, None) 

1030 else: 

1031 expected[lhsRow["id"]] = ( 

1032 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

1033 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

1034 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

1035 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

1036 ) 

1037 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1038 sql = sqlalchemy.sql.select( 

1039 aTable.columns.id.label("lhs"), 

1040 aRepr.overlaps(rhsRepr).label("overlaps"), 

1041 aRepr.contains(rhsRepr).label("contains"), 

1042 (aRepr < rhsRepr).label("less_than"), 

1043 (aRepr > rhsRepr).label("greater_than"), 

1044 ).select_from(aTable) 

1045 queried = { 

1046 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1047 for row in self.query_list(db, sql) 

1048 } 

1049 self.assertEqual(expected, queried) 

1050 # Test relationship expressions that relate in-database timespans to 

1051 # each other, all from the more complete 'a' set; check that this is 

1052 # consistent with Python-only relationship tests. 

1053 expected = {} 

1054 for lhs, rhs in itertools.product(aRows, aRows): 

1055 lhsT = lhs[TimespanReprClass.NAME] 

1056 rhsT = rhs[TimespanReprClass.NAME] 

1057 if lhsT is not None and rhsT is not None: 

1058 expected[lhs["id"], rhs["id"]] = ( 

1059 lhsT.overlaps(rhsT), 

1060 lhsT.contains(rhsT), 

1061 lhsT < rhsT, 

1062 lhsT > rhsT, 

1063 ) 

1064 else: 

1065 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1066 lhsSubquery = aTable.alias("lhs") 

1067 rhsSubquery = aTable.alias("rhs") 

1068 lhsRepr = TimespanReprClass.from_columns(lhsSubquery.columns) 

1069 rhsRepr = TimespanReprClass.from_columns(rhsSubquery.columns) 

1070 sql = sqlalchemy.sql.select( 

1071 lhsSubquery.columns.id.label("lhs"), 

1072 rhsSubquery.columns.id.label("rhs"), 

1073 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1074 lhsRepr.contains(rhsRepr).label("contains"), 

1075 (lhsRepr < rhsRepr).label("less_than"), 

1076 (lhsRepr > rhsRepr).label("greater_than"), 

1077 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1078 queried = { 

1079 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1080 for row in self.query_list(db, sql) 

1081 } 

1082 self.assertEqual(expected, queried) 

1083 # Test relationship expressions between in-database timespans and 

1084 # Python-literal instantaneous times. 

1085 for t in timestamps: 

1086 with self.subTest(t=t): 

1087 expected = {} 

1088 for lhsRow in aRows: 

1089 if lhsRow[TimespanReprClass.NAME] is None: 

1090 expected[lhsRow["id"]] = (None, None, None, None) 

1091 else: 

1092 expected[lhsRow["id"]] = ( 

1093 lhsRow[TimespanReprClass.NAME].contains(t), 

1094 lhsRow[TimespanReprClass.NAME].overlaps(t), 

1095 lhsRow[TimespanReprClass.NAME] < t, 

1096 lhsRow[TimespanReprClass.NAME] > t, 

1097 ) 

1098 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1099 sql = sqlalchemy.sql.select( 

1100 aTable.columns.id.label("lhs"), 

1101 aRepr.contains(rhs).label("contains"), 

1102 aRepr.overlaps(rhs).label("overlaps_point"), 

1103 (aRepr < rhs).label("less_than"), 

1104 (aRepr > rhs).label("greater_than"), 

1105 ).select_from(aTable) 

1106 queried = { 

1107 row.lhs: (row.contains, row.overlaps_point, row.less_than, row.greater_than) 

1108 for row in self.query_list(db, sql) 

1109 } 

1110 self.assertEqual(expected, queried) 

1111 

1112 def testConstantRows(self): 

1113 """Test Database.constant_rows.""" 

1114 new_db = self.makeEmptyDatabase() 

1115 with new_db.declareStaticTables(create=True) as context: 

1116 static = context.addTableTuple(STATIC_TABLE_SPECS) 

1117 b_ids = new_db.insert( 

1118 static.b, 

1119 {"name": "b1", "value": 11}, 

1120 {"name": "b2", "value": 12}, 

1121 {"name": "b3", "value": 13}, 

1122 returnIds=True, 

1123 ) 

1124 values_spec = ddl.TableSpec( 

1125 [ 

1126 ddl.FieldSpec(name="b", dtype=sqlalchemy.BigInteger), 

1127 ddl.FieldSpec(name="s", dtype=sqlalchemy.String(8)), 

1128 ddl.FieldSpec(name="r", dtype=ddl.Base64Region()), 

1129 ], 

1130 ) 

1131 values_data = [ 

1132 {"b": b_ids[0], "s": "b1", "r": None}, 

1133 {"b": b_ids[2], "s": "b3", "r": Circle.empty()}, 

1134 ] 

1135 values = new_db.constant_rows(values_spec.fields, *values_data) 

1136 select_values_alone = sqlalchemy.sql.select( 

1137 values.columns["b"], values.columns["s"], values.columns["r"] 

1138 ) 

1139 self.assertCountEqual( 

1140 [dict(row) for row in self.query_list(new_db, select_values_alone)], 

1141 values_data, 

1142 ) 

1143 select_values_joined = sqlalchemy.sql.select( 

1144 values.columns["s"].label("name"), static.b.columns["value"].label("value") 

1145 ).select_from(values.join(static.b, onclause=static.b.columns["id"] == values.columns["b"])) 

1146 self.assertCountEqual( 

1147 [dict(row) for row in self.query_list(new_db, select_values_joined)], 

1148 [{"value": 11, "name": "b1"}, {"value": 13, "name": "b3"}], 

1149 )