Coverage for python/lsst/daf/butler/registry/tests/_database.py: 8%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

467 statements  

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25import asyncio 

26import itertools 

27import warnings 

28from abc import ABC, abstractmethod 

29from collections import namedtuple 

30from concurrent.futures import ThreadPoolExecutor 

31from contextlib import contextmanager 

32from typing import ContextManager, Iterable, Optional, Set, Tuple 

33 

34import astropy.time 

35import sqlalchemy 

36from lsst.sphgeom import ConvexPolygon, UnitVector3d 

37 

38from ...core import Timespan, ddl 

39from ..interfaces import Database, DatabaseConflictError, ReadOnlyDatabaseError, SchemaAlreadyDefinedError 

40 

41StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

42 

43STATIC_TABLE_SPECS = StaticTablesTuple( 

44 a=ddl.TableSpec( 

45 fields=[ 

46 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

47 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

48 ] 

49 ), 

50 b=ddl.TableSpec( 

51 fields=[ 

52 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

53 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

54 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

55 ], 

56 unique=[("name",)], 

57 ), 

58 c=ddl.TableSpec( 

59 fields=[ 

60 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

61 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

62 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

63 ], 

64 foreignKeys=[ 

65 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

66 ], 

67 ), 

68) 

69 

70DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

71 fields=[ 

72 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

73 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

74 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

75 ], 

76 foreignKeys=[ 

77 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"), 

78 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

79 ], 

80) 

81 

82TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

83 fields=[ 

84 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

85 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

86 ], 

87) 

88 

89 

90@contextmanager 

91def _patch_getExistingTable(db: Database) -> Database: 

92 """Patch getExistingTable method in a database instance to test concurrent 

93 creation of tables. This patch obviously depends on knowning internals of 

94 ``ensureTableExists()`` implementation. 

95 """ 

96 original_method = db.getExistingTable 

97 

98 def _getExistingTable(name: str, spec: ddl.TableSpec) -> Optional[sqlalchemy.schema.Table]: 

99 # Return None on first call, but forward to original method after that 

100 db.getExistingTable = original_method 

101 return None 

102 

103 db.getExistingTable = _getExistingTable 

104 yield db 

105 db.getExistingTable = original_method 

106 

107 

108class DatabaseTests(ABC): 

109 """Generic tests for the `Database` interface that can be subclassed to 

110 generate tests for concrete implementations. 

111 """ 

112 

113 @abstractmethod 

114 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

115 """Return an empty `Database` with the given origin, or an 

116 automatically-generated one if ``origin`` is `None`. 

117 """ 

118 raise NotImplementedError() 

119 

120 @abstractmethod 

121 def asReadOnly(self, database: Database) -> ContextManager[Database]: 

122 """Return a context manager for a read-only connection into the given 

123 database. 

124 

125 The original database should be considered unusable within the context 

126 but safe to use again afterwards (this allows the context manager to 

127 block write access by temporarily changing user permissions to really 

128 guarantee that write operations are not performed). 

129 """ 

130 raise NotImplementedError() 

131 

132 @abstractmethod 

133 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

134 """Return a new `Database` instance that points to the same underlying 

135 storage as the given one. 

136 """ 

137 raise NotImplementedError() 

138 

139 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

140 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

141 # Checking more than this currently seems fragile, as it might restrict 

142 # what Database implementations do; we don't care if the spec is 

143 # actually preserved in terms of types and constraints as long as we 

144 # can use the returned table as if it was. 

145 

146 def checkStaticSchema(self, tables: StaticTablesTuple): 

147 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

148 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

149 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

150 

151 def testDeclareStaticTables(self): 

152 """Tests for `Database.declareStaticSchema` and the methods it 

153 delegates to. 

154 """ 

155 # Create the static schema in a new, empty database. 

156 newDatabase = self.makeEmptyDatabase() 

157 with newDatabase.declareStaticTables(create=True) as context: 

158 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

159 self.checkStaticSchema(tables) 

160 # Check that we can load that schema even from a read-only connection. 

161 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

162 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

163 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

164 self.checkStaticSchema(tables) 

165 

166 def testDeclareStaticTablesTwice(self): 

167 """Tests for `Database.declareStaticSchema` being called twice.""" 

168 # Create the static schema in a new, empty database. 

169 newDatabase = self.makeEmptyDatabase() 

170 with newDatabase.declareStaticTables(create=True) as context: 

171 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

172 self.checkStaticSchema(tables) 

173 # Second time it should raise 

174 with self.assertRaises(SchemaAlreadyDefinedError): 

175 with newDatabase.declareStaticTables(create=True) as context: 

176 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

177 # Check schema, it should still contain all tables, and maybe some 

178 # extra. 

179 with newDatabase.declareStaticTables(create=False) as context: 

180 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

181 

182 def testRepr(self): 

183 """Test that repr does not return a generic thing.""" 

184 newDatabase = self.makeEmptyDatabase() 

185 rep = repr(newDatabase) 

186 # Check that stringification works and gives us something different 

187 self.assertNotEqual(rep, str(newDatabase)) 

188 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

189 self.assertIn("://", rep) 

190 

191 def testDynamicTables(self): 

192 """Tests for `Database.ensureTableExists` and 

193 `Database.getExistingTable`. 

194 """ 

195 # Need to start with the static schema. 

