Coverage for python/lsst/daf/butler/registry/tests/_database.py: 6%

478 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-31 10:07 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25import asyncio 

26import itertools 

27import warnings 

28from abc import ABC, abstractmethod 

29from collections import namedtuple 

30from concurrent.futures import ThreadPoolExecutor 

31from contextlib import contextmanager 

32from typing import ContextManager, Iterable, Optional, Set, Tuple 

33 

34import astropy.time 

35import sqlalchemy 

36from lsst.sphgeom import ConvexPolygon, UnitVector3d 

37 

38from ...core import Timespan, ddl 

39from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

40 

41StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

42 

43STATIC_TABLE_SPECS = StaticTablesTuple( 

44 a=ddl.TableSpec( 

45 fields=[ 

46 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

47 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

48 ] 

49 ), 

50 b=ddl.TableSpec( 

51 fields=[ 

52 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

53 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

54 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

55 ], 

56 unique=[("name",)], 

57 ), 

58 c=ddl.TableSpec( 

59 fields=[ 

60 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

61 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

62 ], 

63 foreignKeys=[ 

64 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

65 ], 

66 ), 

67) 

68 

69DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

70 fields=[ 

71 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

72 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

73 ], 

74 foreignKeys=[ 

75 ddl.ForeignKeySpec("c", source=("c_id",), target=("id",), onDelete="CASCADE"), 

76 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

77 ], 

78) 

79 

80TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

81 fields=[ 

82 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

83 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

84 ], 

85) 

86 

87 

88@contextmanager 

89def _patch_getExistingTable(db: Database) -> Database: 

90 """Patch getExistingTable method in a database instance to test concurrent 

91 creation of tables. This patch obviously depends on knowning internals of 

92 ``ensureTableExists()`` implementation. 

93 """ 

94 original_method = db.getExistingTable 

95 

96 def _getExistingTable(name: str, spec: ddl.TableSpec) -> Optional[sqlalchemy.schema.Table]: 

97 # Return None on first call, but forward to original method after that 

98 db.getExistingTable = original_method 

99 return None 

100 

101 db.getExistingTable = _getExistingTable 

102 yield db 

103 db.getExistingTable = original_method 

104 

105 

106class DatabaseTests(ABC): 

107 """Generic tests for the `Database` interface that can be subclassed to 

108 generate tests for concrete implementations. 

109 """ 

110 

111 @abstractmethod 

112 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

113 """Return an empty `Database` with the given origin, or an 

114 automatically-generated one if ``origin`` is `None`. 

115 """ 

116 raise NotImplementedError() 

117 

118 @abstractmethod 

119 def asReadOnly(self, database: Database) -> ContextManager[Database]: 

120 """Return a context manager for a read-only connection into the given 

121 database. 

122 

123 The original database should be considered unusable within the context 

124 but safe to use again afterwards (this allows the context manager to 

125 block write access by temporarily changing user permissions to really 

126 guarantee that write operations are not performed). 

127 """ 

128 raise NotImplementedError() 

129 

130 @abstractmethod 

131 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

132 """Return a new `Database` instance that points to the same underlying 

133 storage as the given one. 

134 """ 

135 raise NotImplementedError() 

136 

137 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

138 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

139 # Checking more than this currently seems fragile, as it might restrict 

140 # what Database implementations do; we don't care if the spec is 

141 # actually preserved in terms of types and constraints as long as we 

142 # can use the returned table as if it was. 

143 

144 def checkStaticSchema(self, tables: StaticTablesTuple): 

145 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

146 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

147 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

148 

149 def testDeclareStaticTables(self): 

150 """Tests for `Database.declareStaticSchema` and the methods it 

151 delegates to. 

152 """ 

153 # Create the static schema in a new, empty database. 

154 newDatabase = self.makeEmptyDatabase() 

155 with newDatabase.declareStaticTables(create=True) as context: 

156 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

157 self.checkStaticSchema(tables) 

158 # Check that we can load that schema even from a read-only connection. 

159 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

160 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

161 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

162 self.checkStaticSchema(tables) 

163 

164 def testDeclareStaticTablesTwice(self): 

165 """Tests for `Database.declareStaticSchema` being called twice.""" 

166 # Create the static schema in a new, empty database. 

167 newDatabase = self.makeEmptyDatabase() 

168 with newDatabase.declareStaticTables(create=True) as context: 

169 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

170 self.checkStaticSchema(tables) 

171 # Second time it should raise 

172 with self.assertRaises(SchemaAlreadyDefinedError): 

173 with newDatabase.declareStaticTables(create=True) as context: 

174 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

175 # Check schema, it should still contain all tables, and maybe some 

176 # extra. 

177 with newDatabase.declareStaticTables(create=False) as context: 

178 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

179 

180 def testRepr(self): 

181 """Test that repr does not return a generic thing.""" 

182 newDatabase = self.makeEmptyDatabase() 

183 rep = repr(newDatabase) 

184 # Check that stringification works and gives us something different 

185 self.assertNotEqual(rep, str(newDatabase)) 

186 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

187 self.assertIn("://", rep) 

188 

189 def testDynamicTables(self): 

190 """Tests for `Database.ensureTableExists` and 

191 `Database.getExistingTable`. 

192 """ 

193 # Need to start with the static schema. 

194 newDatabase = self.makeEmptyDatabase() 

195 with newDatabase.declareStaticTables(create=True) as context: 

196 context.addTableTuple(STATIC_TABLE_SPECS) 

197 # Try to ensure the dynamic table exists in a read-only version of that 

198 # database, which should fail because we can't create it. 

199 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

200 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

201 context.addTableTuple(STATIC_TABLE_SPECS) 

202 with self.assertRaises(ReadOnlyDatabaseError): 

203 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

204 # Just getting the dynamic table before it exists should return None. 

205 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

206 # Ensure the new table exists back in the original database, which 

207 # should create it. 

208 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

209 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

210 # Ensuring that it exists should just return the exact same table 

211 # instance again. 

212 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

213 # Try again from the read-only database. 

214 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

215 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

216 context.addTableTuple(STATIC_TABLE_SPECS) 

217 # Just getting the dynamic table should now work... 

218 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

219 # ...as should ensuring that it exists, since it now does. 

220 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

221 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

222 # Trying to get the table with a different specification (at least 

223 # in terms of what columns are present) should raise. 

224 with self.assertRaises(DatabaseConflictError): 

225 newDatabase.ensureTableExists( 

226 "d", 

227 ddl.TableSpec( 

228 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

229 ), 

230 ) 

231 # Calling ensureTableExists inside a transaction block is an error, 

232 # even if it would do nothing. 

233 with newDatabase.transaction(): 

234 with self.assertRaises(AssertionError): 

235 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

236 

237 def testDynamicTablesConcurrency(self): 

238 """Tests for `Database.ensureTableExists` concurrent use.""" 

239 # We cannot really run things concurrently in a deterministic way, here 

240 # we just simulate a situation when the table is created by other 

241 # process between the call to getExistingTable() and actual table 

242 # creation. 

243 db1 = self.makeEmptyDatabase() 

244 with db1.declareStaticTables(create=True) as context: 

245 context.addTableTuple(STATIC_TABLE_SPECS) 

246 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

247 

248 # Make a dynamic table using separate connection 

249 db2 = self.getNewConnection(db1, writeable=True) 

250 with db2.declareStaticTables(create=False) as context: 

251 context.addTableTuple(STATIC_TABLE_SPECS) 

252 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

253 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

254 

255 # Call it again but trick it into thinking that table is not there. 

256 # This test depends on knowing implementation of ensureTableExists() 

257 # which initially calls getExistingTable() to check that table may 

258 # exist, the patch intercepts that call and returns None. 

259 with _patch_getExistingTable(db1): 

260 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

261 

262 def testTemporaryTables(self): 

263 """Tests for `Database.makeTemporaryTable`, 

264 `Database.dropTemporaryTable`, and `Database.insert` with 

265 the ``select`` argument. 

266 """ 

267 # Need to start with the static schema; also insert some test data. 

268 newDatabase = self.makeEmptyDatabase() 

269 with newDatabase.declareStaticTables(create=True) as context: 

270 static = context.addTableTuple(STATIC_TABLE_SPECS) 

271 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

272 bIds = newDatabase.insert( 

273 static.b, 

274 {"name": "b1", "value": 11}, 

275 {"name": "b2", "value": 12}, 

276 {"name": "b3", "value": 13}, 

277 returnIds=True, 

278 ) 

279 # Create the table. 

280 with newDatabase.session() as session: 

281 table1 = session.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1") 

282 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

283 # Insert via a INSERT INTO ... SELECT query. 

284 newDatabase.insert( 

285 table1, 

286 select=sqlalchemy.sql.select( 

287 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

288 ) 

289 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

290 .where( 

291 sqlalchemy.sql.and_( 

292 static.a.columns.name == "a1", 

293 static.b.columns.value <= 12, 

294 ) 

295 ), 

296 ) 

297 # Check that the inserted rows are present. 

298 self.assertCountEqual( 

299 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

300 [row._asdict() for row in newDatabase.query(table1.select())], 

301 ) 

302 # Create another one via a read-only connection to the database. 

303 # We _do_ allow temporary table modifications in read-only 

304 # databases. 

305 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

306 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

307 context.addTableTuple(STATIC_TABLE_SPECS) 

308 with existingReadOnlyDatabase.session() as session2: 

309 table2 = session2.makeTemporaryTable(TEMPORARY_TABLE_SPEC) 

310 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

311 # Those tables should not be the same, despite having the 

312 # same ddl. 

313 self.assertIsNot(table1, table2) 

314 # Do a slightly different insert into this table, to check 

315 # that it works in a read-only database. This time we 

316 # pass column names as a kwarg to insert instead of by 

317 # labeling the columns in the select. 