196 newDatabase = self.makeEmptyDatabase() 

197 with newDatabase.declareStaticTables(create=True) as context: 

198 context.addTableTuple(STATIC_TABLE_SPECS) 

199 # Try to ensure the dynamic table exists in a read-only version of that 

200 # database, which should fail because we can't create it. 

201 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

202 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

203 context.addTableTuple(STATIC_TABLE_SPECS) 

204 with self.assertRaises(ReadOnlyDatabaseError): 

205 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

206 # Just getting the dynamic table before it exists should return None. 

207 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

208 # Ensure the new table exists back in the original database, which 

209 # should create it. 

210 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

211 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

212 # Ensuring that it exists should just return the exact same table 

213 # instance again. 

214 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

215 # Try again from the read-only database. 

216 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

217 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

218 context.addTableTuple(STATIC_TABLE_SPECS) 

219 # Just getting the dynamic table should now work... 

220 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

221 # ...as should ensuring that it exists, since it now does. 

222 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

223 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

224 # Trying to get the table with a different specification (at least 

225 # in terms of what columns are present) should raise. 

226 with self.assertRaises(DatabaseConflictError): 

227 newDatabase.ensureTableExists( 

228 "d", 

229 ddl.TableSpec( 

230 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

231 ), 

232 ) 

233 # Calling ensureTableExists inside a transaction block is an error, 

234 # even if it would do nothing. 

235 with newDatabase.transaction(): 

236 with self.assertRaises(AssertionError): 

237 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

238 

239 def testDynamicTablesConcurrency(self): 

240 """Tests for `Database.ensureTableExists` concurrent use.""" 

241 # We cannot really run things concurrently in a deterministic way, here 

242 # we just simulate a situation when the table is created by other 

243 # process between the call to getExistingTable() and actual table 

244 # creation. 

245 db1 = self.makeEmptyDatabase() 

246 with db1.declareStaticTables(create=True) as context: 

247 context.addTableTuple(STATIC_TABLE_SPECS) 

248 self.assertIsNone(db1.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

249 

250 # Make a dynamic table using separate connection 

251 db2 = self.getNewConnection(db1, writeable=True) 

252 with db2.declareStaticTables(create=False) as context: 

253 context.addTableTuple(STATIC_TABLE_SPECS) 

254 table = db2.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

255 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

256 

257 # Call it again but trick it into thinking that table is not there. 

258 # This test depends on knowing implementation of ensureTableExists() 

259 # which initially calls getExistingTable() to check that table may 

260 # exist, the patch intercepts that call and returns None. 

261 with _patch_getExistingTable(db1): 

262 table = db1.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

263 

264 def testTemporaryTables(self): 

265 """Tests for `Database.makeTemporaryTable`, 

266 `Database.dropTemporaryTable`, and `Database.insert` with 

267 the ``select`` argument. 

268 """ 

269 # Need to start with the static schema; also insert some test data. 

270 newDatabase = self.makeEmptyDatabase() 

271 with newDatabase.declareStaticTables(create=True) as context: 

272 static = context.addTableTuple(STATIC_TABLE_SPECS) 

273 newDatabase.insert(static.a, {"name": "a1", "region": None}, {"name": "a2", "region": None}) 

274 bIds = newDatabase.insert( 

275 static.b, 

276 {"name": "b1", "value": 11}, 

277 {"name": "b2", "value": 12}, 

278 {"name": "b3", "value": 13}, 

279 returnIds=True, 

280 ) 

281 # Create the table. 

282 with newDatabase.session() as session: 

283 table1 = session.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1") 

284 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

285 # Insert via a INSERT INTO ... SELECT query. 

286 newDatabase.insert( 

287 table1, 

288 select=sqlalchemy.sql.select( 

289 static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id") 

290 ) 

291 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

292 .where( 

293 sqlalchemy.sql.and_( 

294 static.a.columns.name == "a1", 

295 static.b.columns.value <= 12, 

296 ) 

297 ), 

298 ) 

299 # Check that the inserted rows are present. 

300 self.assertCountEqual( 

301 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

302 [row._asdict() for row in newDatabase.query(table1.select())], 

303 ) 

304 # Create another one via a read-only connection to the database. 

305 # We _do_ allow temporary table modifications in read-only 

306 # databases. 

307 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

308 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

309 context.addTableTuple(STATIC_TABLE_SPECS) 

310 with existingReadOnlyDatabase.session() as session2: 

311 table2 = session2.makeTemporaryTable(TEMPORARY_TABLE_SPEC) 

312 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

313 # Those tables should not be the same, despite having the 

314 # same ddl. 

315 self.assertIsNot(table1, table2) 

316 # Do a slightly different insert into this table, to check 

317 # that it works in a read-only database. This time we 

318 # pass column names as a kwarg to insert instead of by 

319 # labeling the columns in the select. 

320 existingReadOnlyDatabase.insert( 

321 table2, 

322 select=sqlalchemy.sql.select(static.a.columns.name, static.b.columns.id) 

323 .select_from(static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))) 

324 .where( 

325 sqlalchemy.sql.and_( 

326 static.a.columns.name == "a2", 

327 static.b.columns.value >= 12, 

328 ) 

329 ), 

330 names=["a_name", "b_id"], 

331 ) 

332 # Check that the inserted rows are present. 