318 existingReadOnlyDatabase.insert( 

319 table2, 

320 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

321 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

322 .where( 

323 sqlalchemy.sql.and_( 

324 static.a.columns.name == "a2", 

325 static.b.columns.value >= 12, 

326 ) 

327 ), 

328 names=["a_name", "b_id"], 

329 ) 

330 # Check that the inserted rows are present. 

331 self.assertCountEqual( 

332 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

333 [row._asdict() for row in existingReadOnlyDatabase.query(table2.select())], 

334 ) 

335 # Drop the temporary table from the read-only DB. It's 

336 # unspecified whether attempting to use it after this 

337 # point is an error or just never returns any results, so 

338 # we can't test what it does, only that it's not an error. 

339 session2.dropTemporaryTable(table2) 

340 # Drop the original temporary table. 

341 session.dropTemporaryTable(table1) 

342 

343 def testSchemaSeparation(self): 

344 """Test that creating two different `Database` instances allows us 

345 to create different tables with the same name in each. 

346 """ 

347 db1 = self.makeEmptyDatabase(origin=1) 

348 with db1.declareStaticTables(create=True) as context: 

349 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

350 self.checkStaticSchema(tables) 

351 

352 db2 = self.makeEmptyDatabase(origin=2) 

353 # Make the DDL here intentionally different so we'll definitely 

354 # notice if db1 and db2 are pointing at the same schema. 

355 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

356 with db2.declareStaticTables(create=True) as context: 

357 # Make the DDL here intentionally different so we'll definitely 

358 # notice if db1 and db2 are pointing at the same schema. 

359 table = context.addTable("a", spec) 

360 self.checkTable(spec, table) 

361 

362 def testInsertQueryDelete(self): 

363 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

364 methods, as well as the `Base64Region` type and the ``onDelete`` 

365 argument to `ddl.ForeignKeySpec`. 

366 """ 

367 db = self.makeEmptyDatabase(origin=1) 

368 with db.declareStaticTables(create=True) as context: 

369 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

370 # Insert a single, non-autoincrement row that contains a region and 

371 # query to get it back. 

372 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

373 row = {"name": "a1", "region": region} 

374 db.insert(tables.a, row) 

375 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row]) 

376 # Insert multiple autoincrement rows but do not try to get the IDs 

377 # back immediately. 

378 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

379 results = [r._asdict() for r in db.query(tables.b.select().order_by("id"))] 

380 self.assertEqual(len(results), 2) 

381 for row in results: 

382 self.assertIn(row["name"], ("b1", "b2")) 

383 self.assertIsInstance(row["id"], int) 

384 self.assertGreater(results[1]["id"], results[0]["id"]) 

385 # Insert multiple autoincrement rows and get the IDs back from insert. 

386 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

387 ids = db.insert(tables.b, *rows, returnIds=True) 

388 results = [ 

389 r._asdict() for r in db.query(tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

390 ] 

391 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

392 self.assertCountEqual(results, expected) 

393 self.assertTrue(all(result["id"] is not None for result in results)) 

394 # Insert multiple rows into a table with an autoincrement primary key, 

395 # then use the returned IDs to insert into a dynamic table. 

396 rows = [{"b_id": results[0]["id"]}, {"b_id": None}] 

397 ids = db.insert(tables.c, *rows, returnIds=True) 

398 results = [r._asdict() for r in db.query(tables.c.select())] 

399 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

400 self.assertCountEqual(results, expected) 

401 self.assertTrue(all(result["id"] is not None for result in results)) 

402 # Add the dynamic table. 

403 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

404 # Insert into it. 

405 rows = [{"c_id": id, "a_name": "a1"} for id in ids] 

406 db.insert(d, *rows) 

407 results = [r._asdict() for r in db.query(d.select())] 

408 self.assertCountEqual(rows, results) 

409 # Insert multiple rows into a table with an autoincrement primary key, 

410 # but pass in a value for the autoincrement key. 

411 rows2 = [ 

412 {"id": 700, "b_id": None}, 

413 {"id": 701, "b_id": None}, 

414 ] 

415 db.insert(tables.c, *rows2) 

416 results = [r._asdict() for r in db.query(tables.c.select())] 

417 self.assertCountEqual(results, expected + rows2) 

418 self.assertTrue(all(result["id"] is not None for result in results)) 

419 

420 # Define 'SELECT COUNT(*)' query for later use. 

421 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

422 # Get the values we inserted into table b. 

423 bValues = [r._asdict() for r in db.query(tables.b.select())] 

424 # Remove two row from table b by ID. 

425 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

426 self.assertEqual(n, 2) 

427 # Remove the other two rows from table b by name. 

428 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

429 self.assertEqual(n, 2) 

430 # There should now be no rows in table b. 

431 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0) 

432 # All b_id values in table c should now be NULL, because there's an 

433 # onDelete='SET NULL' foreign key. 

434 self.assertEqual( 

435 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711 

436 0, 

437 ) 

438 # Remove all rows in table a (there's only one); this should remove all 

439 # rows in d due to onDelete='CASCADE'. 

440 n = db.delete(tables.a, []) 

441 self.assertEqual(n, 1) 

442 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0) 

443 self.assertEqual(db.query(count.select_from(d)).scalar(), 0) 

444 

445 def testDeleteWhere(self): 

446 """Tests for `Database.deleteWhere`.""" 

447 db = self.makeEmptyDatabase(origin=1) 

448 with db.declareStaticTables(create=True) as context: 

449 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

450 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

451 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

452 

453 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

454 self.assertEqual(n, 3) 

455 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 7) 

456 

457 n = db.deleteWhere( 

458 tables.b, 

459 tables.b.columns.id.in_( 

460 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

461 ), 

462 ) 

463 self.assertEqual(n, 4) 

464 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 3) 

465 

466 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

467 self.assertEqual(n, 1) 

468 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 2) 

469 

470 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

471 self.assertEqual(n, 2) 

472 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0) 

473 

474 def testUpdate(self): 

475 """Tests for `Database.update`.""" 

476 db = self.makeEmptyDatabase(origin=1) 

477 with db.declareStaticTables(create=True) as context: 

478 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

479 # Insert two rows into table a, both without regions. 

480 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

481 # Update one of the rows with a region. 

482 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

483 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

484 self.assertEqual(n, 1) 

485 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

486 self.assertCountEqual( 

487 [r._asdict() for r in db.query(sql)], 

488 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

489 ) 

490 

491 def testSync(self): 

492 """Tests for `Database.sync`.""" 

493 db = self.makeEmptyDatabase(origin=1) 

494 with db.declareStaticTables(create=True) as context: 

495 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

496 # Insert a row with sync, because it doesn't exist yet. 

497 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

498 self.assertTrue(inserted) 

499 self.assertEqual( 

500 [{"id": values["id"], "name": "b1", "value": 10}], 

501 [r._asdict() for r in db.query(tables.b.select())], 

502 ) 

503 # Repeat that operation, which should do nothing but return the 

504 # requested values. 

505 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

506 self.assertFalse(inserted) 

507 self.assertEqual( 

508 [{"id": values["id"], "name": "b1", "value": 10}], 

509 [r._asdict() for r in db.query(tables.b.select())], 

510 ) 

511 # Repeat the operation without the 'extra' arg, which should also just 

512 # return the existing row. 

513 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

514 self.assertFalse(inserted) 

515 self.assertEqual( 

516 [{"id": values["id"], "name": "b1", "value": 10}], 

517 [r._asdict() for r in db.query(tables.b.select())], 

518 ) 

519 # Repeat the operation with a different value in 'extra'. That still 

520 # shouldn't be an error, because 'extra' is only used if we really do 

521 # insert. Also drop the 'returning' argument. 

522 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

523 self.assertFalse(inserted) 

524 self.assertEqual( 

525 [{"id": values["id"], "name": "b1", "value": 10}], 

526 [r._asdict() for r in db.query(tables.b.select())], 

527 ) 

528 # Repeat the operation with the correct value in 'compared' instead of 

529 # 'extra'. 

530 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

531 self.assertFalse(inserted) 

532 self.assertEqual( 

533 [{"id": values["id"], "name": "b1", "value": 10}], 

534 [r._asdict() for r in db.query(tables.b.select())], 

535 ) 

536 # Repeat the operation with an incorrect value in 'compared'; this 

537 # should raise. 

538 with self.assertRaises(DatabaseConflictError): 

539 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

540 # Try to sync in a read-only database. This should work if and only 

541 # if the matching row already exists. 

542 with self.asReadOnly(db) as rodb: 

543 with rodb.declareStaticTables(create=False) as context: 

544 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

545 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

546 self.assertFalse(inserted) 

547 self.assertEqual( 

548 [{"id": values["id"], "name": "b1", "value": 10}], 

549 [r._asdict() for r in rodb.query(tables.b.select())], 

550 ) 

551 with self.assertRaises(ReadOnlyDatabaseError): 

552 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

553 # Repeat the operation with a different value in 'compared' and ask to 

554 # update. 

555 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

556 self.assertEqual(updated, {"value": 10}) 

557 self.assertEqual( 

558 [{"id": values["id"], "name": "b1", "value": 20}], 

559 [r._asdict() for r in db.query(tables.b.select())], 

560 ) 

561 

562 def testReplace(self): 

563 """Tests for `Database.replace`.""" 

564 db = self.makeEmptyDatabase(origin=1) 

565 with db.declareStaticTables(create=True) as context: 

566 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

567 # Use 'replace' to insert a single row that contains a region and 

568 # query to get it back. 

569 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

570 row1 = {"name": "a1", "region": region} 

571 db.replace(tables.a, row1) 

572 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1]) 

573 # Insert another row without a region. 

574 row2 = {"name": "a2", "region": None} 

575 db.replace(tables.a, row2) 

576 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

577 # Use replace to re-insert both of those rows again, which should do 

578 # nothing. 

579 db.replace(tables.a, row1, row2) 

580 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

581 # Replace row1 with a row with no region, while reinserting row2. 

582 row1a = {"name": "a1", "region": None} 

583 db.replace(tables.a, row1a, row2) 

584 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1a, row2]) 

585 # Replace both rows, returning row1 to its original state, while adding 

586 # a new one. Pass them in in a different order. 

587 row2a = {"name": "a2", "region": region} 

588 row3 = {"name": "a3", "region": None} 

589 db.replace(tables.a, row3, row2a, row1) 

590 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2a, row3]) 

591 

592 def testEnsure(self): 

593 """Tests for `Database.ensure`.""" 

594 db = self.makeEmptyDatabase(origin=1) 

595 with db.declareStaticTables(create=True) as context: 

596 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

597 # Use 'ensure' to insert a single row that contains a region and 

598 # query to get it back. 

599 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

600 row1 = {"name": "a1", "region": region} 

601 self.assertEqual(db.ensure(tables.a, row1), 1) 

602 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1]) 

603 # Insert another row without a region. 

604 row2 = {"name": "a2", "region": None} 

605 self.assertEqual(db.ensure(tables.a, row2), 1) 

606 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

607 # Use ensure to re-insert both of those rows again, which should do 

608 # nothing. 

609 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

610 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

611 # Attempt to insert row1's key with no region, while 

612 # reinserting row2. This should also do nothing. 

613 row1a = {"name": "a1", "region": None} 

614 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

615 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

616 # Attempt to insert new rows for both existing keys, this time also 

617 # adding a new row. Pass them in in a different order. Only the new 

618 # row should be added. 

619 row2a = {"name": "a2", "region": region} 

620 row3 = {"name": "a3", "region": None} 

621 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

622 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2, row3]) 

623 # Add some data to a table with both a primary key and a different 

624 # unique constraint. 

625 row_b = {"id": 5, "name": "five", "value": 50} 

626 db.insert(tables.b, row_b) 

627 # Attempt ensure with primary_key_only=False and a conflict for the 

628 # non-PK constraint. This should do nothing. 

629 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}) 

630 self.assertEqual([r._asdict() for r in db.query(tables.b.select())], [row_b]) 

631 # Now use primary_key_only=True with conflict in only the non-PK field. 

632 # This should be an integrity error and nothing should change. 

633 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

634 db.ensure(tables.b, {"id": 10, "name": "five", "value": 200}, primary_key_only=True) 

635 self.assertEqual([r._asdict() for r in db.query(tables.b.select())], [row_b]) 

636 # With primary_key_only=True a conflict in the primary key is ignored 

637 # regardless of whether there is a conflict elsewhere. 

638 db.ensure(tables.b, {"id": 5, "name": "ten", "value": 100}, primary_key_only=True) 

639 self.assertEqual([r._asdict() for r in db.query(tables.b.select())], [row_b]) 

640 db.ensure(tables.b, {"id": 5, "name": "five", "value": 100}, primary_key_only=True) 

641 self.assertEqual([r._asdict() for r in db.query(tables.b.select())], [row_b]) 

642 

643 def testTransactionNesting(self): 

644 """Test that transactions can be nested with the behavior in the 