333 self.assertCountEqual( 

334 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

335 [row._asdict() for row in existingReadOnlyDatabase.query(table2.select())], 

336 ) 

337 # Drop the temporary table from the read-only DB. It's 

338 # unspecified whether attempting to use it after this 

339 # point is an error or just never returns any results, so 

340 # we can't test what it does, only that it's not an error. 

341 session2.dropTemporaryTable(table2) 

342 # Drop the original temporary table. 

343 session.dropTemporaryTable(table1) 

344 

345 def testSchemaSeparation(self): 

346 """Test that creating two different `Database` instances allows us 

347 to create different tables with the same name in each. 

348 """ 

349 db1 = self.makeEmptyDatabase(origin=1) 

350 with db1.declareStaticTables(create=True) as context: 

351 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

352 self.checkStaticSchema(tables) 

353 

354 db2 = self.makeEmptyDatabase(origin=2) 

355 # Make the DDL here intentionally different so we'll definitely 

356 # notice if db1 and db2 are pointing at the same schema. 

357 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

358 with db2.declareStaticTables(create=True) as context: 

359 # Make the DDL here intentionally different so we'll definitely 

360 # notice if db1 and db2 are pointing at the same schema. 

361 table = context.addTable("a", spec) 

362 self.checkTable(spec, table) 

363 

364 def testInsertQueryDelete(self): 

365 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

366 methods, as well as the `Base64Region` type and the ``onDelete`` 

367 argument to `ddl.ForeignKeySpec`. 