645 presence of exceptions working as documented. 

646 """ 

647 db = self.makeEmptyDatabase(origin=1) 

648 with db.declareStaticTables(create=True) as context: 

649 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

650 # Insert one row so we can trigger integrity errors by trying to insert 

651 # a duplicate of it below. 

652 db.insert(tables.a, {"name": "a1"}) 

653 # First test: error recovery via explicit savepoint=True in the inner 

654 # transaction. 

655 with db.transaction(): 

656 # This insert should succeed, and should not be rolled back because 

657 # the assertRaises context should catch any exception before it 

658 # propagates up to the outer transaction. 

659 db.insert(tables.a, {"name": "a2"}) 

660 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

661 with db.transaction(savepoint=True): 

662 # This insert should succeed, but should be rolled back. 

663 db.insert(tables.a, {"name": "a4"}) 

664 # This insert should fail (duplicate primary key), raising 

665 # an exception. 

666 db.insert(tables.a, {"name": "a1"}) 

667 self.assertCountEqual( 

668 [r._asdict() for r in db.query(tables.a.select())], 

669 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

670 ) 

671 # Second test: error recovery via implicit savepoint=True, when the 

672 # innermost transaction is inside a savepoint=True transaction. 

673 with db.transaction(): 

674 # This insert should succeed, and should not be rolled back 

675 # because the assertRaises context should catch any 

676 # exception before it propagates up to the outer 

677 # transaction. 

678 db.insert(tables.a, {"name": "a3"}) 

679 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

680 with db.transaction(savepoint=True): 

681 # This insert should succeed, but should be rolled back. 

682 db.insert(tables.a, {"name": "a4"}) 

683 with db.transaction(): 

684 # This insert should succeed, but should be rolled 

685 # back. 

686 db.insert(tables.a, {"name": "a5"}) 

687 # This insert should fail (duplicate primary key), 

688 # raising an exception. 

689 db.insert(tables.a, {"name": "a1"}) 

690 self.assertCountEqual( 

691 [r._asdict() for r in db.query(tables.a.select())], 

692 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

693 ) 

694 

695 def testTransactionLocking(self): 

696 """Test that `Database.transaction` can be used to acquire a lock 