368 """ 

369 db = self.makeEmptyDatabase(origin=1) 

370 with db.declareStaticTables(create=True) as context: 

371 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

372 # Insert a single, non-autoincrement row that contains a region and 

373 # query to get it back. 

374 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

375 row = {"name": "a1", "region": region} 

376 db.insert(tables.a, row) 

377 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row]) 

378 # Insert multiple autoincrement rows but do not try to get the IDs 

379 # back immediately. 

380 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

381 results = [r._asdict() for r in db.query(tables.b.select().order_by("id"))] 

382 self.assertEqual(len(results), 2) 

383 for row in results: 

384 self.assertIn(row["name"], ("b1", "b2")) 

385 self.assertIsInstance(row["id"], int) 

386 self.assertGreater(results[1]["id"], results[0]["id"]) 

387 # Insert multiple autoincrement rows and get the IDs back from insert. 

388 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

389 ids = db.insert(tables.b, *rows, returnIds=True) 

390 results = [ 

391 r._asdict() for r in db.query(tables.b.select().where(tables.b.columns.id > results[1]["id"])) 

392 ] 

393 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

394 self.assertCountEqual(results, expected) 

395 self.assertTrue(all(result["id"] is not None for result in results)) 

396 # Insert multiple rows into a table with an autoincrement+origin 

397 # primary key, then use the returned IDs to insert into a dynamic 

398 # table. 

399 rows = [{"origin": db.origin, "b_id": results[0]["id"]}, {"origin": db.origin, "b_id": None}] 

400 ids = db.insert(tables.c, *rows, returnIds=True) 

401 results = [r._asdict() for r in db.query(tables.c.select())] 

402 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

403 self.assertCountEqual(results, expected) 

404 self.assertTrue(all(result["id"] is not None for result in results)) 

405 # Add the dynamic table. 

406 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

407 # Insert into it. 

408 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids] 

409 db.insert(d, *rows) 

410 results = [r._asdict() for r in db.query(d.select())] 

411 self.assertCountEqual(rows, results) 

412 # Insert multiple rows into a table with an autoincrement+origin 

413 # primary key (this is especially tricky for SQLite, but good to test 

414 # for all DBs), but pass in a value for the autoincrement key. 

415 # For extra complexity, we re-use the autoincrement value with a 

416 # different value for origin. 

417 rows2 = [ 

418 {"id": 700, "origin": db.origin, "b_id": None}, 

419 {"id": 700, "origin": 60, "b_id": None}, 

420 {"id": 1, "origin": 60, "b_id": None}, 

421 ] 

422 db.insert(tables.c, *rows2) 

423 results = [r._asdict() for r in db.query(tables.c.select())] 

424 self.assertCountEqual(results, expected + rows2) 

425 self.assertTrue(all(result["id"] is not None for result in results)) 

426 

427 # Define 'SELECT COUNT(*)' query for later use. 

428 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

429 # Get the values we inserted into table b. 

430 bValues = [r._asdict() for r in db.query(tables.b.select())] 

431 # Remove two row from table b by ID. 

432 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

433 self.assertEqual(n, 2) 

434 # Remove the other two rows from table b by name. 

435 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

436 self.assertEqual(n, 2) 

437 # There should now be no rows in table b. 

438 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0) 

439 # All b_id values in table c should now be NULL, because there's an 

440 # onDelete='SET NULL' foreign key. 

441 self.assertEqual( 

442 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711 

443 0, 

444 ) 

445 # Remove all rows in table a (there's only one); this should remove all 

446 # rows in d due to onDelete='CASCADE'. 

447 n = db.delete(tables.a, []) 

448 self.assertEqual(n, 1) 

449 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0) 

450 self.assertEqual(db.query(count.select_from(d)).scalar(), 0) 

451 

452 def testDeleteWhere(self): 

453 """Tests for `Database.deleteWhere`.""" 

454 db = self.makeEmptyDatabase(origin=1) 

455 with db.declareStaticTables(create=True) as context: 

456 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

457 db.insert(tables.b, *[{"id": i, "name": f"b{i}"} for i in range(10)]) 

458 count = sqlalchemy.sql.select(sqlalchemy.sql.func.count()) 

459 

460 n = db.deleteWhere(tables.b, tables.b.columns.id.in_([0, 1, 2])) 

461 self.assertEqual(n, 3) 

462 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 7) 

463 

464 n = db.deleteWhere( 

465 tables.b, 

466 tables.b.columns.id.in_( 

467 sqlalchemy.sql.select(tables.b.columns.id).where(tables.b.columns.id > 5) 

468 ), 

469 ) 

470 self.assertEqual(n, 4) 

471 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 3) 

472 

473 n = db.deleteWhere(tables.b, tables.b.columns.name == "b5") 

474 self.assertEqual(n, 1) 

475 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 2) 

476 

477 n = db.deleteWhere(tables.b, sqlalchemy.sql.literal(True)) 

478 self.assertEqual(n, 2) 

479 self.assertEqual(db.query(count.select_from(tables.b)).scalar(), 0) 

480 

481 def testUpdate(self): 

482 """Tests for `Database.update`.""" 

483 db = self.makeEmptyDatabase(origin=1) 

484 with db.declareStaticTables(create=True) as context: 

485 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

486 # Insert two rows into table a, both without regions. 

487 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

488 # Update one of the rows with a region. 

489 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

490 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

491 self.assertEqual(n, 1) 

492 sql = sqlalchemy.sql.select(tables.a.columns.name, tables.a.columns.region).select_from(tables.a) 

493 self.assertCountEqual( 

494 [r._asdict() for r in db.query(sql)], 

495 [{"name": "a1", "region": None}, {"name": "a2", "region": region}], 

496 ) 

497 

498 def testSync(self): 

499 """Tests for `Database.sync`.""" 

500 db = self.makeEmptyDatabase(origin=1) 

501 with db.declareStaticTables(create=True) as context: 

502 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

503 # Insert a row with sync, because it doesn't exist yet. 

504 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

505 self.assertTrue(inserted) 

506 self.assertEqual( 

507 [{"id": values["id"], "name": "b1", "value": 10}], 

508 [r._asdict() for r in db.query(tables.b.select())], 

509 ) 

510 # Repeat that operation, which should do nothing but return the 

511 # requested values. 

512 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

513 self.assertFalse(inserted) 

514 self.assertEqual( 

515 [{"id": values["id"], "name": "b1", "value": 10}], 

516 [r._asdict() for r in db.query(tables.b.select())], 

517 ) 

518 # Repeat the operation without the 'extra' arg, which should also just 

519 # return the existing row. 

520 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

521 self.assertFalse(inserted) 

522 self.assertEqual( 

523 [{"id": values["id"], "name": "b1", "value": 10}], 

524 [r._asdict() for r in db.query(tables.b.select())], 

525 ) 

526 # Repeat the operation with a different value in 'extra'. That still 

527 # shouldn't be an error, because 'extra' is only used if we really do 

528 # insert. Also drop the 'returning' argument. 

529 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

530 self.assertFalse(inserted) 

531 self.assertEqual( 

532 [{"id": values["id"], "name": "b1", "value": 10}], 

533 [r._asdict() for r in db.query(tables.b.select())], 

534 ) 

535 # Repeat the operation with the correct value in 'compared' instead of 

536 # 'extra'. 

537 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

538 self.assertFalse(inserted) 

539 self.assertEqual( 

540 [{"id": values["id"], "name": "b1", "value": 10}], 

541 [r._asdict() for r in db.query(tables.b.select())], 

542 ) 

543 # Repeat the operation with an incorrect value in 'compared'; this 

544 # should raise. 

545 with self.assertRaises(DatabaseConflictError): 

546 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

547 # Try to sync in a read-only database. This should work if and only 

548 # if the matching row already exists. 

549 with self.asReadOnly(db) as rodb: 

550 with rodb.declareStaticTables(create=False) as context: 

551 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

552 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

553 self.assertFalse(inserted) 

554 self.assertEqual( 

555 [{"id": values["id"], "name": "b1", "value": 10}], 

556 [r._asdict() for r in rodb.query(tables.b.select())], 

557 ) 

558 with self.assertRaises(ReadOnlyDatabaseError): 

559 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

560 # Repeat the operation with a different value in 'compared' and ask to 

561 # update. 

562 _, updated = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}, update=True) 

563 self.assertEqual(updated, {"value": 10}) 

564 self.assertEqual( 

565 [{"id": values["id"], "name": "b1", "value": 20}], 

566 [r._asdict() for r in db.query(tables.b.select())], 

567 ) 

568 

569 def testReplace(self): 

570 """Tests for `Database.replace`.""" 

571 db = self.makeEmptyDatabase(origin=1) 

572 with db.declareStaticTables(create=True) as context: 

573 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

574 # Use 'replace' to insert a single row that contains a region and 

575 # query to get it back. 

576 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

577 row1 = {"name": "a1", "region": region} 

578 db.replace(tables.a, row1) 

579 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1]) 

580 # Insert another row without a region. 

581 row2 = {"name": "a2", "region": None} 

582 db.replace(tables.a, row2) 

583 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

584 # Use replace to re-insert both of those rows again, which should do 

585 # nothing. 

586 db.replace(tables.a, row1, row2) 

587 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

588 # Replace row1 with a row with no region, while reinserting row2. 

589 row1a = {"name": "a1", "region": None} 

590 db.replace(tables.a, row1a, row2) 

591 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1a, row2]) 

592 # Replace both rows, returning row1 to its original state, while adding 

593 # a new one. Pass them in in a different order. 

594 row2a = {"name": "a2", "region": region} 

595 row3 = {"name": "a3", "region": None} 

596 db.replace(tables.a, row3, row2a, row1) 

597 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2a, row3]) 

598 

599 def testEnsure(self): 

600 """Tests for `Database.ensure`.""" 

601 db = self.makeEmptyDatabase(origin=1) 

602 with db.declareStaticTables(create=True) as context: 

603 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

604 # Use 'ensure' to insert a single row that contains a region and 

605 # query to get it back. 

606 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

607 row1 = {"name": "a1", "region": region} 

608 self.assertEqual(db.ensure(tables.a, row1), 1) 

609 self.assertEqual([r._asdict() for r in db.query(tables.a.select())], [row1]) 

610 # Insert another row without a region. 

611 row2 = {"name": "a2", "region": None} 

612 self.assertEqual(db.ensure(tables.a, row2), 1) 

613 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

614 # Use ensure to re-insert both of those rows again, which should do 

615 # nothing. 

616 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

617 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

618 # Attempt to insert row1's key with no region, while 

619 # reinserting row2. This should also do nothing. 

620 row1a = {"name": "a1", "region": None} 

621 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

622 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2]) 

623 # Attempt to insert new rows for both existing keys, this time also 

624 # adding a new row. Pass them in in a different order. Only the new 

625 # row should be added. 

626 row2a = {"name": "a2", "region": region} 

627 row3 = {"name": "a3", "region": None} 

628 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

629 self.assertCountEqual([r._asdict() for r in db.query(tables.a.select())], [row1, row2, row3]) 

630 

631 def testTransactionNesting(self): 

632 """Test that transactions can be nested with the behavior in the 

633 presence of exceptions working as documented. 

634 """ 

635 db = self.makeEmptyDatabase(origin=1) 

636 with db.declareStaticTables(create=True) as context: 

637 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

638 # Insert one row so we can trigger integrity errors by trying to insert 

639 # a duplicate of it below. 

640 db.insert(tables.a, {"name": "a1"}) 

641 # First test: error recovery via explicit savepoint=True in the inner 

642 # transaction. 

643 with db.transaction(): 

644 # This insert should succeed, and should not be rolled back because 

645 # the assertRaises context should catch any exception before it 

646 # propagates up to the outer transaction. 

647 db.insert(tables.a, {"name": "a2"}) 

648 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

649 with db.transaction(savepoint=True): 

650 # This insert should succeed, but should be rolled back. 

651 db.insert(tables.a, {"name": "a4"}) 

652 # This insert should fail (duplicate primary key), raising 

653 # an exception. 

654 db.insert(tables.a, {"name": "a1"}) 

655 self.assertCountEqual( 

656 [r._asdict() for r in db.query(tables.a.select())], 

657 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

658 ) 

659 # Second test: error recovery via implicit savepoint=True, when the 

660 # innermost transaction is inside a savepoint=True transaction. 

661 with db.transaction(): 

662 # This insert should succeed, and should not be rolled back 

663 # because the assertRaises context should catch any 

664 # exception before it propagates up to the outer 

665 # transaction. 

666 db.insert(tables.a, {"name": "a3"}) 

667 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

668 with db.transaction(savepoint=True): 

669 # This insert should succeed, but should be rolled back. 

670 db.insert(tables.a, {"name": "a4"}) 

671 with db.transaction(): 

672 # This insert should succeed, but should be rolled 

673 # back. 

674 db.insert(tables.a, {"name": "a5"}) 

675 # This insert should fail (duplicate primary key), 

676 # raising an exception. 

677 db.insert(tables.a, {"name": "a1"}) 

678 self.assertCountEqual( 

679 [r._asdict() for r in db.query(tables.a.select())], 

680 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

681 ) 

682 

683 def testTransactionLocking(self): 

684 """Test that `Database.transaction` can be used to acquire a lock 

685 that prohibits concurrent writes. 

686 """ 

687 db1 = self.makeEmptyDatabase(origin=1) 

688 with db1.declareStaticTables(create=True) as context: 

689 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

690 

691 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]: 

692 """One side of the concurrent locking test. 

693 

694 This optionally locks the table (and maybe the whole database), 

695 does a select for its contents, inserts a new row, and then selects 

696 again, with some waiting in between to make sure the other side has 

697 a chance to _attempt_ to insert in between. If the locking is 

698 enabled and works, the difference between the selects should just 

699 be the insert done on this thread. 