697 that prohibits concurrent writes. 

698 """ 

699 db1 = self.makeEmptyDatabase(origin=1) 

700 with db1.declareStaticTables(create=True) as context: 

701 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

702 

703 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]: 

704 """One side of the concurrent locking test. 

705 

706 This optionally locks the table (and maybe the whole database), 

707 does a select for its contents, inserts a new row, and then selects 

708 again, with some waiting in between to make sure the other side has 

709 a chance to _attempt_ to insert in between. If the locking is 

710 enabled and works, the difference between the selects should just 

711 be the insert done on this thread. 

712 """ 

713 # Give Side2 a chance to create a connection 

714 await asyncio.sleep(1.0) 

715 with db1.transaction(lock=lock): 

716 names1 = {row.name for row in db1.query(tables1.a.select())} 

717 # Give Side2 a chance to insert (which will be blocked if 

718 # we've acquired a lock). 

719 await asyncio.sleep(2.0) 

720 db1.insert(tables1.a, {"name": "a1"}) 

721 names2 = {row.name for row in db1.query(tables1.a.select())} 

722 return names1, names2 

723 

724 async def side2() -> None: 

725 """The other side of the concurrent locking test. 

726 

727 This side just waits a bit and then tries to insert a row into the 

728 table that the other side is trying to lock. Hopefully that 

729 waiting is enough to give the other side a chance to acquire the 

730 lock and thus make this side block until the lock is released. If 

731 this side manages to do the insert before side1 acquires the lock, 

732 we'll just warn about not succeeding at testing the locking, 

733 because we can only make that unlikely, not impossible. 

734 """ 

735 

736 def toRunInThread(): 

737 """SQLite locking isn't asyncio-friendly unless we actually 

738 run it in another thread. And SQLite gets very unhappy if 

739 we try to use a connection from multiple threads, so we have 

740 to create the new connection here instead of out in the main 

741 body of the test function. 

742 """ 

743 db2 = self.getNewConnection(db1, writeable=True) 

744 with db2.declareStaticTables(create=False) as context: 

745 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

746 with db2.transaction(): 

747 db2.insert(tables2.a, {"name": "a2"}) 

748 

749 await asyncio.sleep(2.0) 

750 loop = asyncio.get_running_loop() 

751 with ThreadPoolExecutor() as pool: 

752 await loop.run_in_executor(pool, toRunInThread) 

753 

754 async def testProblemsWithNoLocking() -> None: 

755 """Run side1 and side2 with no locking, attempting to demonstrate 

756 the problem that locking is supposed to solve. If we get unlucky 

757 with scheduling, side2 will just happen to insert after side1 is 

758 done, and we won't have anything definitive. We just warn in that 

759 case because we really don't want spurious test failures. 

760 """ 

761 task1 = asyncio.create_task(side1()) 

762 task2 = asyncio.create_task(side2()) 

763 

764 names1, names2 = await task1 

765 await task2 

766 if "a2" in names1: 

767 warnings.warn( 

768 "Unlucky scheduling in no-locking test: concurrent INSERT " 

769 "happened before first SELECT." 

770 ) 

771 self.assertEqual(names1, {"a2"}) 

772 self.assertEqual(names2, {"a1", "a2"}) 

773 elif "a2" not in names2: 

774 warnings.warn( 

775 "Unlucky scheduling in no-locking test: concurrent INSERT " 

776 "happened after second SELECT even without locking." 

777 ) 

778 self.assertEqual(names1, set()) 

779 self.assertEqual(names2, {"a1"}) 

780 else: 

781 # This is the expected case: both INSERTS happen between the 

782 # two SELECTS. If we don't get this almost all of the time we 

783 # should adjust the sleep amounts. 

784 self.assertEqual(names1, set()) 

785 self.assertEqual(names2, {"a1", "a2"}) 

786 

787 asyncio.run(testProblemsWithNoLocking()) 

788 

789 # Clean up after first test. 

790 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

791 

792 async def testSolutionWithLocking() -> None: 

793 """Run side1 and side2 with locking, which should make side2 block 

794 its insert until side2 releases its lock. 