700 """ 

701 # Give Side2 a chance to create a connection 

702 await asyncio.sleep(1.0) 

703 with db1.transaction(lock=lock): 

704 names1 = {row.name for row in db1.query(tables1.a.select())} 

705 # Give Side2 a chance to insert (which will be blocked if 

706 # we've acquired a lock). 

707 await asyncio.sleep(2.0) 

708 db1.insert(tables1.a, {"name": "a1"}) 

709 names2 = {row.name for row in db1.query(tables1.a.select())} 

710 return names1, names2 

711 

712 async def side2() -> None: 

713 """The other side of the concurrent locking test. 

714 

715 This side just waits a bit and then tries to insert a row into the 

716 table that the other side is trying to lock. Hopefully that 

717 waiting is enough to give the other side a chance to acquire the 

718 lock and thus make this side block until the lock is released. If 

719 this side manages to do the insert before side1 acquires the lock, 

720 we'll just warn about not succeeding at testing the locking, 

721 because we can only make that unlikely, not impossible. 

722 """ 

723 

724 def toRunInThread(): 

725 """SQLite locking isn't asyncio-friendly unless we actually 

726 run it in another thread. And SQLite gets very unhappy if 

727 we try to use a connection from multiple threads, so we have 

728 to create the new connection here instead of out in the main 

729 body of the test function. 

730 """ 

731 db2 = self.getNewConnection(db1, writeable=True) 

732 with db2.declareStaticTables(create=False) as context: 

733 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

734 with db2.transaction(): 

735 db2.insert(tables2.a, {"name": "a2"}) 

736 

737 await asyncio.sleep(2.0) 

738 loop = asyncio.get_running_loop() 

739 with ThreadPoolExecutor() as pool: 

740 await loop.run_in_executor(pool, toRunInThread) 

741 

742 async def testProblemsWithNoLocking() -> None: 

743 """Run side1 and side2 with no locking, attempting to demonstrate 

744 the problem that locking is supposed to solve. If we get unlucky 

745 with scheduling, side2 will just happen to insert after side1 is 

746 done, and we won't have anything definitive. We just warn in that 

747 case because we really don't want spurious test failures. 

748 """ 

749 task1 = asyncio.create_task(side1()) 

750 task2 = asyncio.create_task(side2()) 

751 

752 names1, names2 = await task1 

753 await task2 

754 if "a2" in names1: 

755 warnings.warn( 

756 "Unlucky scheduling in no-locking test: concurrent INSERT " 

757 "happened before first SELECT." 

758 ) 

759 self.assertEqual(names1, {"a2"}) 

760 self.assertEqual(names2, {"a1", "a2"}) 

761 elif "a2" not in names2: 

762 warnings.warn( 

763 "Unlucky scheduling in no-locking test: concurrent INSERT " 

764 "happened after second SELECT even without locking." 

765 ) 

766 self.assertEqual(names1, set()) 

767 self.assertEqual(names2, {"a1"}) 

768 else: 

769 # This is the expected case: both INSERTS happen between the 

770 # two SELECTS. If we don't get this almost all of the time we 

771 # should adjust the sleep amounts. 

772 self.assertEqual(names1, set()) 

773 self.assertEqual(names2, {"a1", "a2"}) 

774 

775 asyncio.run(testProblemsWithNoLocking()) 

776 

777 # Clean up after first test. 

778 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

779 

780 async def testSolutionWithLocking() -> None: 

781 """Run side1 and side2 with locking, which should make side2 block 

782 its insert until side2 releases its lock. 

783 """ 

784 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

785 task2 = asyncio.create_task(side2()) 

786 

787 names1, names2 = await task1 

788 await task2 

789 if "a2" in names1: 

790 warnings.warn( 

791 "Unlucky scheduling in locking test: concurrent INSERT happened before first SELECT." 

792 ) 

793 self.assertEqual(names1, {"a2"}) 

794 self.assertEqual(names2, {"a1", "a2"}) 

795 else: 

796 # This is the expected case: the side2 INSERT happens after the 

797 # last SELECT on side1. This can also happen due to unlucky 

798 # scheduling, and we have no way to detect that here, but the 

799 # similar "no-locking" test has at least some chance of being 

800 # affected by the same problem and warning about it. 

801 self.assertEqual(names1, set()) 

802 self.assertEqual(names2, {"a1"}) 

803 

804 asyncio.run(testSolutionWithLocking()) 

805 

806 def testTimespanDatabaseRepresentation(self): 

807 """Tests for `TimespanDatabaseRepresentation` and the `Database` 

808 methods that interact with it. 

809 """ 

810 # Make some test timespans to play with, with the full suite of 

811 # topological relationships. 

812 start = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

813 offset = astropy.time.TimeDelta(60, format="sec") 

814 timestamps = [start + offset * n for n in range(3)] 

815 aTimespans = [Timespan(begin=None, end=None)] 

816 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

817 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

818 aTimespans.extend(Timespan.fromInstant(t) for t in timestamps) 

819 aTimespans.append(Timespan.makeEmpty()) 

820 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

821 # Make another list of timespans that span the full range but don't 

822 # overlap. This is a subset of the previous list. 

823 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

824 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:])) 

825 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

826 # Make a database and create a table with that database's timespan 

827 # representation. This one will have no exclusion constraint and 

828 # a nullable timespan. 

829 db = self.makeEmptyDatabase(origin=1) 

830 TimespanReprClass = db.getTimespanRepresentation() 

831 aSpec = ddl.TableSpec( 

832 fields=[ 

833 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

834 ], 

835 ) 

836 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True): 

837 aSpec.fields.add(fieldSpec) 

838 with db.declareStaticTables(create=True) as context: 

839 aTable = context.addTable("a", aSpec) 

840 self.maxDiff = None 

841 

842 def convertRowForInsert(row: dict) -> dict: 

843 """Convert a row containing a Timespan instance into one suitable 

844 for insertion into the database. 

845 """ 

846 result = row.copy() 

847 ts = result.pop(TimespanReprClass.NAME) 

848 return TimespanReprClass.update(ts, result=result) 

849 

850 def convertRowFromSelect(row: dict) -> dict: 

851 """Convert a row from the database into one containing a Timespan. 

852 

853 Parameters 

854 ---------- 

855 row : `dict` 

856 Original row. 

857 

858 Returns 

859 ------- 

860 row : `dict` 

861 The updated row. 

862 """ 

863 result = row.copy() 

864 timespan = TimespanReprClass.extract(result) 

865 for name in TimespanReprClass.getFieldNames(): 

866 del result[name] 

867 result[TimespanReprClass.NAME] = timespan 

868 return result 

869 

870 # Insert rows into table A, in chunks just to make things interesting. 

871 # Include one with a NULL timespan. 

872 aRows = [{"id": n, TimespanReprClass.NAME: t} for n, t in enumerate(aTimespans)] 

873 aRows.append({"id": len(aRows), TimespanReprClass.NAME: None}) 

874 db.insert(aTable, convertRowForInsert(aRows[0])) 

875 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

876 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

877 # Add another one with a NULL timespan, but this time by invoking 

878 # the server-side default. 

879 aRows.append({"id": len(aRows)}) 

880 db.insert(aTable, aRows[-1]) 

881 aRows[-1][TimespanReprClass.NAME] = None 

882 # Test basic round-trip through database. 

883 self.assertEqual( 

884 aRows, 

885 [ 

886 convertRowFromSelect(row._asdict()) 

887 for row in db.query(aTable.select().order_by(aTable.columns.id)) 

888 ], 

889 ) 

890 # Create another table B with a not-null timespan and (if the database 

891 # supports it), an exclusion constraint. Use ensureTableExists this 

892 # time to check that mode of table creation vs. timespans. 

893 bSpec = ddl.TableSpec( 

894 fields=[ 

895 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

896 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

897 ], 

898 ) 

899 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=False): 

900 bSpec.fields.add(fieldSpec) 

901 if TimespanReprClass.hasExclusionConstraint(): 

902 bSpec.exclusion.add(("key", TimespanReprClass)) 

903 bTable = db.ensureTableExists("b", bSpec) 

904 # Insert rows into table B, again in chunks. Each Timespan appears 

905 # twice, but with different values for the 'key' field (which should 

906 # still be okay for any exclusion constraint we may have defined). 

907 bRows = [{"id": n, "key": 1, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans)] 

908 offset = len(bRows) 

909 bRows.extend( 

910 {"id": n + offset, "key": 2, TimespanReprClass.NAME: t} for n, t in enumerate(bTimespans) 

911 ) 

912 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

913 db.insert(bTable, convertRowForInsert(bRows[2])) 

914 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

915 # Insert a row with no timespan into table B. This should invoke the 

916 # server-side default, which is a timespan over (-∞, ∞). We set 

917 # key=3 to avoid upsetting an exclusion constraint that might exist. 

918 bRows.append({"id": len(bRows), "key": 3}) 

919 db.insert(bTable, bRows[-1]) 

920 bRows[-1][TimespanReprClass.NAME] = Timespan(None, None) 

921 # Test basic round-trip through database. 

922 self.assertEqual( 

923 bRows, 

924 [ 

925 convertRowFromSelect(row._asdict()) 

926 for row in db.query(bTable.select().order_by(bTable.columns.id)) 

927 ], 

928 ) 

929 # Test that we can't insert timespan=None into this table. 

930 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

931 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, TimespanReprClass.NAME: None})) 

932 # IFF this database supports exclusion constraints, test that they 

933 # also prevent inserts. 

934 if TimespanReprClass.hasExclusionConstraint(): 

935 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

936 db.insert( 

937 bTable, 

938 convertRowForInsert( 

939 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(None, timestamps[1])} 

940 ), 

941 ) 

942 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

943 db.insert( 

944 bTable, 

945 convertRowForInsert( 

946 { 

947 "id": len(bRows), 

948 "key": 1, 

949 TimespanReprClass.NAME: Timespan(timestamps[0], timestamps[2]), 

950 } 

951 ), 

952 ) 

953 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

954 db.insert( 

955 bTable, 

956 convertRowForInsert( 

957 {"id": len(bRows), "key": 1, TimespanReprClass.NAME: Timespan(timestamps[2], None)} 

958 ), 

959 ) 

960 # Test NULL checks in SELECT queries, on both tables. 

961 aRepr = TimespanReprClass.fromSelectable(aTable) 

962 self.assertEqual( 

963 [row[TimespanReprClass.NAME] is None for row in aRows], 

964 [ 

965 row.f 

966 for row in db.query( 

967 sqlalchemy.sql.select(aRepr.isNull().label("f")).order_by(aTable.columns.id) 

968 ) 

969 ], 

970 ) 

971 bRepr = TimespanReprClass.fromSelectable(bTable) 

972 self.assertEqual( 

973 [False for row in bRows], 

974 [ 

975 row.f 

976 for row in db.query( 

977 sqlalchemy.sql.select(bRepr.isNull().label("f")).order_by(bTable.columns.id) 

978 ) 

979 ], 

980 ) 

981 # Test relationships expressions that relate in-database timespans to 

982 # Python-literal timespans, all from the more complete 'a' set; check 

983 # that this is consistent with Python-only relationship tests. 

984 for rhsRow in aRows: 

985 if rhsRow[TimespanReprClass.NAME] is None: 

986 continue 

987 with self.subTest(rhsRow=rhsRow): 

988 expected = {} 

989 for lhsRow in aRows: 

990 if lhsRow[TimespanReprClass.NAME] is None: 

991 expected[lhsRow["id"]] = (None, None, None, None) 

992 else: 

993 expected[lhsRow["id"]] = ( 

994 lhsRow[TimespanReprClass.NAME].overlaps(rhsRow[TimespanReprClass.NAME]), 

995 lhsRow[TimespanReprClass.NAME].contains(rhsRow[TimespanReprClass.NAME]), 

996 lhsRow[TimespanReprClass.NAME] < rhsRow[TimespanReprClass.NAME], 

997 lhsRow[TimespanReprClass.NAME] > rhsRow[TimespanReprClass.NAME], 

998 ) 

999 rhsRepr = TimespanReprClass.fromLiteral(rhsRow[TimespanReprClass.NAME]) 

1000 sql = sqlalchemy.sql.select( 

1001 aTable.columns.id.label("lhs"), 

1002 aRepr.overlaps(rhsRepr).label("overlaps"), 

1003 aRepr.contains(rhsRepr).label("contains"), 

1004 (aRepr < rhsRepr).label("less_than"), 

1005 (aRepr > rhsRepr).label("greater_than"), 

1006 ).select_from(aTable) 

1007 queried = { 

1008 row.lhs: (row.overlaps, row.contains, row.less_than, row.greater_than) 

1009 for row in db.query(sql) 

1010 } 

1011 self.assertEqual(expected, queried) 

1012 # Test relationship expressions that relate in-database timespans to 

1013 # each other, all from the more complete 'a' set; check that this is 

1014 # consistent with Python-only relationship tests. 

1015 expected = {} 

1016 for lhs, rhs in itertools.product(aRows, aRows): 

1017 lhsT = lhs[TimespanReprClass.NAME] 

1018 rhsT = rhs[TimespanReprClass.NAME] 

1019 if lhsT is not None and rhsT is not None: 

1020 expected[lhs["id"], rhs["id"]] = ( 

1021 lhsT.overlaps(rhsT), 

1022 lhsT.contains(rhsT), 

1023 lhsT < rhsT, 

1024 lhsT > rhsT, 

1025 ) 

1026 else: 

1027 expected[lhs["id"], rhs["id"]] = (None, None, None, None) 

1028 lhsSubquery = aTable.alias("lhs") 

1029 rhsSubquery = aTable.alias("rhs") 

1030 lhsRepr = TimespanReprClass.fromSelectable(lhsSubquery) 

1031 rhsRepr = TimespanReprClass.fromSelectable(rhsSubquery) 

1032 sql = sqlalchemy.sql.select( 

1033 lhsSubquery.columns.id.label("lhs"), 

1034 rhsSubquery.columns.id.label("rhs"), 

1035 lhsRepr.overlaps(rhsRepr).label("overlaps"), 

1036 lhsRepr.contains(rhsRepr).label("contains"), 

1037 (lhsRepr < rhsRepr).label("less_than"), 

1038 (lhsRepr > rhsRepr).label("greater_than"), 

1039 ).select_from(lhsSubquery.join(rhsSubquery, onclause=sqlalchemy.sql.literal(True))) 

1040 queried = { 

1041 (row.lhs, row.rhs): (row.overlaps, row.contains, row.less_than, row.greater_than) 

1042 for row in db.query(sql) 

1043 } 

1044 self.assertEqual(expected, queried) 

1045 # Test relationship expressions between in-database timespans and 

1046 # Python-literal instantaneous times. 

1047 for t in timestamps: 

1048 with self.subTest(t=t): 

1049 expected = {} 

1050 for lhsRow in aRows: 

1051 if lhsRow[TimespanReprClass.NAME] is None: 

1052 expected[lhsRow["id"]] = (None, None, None) 

1053 else: 

1054 expected[lhsRow["id"]] = ( 

1055 lhsRow[TimespanReprClass.NAME].contains(t), 

1056 lhsRow[TimespanReprClass.NAME] < t, 

1057 lhsRow[TimespanReprClass.NAME] > t, 

1058 ) 

1059 rhs = sqlalchemy.sql.literal(t, type_=ddl.AstropyTimeNsecTai) 

1060 sql = sqlalchemy.sql.select( 

1061 aTable.columns.id.label("lhs"), 

1062 aRepr.contains(rhs).label("contains"), 

1063 (aRepr < rhs).label("less_than"), 

1064 (aRepr > rhs).label("greater_than"), 

1065 ).select_from(aTable) 

1066 queried = {row.lhs: (row.contains, row.less_than, row.greater_than) for row in db.query(sql)} 

1067 self.assertEqual(expected, queried)