795 """ 

796 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

797 task2 = asyncio.create_task(side2()) 

798 

799 names1, names2 = await task1 

800 await task2 

801 if "a2" in names1: 

802 warnings.warn( 

803 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT." 

804 ) 

805 self.assertEqual(names1, {"a2"}) 

806 self.assertEqual(names2, {"a1", "a2"}) 

807 else: 

808 # This is the expected case: the side2 INSERT happens after the 

809 # last SELECT on side1. This can also happen due to unlucky 

810 # scheduling, and we have no way to detect that here, but the 

811 # similar "no-locking" test has at least some chance of being 

812 # affected by the same problem and warning about it. 

813 self.assertEqual(names1, set()) 

814 self.assertEqual(names2, {"a1"}) 

815 

816 asyncio.run(testSolutionWithLocking()) 

817 

818 def testTimespanDatabaseRepresentation(self): 

819 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

820 methods that interact with it. 

821 """ 

822 # Make some test timespans to play with, with the full suite of 

823 # topological relationships. 

824 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

825 offset = astropy.time.TimeDelta(60, format="sec") 

826 timestamps = [start + offset * n for n in range(3)] 

827 aTimespans = [Timespan(begin=None, end=None)] 

828 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

829 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

830 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

831 aTimespans.append(Timespan.makeEmpty()) 

832 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

833 # Make another list of timespans that span the full range but don't 

834 # overlap. This is a subset of the previous list. 

835 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

836 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:])) 

837 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

838 # Make a database and create a table with that database's timespan 

839 # representation. This one will have no exclusion constraint and 

840 # a nullable timespan. 

841 db = self.makeEmptyDatabase(origin=1) 

842 TimespanReprClass = db.getTimespanRepresentation() 

843 aSpec = ddl.TableSpec( 

844 fields=[ 

845 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

846 ], 

847 ) 

848 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

849 aSpec.fields.add(fieldSpec) 

850 with db.declareStaticTables(create=True) as context: 

851 aTable = context.addTable("a", aSpec) 

852 self.maxDiff = None 

853 

854 def convertRowForInsert(row: dict) -> dict: 

855 """Convert a row containing a Timespan instance into one suitable 

856 for insertion into the database. 

857 """ 

858 result = row.copy() 

859 ts = result.pop(TimespanReprClass.NAME) 

860 return TimespanReprClass.update(ts, result=result) 

861 

862 def convertRowFromSelect(row: dict) -> dict: 

863 """Convert a row from the database into one containing a Timespan. 

864 

865 Parameters 

866 ---------- 

867 row : `dict` 

868 Original row. 

869 

870 Returns 

871 ------- 

872 row : `dict` 

873 The updated row. 

874 """ 

875 result = row.copy() 

876 timespan = TimespanReprClass.extract(result) 

877 for name in TimespanReprClass.getFieldNames(): 

878 del result[name] 

879 result[TimespanReprClass.NAME] = timespan 

880 return result 

881 

882 # Insert rows into table A, in chunks just to make things interesting. 

883 # Include one with a NULL timespan. 

884 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

885 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

886 db.insert(aTable, convertRowForInsert(aRows[0])) 

887 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

888 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

889 # Add another one with a NULL timespan, but this time by invoking 

890 # the server-side default. 

891 aRows.append({"id": len(aRows)}) 

892 db.insert(aTable, aRows[-1]) 

893 aRows[-1][TimespanReprClass.NAME] = None 

894 # Test basic round-trip through database. 

895 self.assertEqual( 

896 aRows, 

897 [ 

898 convertRowFromSelect(row._asdict()) 

899 for row in db.query(aTable.select().order_by(aTable.columns.id)) 

900 ], 

901 ) 

902 # Create another table B with a not-null timespan and (if the database 

903 # supports it), an exclusion constraint. Use ensureTableExists this 

904 # time to check that mode of table creation vs. timespans. 

905 bSpec = ddl.TableSpec( 

906 fields=[ 

907 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

908 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

909 ], 

910 ) 

911 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

912 bSpec.fields.add(fieldSpec) 

913 if TimespanReprClass.hasExclusionConstraint(): 

914 bSpec.exclusion.add(("key", TimespanReprClass)) 

915 bTable = db.ensureTableExists("b", bSpec) 

916 # Insert rows into table B, again in chunks. Each Timespan appears 

917 # twice, but with different values for the 'key' field (which should 

918 # still be okay for any exclusion constraint we may have defined). 

919 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

920 offset = len(bRows) 

921 bRows.extend( 

922 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

923 ) 

924 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

925 db.insert(bTable, convertRowForInsert(bRows[2])) 

926 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

927 # Insert a row with no timespan into table B. This should invoke the 

928 # server-side default, which is a timespan over (-∞, ∞). We set 

929 # key=3 to avoid upsetting an exclusion constraint that might exist. 

930 bRows.append({"id": len(bRows), "key": 3}) 

931 db.insert(bTable, bRows[-1]) 

932 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

933 # Test basic round-trip through database. 

934 self.assertEqual( 

935 bRows, 

936 [ 

937 convertRowFromSelect(row._asdict()) 

938 for row in db.query(bTable.select().order_by(bTable.columns.id)) 

939 ], 

940 ) 

941 # Test that we can't insert timespan=None into this table. 

942 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

943 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})) 

944 # IFF this database supports exclusion constraints, test that they 

945 # also prevent inserts. 

946 if TimespanReprClass.hasExclusionConstraint(): 

947 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

948 db.insert( 

949 bTable, 

950 convertRowForInsert( 

951 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

952 ), 

953 ) 

954 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

955 db.insert( 

956 bTable, 

957 convertRowForInsert( 

958 { 

959 "id": len(bRows), 

960 "key": 1, 

961 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

962 } 

963 ), 

964 ) 

965 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

966 db.insert( 

967 bTable, 

968 convertRowForInsert( 

969 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

970 ), 

971 ) 

972 # Test NULL checks in SELECT queries, on both tables. 

973 aRepr = TimespanReprClass.fromSelectable(aTable) 

974 self.assertEqual( 

975 [row[TimespanReprClass.NAME] is None for row in aRows], 

976 [ 

977 row.f 

978 for row in db.query( 

979 sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

980 ) 

981 ], 

982 ) 

983 bRepr = TimespanReprClass.fromSelectable(bTable) 

984 self.assertEqual( 

985 [False for row in bRows], 

986 [ 

987 row.f 

988 for row in db.query( 

989 sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

990 ) 

991 ], 

992 ) 

993 # Test relationships expressions that relate in-database timespans to 

994 # Python-literal timespans, all from the more complete 'a' set; check 

995 # that this is consistent with Python-only relationship tests. 

996 for rhsRow in aRows: 

997 if rhsRow[TimespanReprClass.NAME] is None: 

998 continue 

999 with self.subTest(rhsRow=rhsRow): 

1000 expected = {} 

1001 for lhsRow in aRows: 

1002 if lhsRow[TimespanReprClass.NAME] is None: 

1003 expected[lhsRow["id"]] = (None, None, None, None) 

1004 else: 

1005 expected[lhsRow["id"]] = ( 

1006 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

1007 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

1008 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

1009 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

1010 ) 

1011 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1012 sql = sqlalchemy.sql.select( 

1013 aTable.columns.id.label("lhs"), 

1014 aRepr.overlaps(rhsRepr).label("overlaps"), 

1015 aRepr.contains(rhsRepr).label("contains"), 

1016 (aRepr < rhsRepr).label("less_than"), 

1017 (aRepr > rhsRepr).label("greater_than"), 

1018 ).select_from(aTable) 

1019 queried = { 

1020 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1021 for row in db.query(sql) 

1022 } 

1023 self.assertEqual(expected, queried) 

1024 # Test relationship expressions that relate in-database timespans to 

1025 # each other, all from the more complete 'a' set; check that this is 

1026 # consistent with Python-only relationship tests. 

1027 expected = {} 

1028 for lhs, rhs in itertools.product(aRows, aRows): 

1029 lhsT = lhs[TimespanReprClass.NAME] 

1030 rhsT = rhs[TimespanReprClass.NAME] 

1031 if lhsT is not None and rhsT is not None: 

1032 expected[lhs["id"], rhs["id"]] = ( 

1033 lhsT.overlaps(rhsT), 

1034 lhsT.contains(rhsT), 

1035 lhsT < rhsT, 

1036 lhsT > rhsT, 

1037 ) 

1038 else: 

1039 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1040 lhsSubquery = aTable.alias("lhs") 

1041 rhsSubquery = aTable.alias("rhs") 

1042 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery) 

1043 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery) 

1044 sql = sqlalchemy.sql.select( 

1045 lhsSubquery.columns.id.label("lhs"), 

1046 rhsSubquery.columns.id.label("rhs"), 

1047 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1048 lhsRepr.contains(rhsRepr).label("contains"), 

1049 (lhsRepr < rhsRepr).label("less_than"), 

1050 (lhsRepr > rhsRepr).label("greater_than"), 

1051 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1052 queried = { 

1053 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1054 for row in db.query(sql) 

1055 } 

1056 self.assertEqual(expected, queried) 

1057 # Test relationship expressions between in-database timespans and 

1058 # Python-literal instantaneous times. 

1059 for t in timestamps: 

1060 with self.subTest(t=t): 

1061 expected = {} 

1062 for lhsRow in aRows: 

1063 if lhsRow[TimespanReprClass.NAME] is None: 

1064 expected[lhsRow["id"]] = (None, None, None) 

1065 else: 

1066 expected[lhsRow["id"]] = ( 

1067 lhsRow[TimespanReprClass.NAME].contains(t), 

1068 lhsRow[TimespanReprClass.NAME] < t, 

1069 lhsRow[TimespanReprClass.NAME] > t, 

1070 ) 

1071 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1072 sql = sqlalchemy.sql.select( 

1073 aTable.columns.id.label("lhs"), 

1074 aRepr.contains(rhs).label("contains"), 

1075 (aRepr < rhs).label("less_than"), 

1076 (aRepr > rhs).label("greater_than"), 

1077 ).select_from(aTable) 

1078 queried = {row.lhs: (row.contains, row.less_than, row.greater_than) for row in db.query(sql)} 

1079 self.assertEqual(expected, queried)