Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25from abc import ABC, abstractmethod 

26import asyncio 

27from collections import namedtuple 

28from concurrent.futures import ThreadPoolExecutor 

29import itertools 

30from typing import ContextManager, Iterable, Set, Tuple 

31import warnings 

32 

33import astropy.time 

34import sqlalchemy 

35 

36from lsst.sphgeom import ConvexPolygon, UnitVector3d 

37from ..interfaces import ( 

38 Database, 

39 ReadOnlyDatabaseError, 

40 DatabaseConflictError, 

41 SchemaAlreadyDefinedError 

42) 

43from ...core import ddl, Timespan 

44 

45StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

46 

47STATIC_TABLE_SPECS = StaticTablesTuple( 

48 a=ddl.TableSpec( 

49 fields=[ 

50 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

51 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

52 ] 

53 ), 

54 b=ddl.TableSpec( 

55 fields=[ 

56 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

58 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

59 ], 

60 unique=[("name",)], 

61 ), 

62 c=ddl.TableSpec( 

63 fields=[ 

64 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

65 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

66 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

67 ], 

68 foreignKeys=[ 

69 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

70 ] 

71 ), 

72) 

73 

74DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

75 fields=[ 

76 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

77 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

78 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

79 ], 

80 foreignKeys=[ 

81 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"), 

82 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

83 ] 

84) 

85 

86TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

87 fields=[ 

88 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

89 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

90 ], 

91) 

92 

93 

94class DatabaseTests(ABC): 

95 """Generic tests for the `Database` interface that can be subclassed to 

96 generate tests for concrete implementations. 

97 """ 

98 

99 @abstractmethod 

100 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

101 """Return an empty `Database` with the given origin, or an 

102 automatically-generated one if ``origin`` is `None`. 

103 """ 

104 raise NotImplementedError() 

105 

106 @abstractmethod 

107 def asReadOnly(self, database: Database) -> ContextManager[Database]: 

108 """Return a context manager for a read-only connection into the given 

109 database. 

110 

111 The original database should be considered unusable within the context 

112 but safe to use again afterwards (this allows the context manager to 

113 block write access by temporarily changing user permissions to really 

114 guarantee that write operations are not performed). 

115 """ 

116 raise NotImplementedError() 

117 

118 @abstractmethod 

119 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

120 """Return a new `Database` instance that points to the same underlying 

121 storage as the given one. 

122 """ 

123 raise NotImplementedError() 

124 

125 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

126 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

127 # Checking more than this currently seems fragile, as it might restrict 

128 # what Database implementations do; we don't care if the spec is 

129 # actually preserved in terms of types and constraints as long as we 

130 # can use the returned table as if it was. 

131 

132 def checkStaticSchema(self, tables: StaticTablesTuple): 

133 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

134 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

135 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

136 

137 def testDeclareStaticTables(self): 

138 """Tests for `Database.declareStaticSchema` and the methods it 

139 delegates to. 

140 """ 

141 # Create the static schema in a new, empty database. 

142 newDatabase = self.makeEmptyDatabase() 

143 with newDatabase.declareStaticTables(create=True) as context: 

144 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

145 self.checkStaticSchema(tables) 

146 # Check that we can load that schema even from a read-only connection. 

147 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

148 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

149 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

150 self.checkStaticSchema(tables) 

151 

152 def testDeclareStaticTablesTwice(self): 

153 """Tests for `Database.declareStaticSchema` being called twice. 

154 """ 

155 # Create the static schema in a new, empty database. 

156 newDatabase = self.makeEmptyDatabase() 

157 with newDatabase.declareStaticTables(create=True) as context: 

158 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

159 self.checkStaticSchema(tables) 

160 # Second time it should raise 

161 with self.assertRaises(SchemaAlreadyDefinedError): 

162 with newDatabase.declareStaticTables(create=True) as context: 

163 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

164 # Check schema, it should still contain all tables, and maybe some 

165 # extra. 

166 with newDatabase.declareStaticTables(create=False) as context: 

167 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

168 

169 def testRepr(self): 

170 """Test that repr does not return a generic thing.""" 

171 newDatabase = self.makeEmptyDatabase() 

172 rep = repr(newDatabase) 

173 # Check that stringification works and gives us something different 

174 self.assertNotEqual(rep, str(newDatabase)) 

175 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

176 self.assertIn("://", rep) 

177 

178 def testDynamicTables(self): 

179 """Tests for `Database.ensureTableExists` and 

180 `Database.getExistingTable`. 

181 """ 

182 # Need to start with the static schema. 

183 newDatabase = self.makeEmptyDatabase() 

184 with newDatabase.declareStaticTables(create=True) as context: 

185 context.addTableTuple(STATIC_TABLE_SPECS) 

186 # Try to ensure the dyamic table exists in a read-only version of that 

187 # database, which should fail because we can't create it. 

188 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

189 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

190 context.addTableTuple(STATIC_TABLE_SPECS) 

191 with self.assertRaises(ReadOnlyDatabaseError): 

192 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

193 # Just getting the dynamic table before it exists should return None. 

194 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

195 # Ensure the new table exists back in the original database, which 

196 # should create it. 

197 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

198 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

199 # Ensuring that it exists should just return the exact same table 

200 # instance again. 

201 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

202 # Try again from the read-only database. 

203 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

204 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

205 context.addTableTuple(STATIC_TABLE_SPECS) 

206 # Just getting the dynamic table should now work... 

207 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

208 # ...as should ensuring that it exists, since it now does. 

209 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

210 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

211 # Trying to get the table with a different specification (at least 

212 # in terms of what columns are present) should raise. 

213 with self.assertRaises(DatabaseConflictError): 

214 newDatabase.ensureTableExists( 

215 "d", 

216 ddl.TableSpec( 

217 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

218 ) 

219 ) 

220 # Calling ensureTableExists inside a transaction block is an error, 

221 # even if it would do nothing. 

222 with newDatabase.transaction(): 

223 with self.assertRaises(AssertionError): 

224 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

225 

226 def testTemporaryTables(self): 

227 """Tests for `Database.makeTemporaryTable`, 

228 `Database.dropTemporaryTable`, and `Database.insert` with 

229 the ``select`` argument. 

230 """ 

231 # Need to start with the static schema; also insert some test data. 

232 newDatabase = self.makeEmptyDatabase() 

233 with newDatabase.declareStaticTables(create=True) as context: 

234 static = context.addTableTuple(STATIC_TABLE_SPECS) 

235 newDatabase.insert(static.a, 

236 {"name": "a1", "region": None}, 

237 {"name": "a2", "region": None}) 

238 bIds = newDatabase.insert(static.b, 

239 {"name": "b1", "value": 11}, 

240 {"name": "b2", "value": 12}, 

241 {"name": "b3", "value": 13}, 

242 returnIds=True) 

243 # Create the table. 

244 table1 = newDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1") 

245 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

246 # Insert via a INSERT INTO ... SELECT query. 

247 newDatabase.insert( 

248 table1, 

249 select=sqlalchemy.sql.select( 

250 [static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")] 

251 ).select_from( 

252 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)) 

253 ).where( 

254 sqlalchemy.sql.and_( 

255 static.a.columns.name == "a1", 

256 static.b.columns.value <= 12, 

257 ) 

258 ) 

259 ) 

260 # Check that the inserted rows are present. 

261 self.assertCountEqual( 

262 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

263 [dict(row) for row in newDatabase.query(table1.select())] 

264 ) 

265 # Create another one via a read-only connection to the database. 

266 # We _do_ allow temporary table modifications in read-only databases. 

267 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

268 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

269 context.addTableTuple(STATIC_TABLE_SPECS) 

270 table2 = existingReadOnlyDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC) 

271 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

272 # Those tables should not be the same, despite having the same ddl. 

273 self.assertIsNot(table1, table2) 

274 # Do a slightly different insert into this table, to check that 

275 # it works in a read-only database. This time we pass column 

276 # names as a kwarg to insert instead of by labeling the columns in 

277 # the select. 

278 existingReadOnlyDatabase.insert( 

279 table2, 

280 select=sqlalchemy.sql.select( 

281 [static.a.columns.name, static.b.columns.id] 

282 ).select_from( 

283 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)) 

284 ).where( 

285 sqlalchemy.sql.and_( 

286 static.a.columns.name == "a2", 

287 static.b.columns.value >= 12, 

288 ) 

289 ), 

290 names=["a_name", "b_id"], 

291 ) 

292 # Check that the inserted rows are present. 

293 self.assertCountEqual( 

294 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

295 [dict(row) for row in existingReadOnlyDatabase.query(table2.select())] 

296 ) 

297 # Drop the temporary table from the read-only DB. It's unspecified 

298 # whether attempting to use it after this point is an error or just 

299 # never returns any results, so we can't test what it does, only 

300 # that it's not an error. 

301 existingReadOnlyDatabase.dropTemporaryTable(table2) 

302 # Drop the original temporary table. 

303 newDatabase.dropTemporaryTable(table1) 

304 

305 def testSchemaSeparation(self): 

306 """Test that creating two different `Database` instances allows us 

307 to create different tables with the same name in each. 

308 """ 

309 db1 = self.makeEmptyDatabase(origin=1) 

310 with db1.declareStaticTables(create=True) as context: 

311 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

312 self.checkStaticSchema(tables) 

313 

314 db2 = self.makeEmptyDatabase(origin=2) 

315 # Make the DDL here intentionally different so we'll definitely 

316 # notice if db1 and db2 are pointing at the same schema. 

317 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

318 with db2.declareStaticTables(create=True) as context: 

319 # Make the DDL here intentionally different so we'll definitely 

320 # notice if db1 and db2 are pointing at the same schema. 

321 table = context.addTable("a", spec) 

322 self.checkTable(spec, table) 

323 

324 def testInsertQueryDelete(self): 

325 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

326 methods, as well as the `Base64Region` type and the ``onDelete`` 

327 argument to `ddl.ForeignKeySpec`. 

328 """ 

329 db = self.makeEmptyDatabase(origin=1) 

330 with db.declareStaticTables(create=True) as context: 

331 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

332 # Insert a single, non-autoincrement row that contains a region and 

333 # query to get it back. 

334 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

335 row = {"name": "a1", "region": region} 

336 db.insert(tables.a, row) 

337 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row]) 

338 # Insert multiple autoincrement rows but do not try to get the IDs 

339 # back immediately. 

340 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

341 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()] 

342 self.assertEqual(len(results), 2) 

343 for row in results: 

344 self.assertIn(row["name"], ("b1", "b2")) 

345 self.assertIsInstance(row["id"], int) 

346 self.assertGreater(results[1]["id"], results[0]["id"]) 

347 # Insert multiple autoincrement rows and get the IDs back from insert. 

348 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

349 ids = db.insert(tables.b, *rows, returnIds=True) 

350 results = [ 

351 dict(r) for r in db.query( 

352 tables.b.select().where(tables.b.columns.id > results[1]["id"]) 

353 ).fetchall() 

354 ] 

355 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

356 self.assertCountEqual(results, expected) 

357 self.assertTrue(all(result["id"] is not None for result in results)) 

358 # Insert multiple rows into a table with an autoincrement+origin 

359 # primary key, then use the returned IDs to insert into a dynamic 

360 # table. 

361 rows = [{"origin": db.origin, "b_id": results[0]["id"]}, 

362 {"origin": db.origin, "b_id": None}] 

363 ids = db.insert(tables.c, *rows, returnIds=True) 

364 results = [dict(r) for r in db.query(tables.c.select()).fetchall()] 

365 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

366 self.assertCountEqual(results, expected) 

367 self.assertTrue(all(result["id"] is not None for result in results)) 

368 # Add the dynamic table. 

369 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

370 # Insert into it. 

371 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids] 

372 db.insert(d, *rows) 

373 results = [dict(r) for r in db.query(d.select()).fetchall()] 

374 self.assertCountEqual(rows, results) 

375 # Insert multiple rows into a table with an autoincrement+origin 

376 # primary key (this is especially tricky for SQLite, but good to test 

377 # for all DBs), but pass in a value for the autoincrement key. 

378 # For extra complexity, we re-use the autoincrement value with a 

379 # different value for origin. 

380 rows2 = [{"id": 700, "origin": db.origin, "b_id": None}, 

381 {"id": 700, "origin": 60, "b_id": None}, 

382 {"id": 1, "origin": 60, "b_id": None}] 

383 db.insert(tables.c, *rows2) 

384 results = [dict(r) for r in db.query(tables.c.select()).fetchall()] 

385 self.assertCountEqual(results, expected + rows2) 

386 self.assertTrue(all(result["id"] is not None for result in results)) 

387 

388 # Define 'SELECT COUNT(*)' query for later use. 

389 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()]) 

390 # Get the values we inserted into table b. 

391 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()] 

392 # Remove two row from table b by ID. 

393 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

394 self.assertEqual(n, 2) 

395 # Remove the other two rows from table b by name. 

396 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

397 self.assertEqual(n, 2) 

398 # There should now be no rows in table b. 

399 self.assertEqual( 

400 db.query(count.select_from(tables.b)).scalar(), 

401 0 

402 ) 

403 # All b_id values in table c should now be NULL, because there's an 

404 # onDelete='SET NULL' foreign key. 

405 self.assertEqual( 

406 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711 

407 0 

408 ) 

409 # Remove all rows in table a (there's only one); this should remove all 

410 # rows in d due to onDelete='CASCADE'. 

411 n = db.delete(tables.a, []) 

412 self.assertEqual(n, 1) 

413 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0) 

414 self.assertEqual(db.query(count.select_from(d)).scalar(), 0) 

415 

416 def testUpdate(self): 

417 """Tests for `Database.update`. 

418 """ 

419 db = self.makeEmptyDatabase(origin=1) 

420 with db.declareStaticTables(create=True) as context: 

421 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

422 # Insert two rows into table a, both without regions. 

423 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

424 # Update one of the rows with a region. 

425 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

426 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

427 self.assertEqual(n, 1) 

428 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a) 

429 self.assertCountEqual( 

430 [dict(r) for r in db.query(sql).fetchall()], 

431 [{"name": "a1", "region": None}, {"name": "a2", "region": region}] 

432 ) 

433 

434 def testSync(self): 

435 """Tests for `Database.sync`. 

436 """ 

437 db = self.makeEmptyDatabase(origin=1) 

438 with db.declareStaticTables(create=True) as context: 

439 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

440 # Insert a row with sync, because it doesn't exist yet. 

441 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

442 self.assertTrue(inserted) 

443 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

444 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

445 # Repeat that operation, which should do nothing but return the 

446 # requested values. 

447 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

448 self.assertFalse(inserted) 

449 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

450 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

451 # Repeat the operation without the 'extra' arg, which should also just 

452 # return the existing row. 

453 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

454 self.assertFalse(inserted) 

455 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

456 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

457 # Repeat the operation with a different value in 'extra'. That still 

458 # shouldn't be an error, because 'extra' is only used if we really do 

459 # insert. Also drop the 'returning' argument. 

460 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

461 self.assertFalse(inserted) 

462 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

463 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

464 # Repeat the operation with the correct value in 'compared' instead of 

465 # 'extra'. 

466 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

467 self.assertFalse(inserted) 

468 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

469 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

470 # Repeat the operation with an incorrect value in 'compared'; this 

471 # should raise. 

472 with self.assertRaises(DatabaseConflictError): 

473 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

474 # Try to sync inside a transaction. That's always an error, regardless 

475 # of whether there would be an insertion or not. 

476 with self.assertRaises(AssertionError): 

477 with db.transaction(): 

478 db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}) 

479 with self.assertRaises(AssertionError): 

480 with db.transaction(): 

481 db.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

482 # Try to sync in a read-only database. This should work if and only 

483 # if the matching row already exists. 

484 with self.asReadOnly(db) as rodb: 

485 with rodb.declareStaticTables(create=False) as context: 

486 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

487 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

488 self.assertFalse(inserted) 

489 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

490 [dict(r) for r in rodb.query(tables.b.select()).fetchall()]) 

491 with self.assertRaises(ReadOnlyDatabaseError): 

492 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

493 

494 def testReplace(self): 

495 """Tests for `Database.replace`. 

496 """ 

497 db = self.makeEmptyDatabase(origin=1) 

498 with db.declareStaticTables(create=True) as context: 

499 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

500 # Use 'replace' to insert a single row that contains a region and 

501 # query to get it back. 

502 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

503 row1 = {"name": "a1", "region": region} 

504 db.replace(tables.a, row1) 

505 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1]) 

506 # Insert another row without a region. 

507 row2 = {"name": "a2", "region": None} 

508 db.replace(tables.a, row2) 

509 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

510 # Use replace to re-insert both of those rows again, which should do 

511 # nothing. 

512 db.replace(tables.a, row1, row2) 

513 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

514 # Replace row1 with a row with no region, while reinserting row2. 

515 row1a = {"name": "a1", "region": None} 

516 db.replace(tables.a, row1a, row2) 

517 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2]) 

518 # Replace both rows, returning row1 to its original state, while adding 

519 # a new one. Pass them in in a different order. 

520 row2a = {"name": "a2", "region": region} 

521 row3 = {"name": "a3", "region": None} 

522 db.replace(tables.a, row3, row2a, row1) 

523 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3]) 

524 

525 def testTransactionNesting(self): 

526 """Test that transactions can be nested with the behavior in the 

527 presence of exceptions working as documented. 

528 """ 

529 db = self.makeEmptyDatabase(origin=1) 

530 with db.declareStaticTables(create=True) as context: 

531 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

532 # Insert one row so we can trigger integrity errors by trying to insert 

533 # a duplicate of it below. 

534 db.insert(tables.a, {"name": "a1"}) 

535 # First test: error recovery via explicit savepoint=True in the inner 

536 # transaction. 

537 with db.transaction(): 

538 # This insert should succeed, and should not be rolled back because 

539 # the assertRaises context should catch any exception before it 

540 # propagates up to the outer transaction. 

541 db.insert(tables.a, {"name": "a2"}) 

542 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

543 with db.transaction(savepoint=True): 

544 # This insert should succeed, but should be rolled back. 

545 db.insert(tables.a, {"name": "a4"}) 

546 # This insert should fail (duplicate primary key), raising 

547 # an exception. 

548 db.insert(tables.a, {"name": "a1"}) 

549 self.assertCountEqual( 

550 [dict(r) for r in db.query(tables.a.select()).fetchall()], 

551 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

552 ) 

553 # Second test: error recovery via implicit savepoint=True, when the 

554 # innermost transaction is inside a savepoint=True transaction. 

555 with db.transaction(): 

556 # This insert should succeed, and should not be rolled back 

557 # because the assertRaises context should catch any 

558 # exception before it propagates up to the outer 

559 # transaction. 

560 db.insert(tables.a, {"name": "a3"}) 

561 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

562 with db.transaction(savepoint=True): 

563 # This insert should succeed, but should be rolled back. 

564 db.insert(tables.a, {"name": "a4"}) 

565 with db.transaction(): 

566 # This insert should succeed, but should be rolled 

567 # back. 

568 db.insert(tables.a, {"name": "a5"}) 

569 # This insert should fail (duplicate primary key), 

570 # raising an exception. 

571 db.insert(tables.a, {"name": "a1"}) 

572 self.assertCountEqual( 

573 [dict(r) for r in db.query(tables.a.select()).fetchall()], 

574 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

575 ) 

576 

577 def testTransactionLocking(self): 

578 """Test that `Database.transaction` can be used to acquire a lock 

579 that prohibits concurrent writes. 

580 """ 

581 db1 = self.makeEmptyDatabase(origin=1) 

582 with db1.declareStaticTables(create=True) as context: 

583 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

584 

585 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]: 

586 """One side of the concurrent locking test. 

587 

588 This optionally locks the table (and maybe the whole database), 

589 does a select for its contents, inserts a new row, and then selects 

590 again, with some waiting in between to make sure the other side has 

591 a chance to _attempt_ to insert in between. If the locking is 

592 enabled and works, the difference between the selects should just 

593 be the insert done on this thread. 

594 """ 

595 # Give Side2 a chance to create a connection 

596 await asyncio.sleep(1.0) 

597 with db1.transaction(lock=lock): 

598 names1 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()} 

599 # Give Side2 a chance to insert (which will be blocked if 

600 # we've acquired a lock). 

601 await asyncio.sleep(2.0) 

602 db1.insert(tables1.a, {"name": "a1"}) 

603 names2 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()} 

604 return names1, names2 

605 

606 async def side2() -> None: 

607 """The other side of the concurrent locking test. 

608 

609 This side just waits a bit and then tries to insert a row into the 

610 table that the other side is trying to lock. Hopefully that 

611 waiting is enough to give the other side a chance to acquire the 

612 lock and thus make this side block until the lock is released. If 

613 this side manages to do the insert before side1 acquires the lock, 

614 we'll just warn about not succeeding at testing the locking, 

615 because we can only make that unlikely, not impossible. 

616 """ 

617 def toRunInThread(): 

618 """SQLite locking isn't asyncio-friendly unless we actually 

619 run it in another thread. And SQLite gets very unhappy if 

620 we try to use a connection from multiple threads, so we have 

621 to create the new connection here instead of out in the main 

622 body of the test function. 

623 """ 

624 db2 = self.getNewConnection(db1, writeable=True) 

625 with db2.declareStaticTables(create=False) as context: 

626 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

627 with db2.transaction(): 

628 db2.insert(tables2.a, {"name": "a2"}) 

629 

630 await asyncio.sleep(2.0) 

631 loop = asyncio.get_running_loop() 

632 with ThreadPoolExecutor() as pool: 

633 await loop.run_in_executor(pool, toRunInThread) 

634 

635 async def testProblemsWithNoLocking() -> None: 

636 """Run side1 and side2 with no locking, attempting to demonstrate 

637 the problem that locking is supposed to solve. If we get unlucky 

638 with scheduling, side2 will just happen to insert after side1 is 

639 done, and we won't have anything definitive. We just warn in that 

640 case because we really don't want spurious test failures. 

641 """ 

642 task1 = asyncio.create_task(side1()) 

643 task2 = asyncio.create_task(side2()) 

644 

645 names1, names2 = await task1 

646 await task2 

647 if "a2" in names1: 

648 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT " 

649 "happened before first SELECT.") 

650 self.assertEqual(names1, {"a2"}) 

651 self.assertEqual(names2, {"a1", "a2"}) 

652 elif "a2" not in names2: 

653 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT " 

654 "happened after second SELECT even without locking.") 

655 self.assertEqual(names1, set()) 

656 self.assertEqual(names2, {"a1"}) 

657 else: 

658 # This is the expected case: both INSERTS happen between the 

659 # two SELECTS. If we don't get this almost all of the time we 

660 # should adjust the sleep amounts. 

661 self.assertEqual(names1, set()) 

662 self.assertEqual(names2, {"a1", "a2"}) 

663 

664 asyncio.run(testProblemsWithNoLocking()) 

665 

666 # Clean up after first test. 

667 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

668 

669 async def testSolutionWithLocking() -> None: 

670 """Run side1 and side2 with locking, which should make side2 block 

671 its insert until side2 releases its lock. 

672 """ 

673 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

674 task2 = asyncio.create_task(side2()) 

675 

676 names1, names2 = await task1 

677 await task2 

678 if "a2" in names1: 

679 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT " 

680 "happened before first SELECT.") 

681 self.assertEqual(names1, {"a2"}) 

682 self.assertEqual(names2, {"a1", "a2"}) 

683 else: 

684 # This is the expected case: the side2 INSERT happens after the 

685 # last SELECT on side1. This can also happen due to unlucky 

686 # scheduling, and we have no way to detect that here, but the 

687 # similar "no-locking" test has at least some chance of being 

688 # affected by the same problem and warning about it. 

689 self.assertEqual(names1, set()) 

690 self.assertEqual(names2, {"a1"}) 

691 

692 asyncio.run(testSolutionWithLocking()) 

693 

694 def testDatabaseTimespanRepresentation(self): 

695 """Tests for `DatabaseTimespanRepresentation` and the `Database` 

696 methods that interact with it. 

697 """ 

698 # Make some test timespans to play with, with the full suite of 

699 # topological relationships. 

700 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai") 

701 offset = astropy.time.TimeDelta(60, format="sec") 

702 timestamps = [start + offset*n for n in range(3)] 

703 aTimespans = [Timespan(begin=None, end=None)] 

704 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

705 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

706 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

707 # Make another list of timespans that span the full range but don't 

708 # overlap. This is a subset of the previous list. 

709 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

710 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:])) 

711 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

712 # Make a database and create a table with that database's timespan 

713 # representation. This one will have no exclusion constraint and 

714 # a nullable timespan. 

715 db = self.makeEmptyDatabase(origin=1) 

716 tsRepr = db.getTimespanRepresentation() 

717 aSpec = ddl.TableSpec( 

718 fields=[ 

719 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

720 ], 

721 ) 

722 for fieldSpec in tsRepr.makeFieldSpecs(nullable=True): 

723 aSpec.fields.add(fieldSpec) 

724 with db.declareStaticTables(create=True) as context: 

725 aTable = context.addTable("a", aSpec) 

726 

727 def convertRowForInsert(row: dict) -> dict: 

728 """Convert a row containing a Timespan instance into one suitable 

729 for insertion into the database. 

730 """ 

731 result = row.copy() 

732 ts = result.pop(tsRepr.NAME) 

733 return tsRepr.update(ts, result=result) 

734 

735 def convertRowFromSelect(row: dict) -> dict: 

736 """Convert a row from the database into one containing a Timespan. 

737 """ 

738 result = row.copy() 

739 timespan = tsRepr.extract(result) 

740 for name in tsRepr.getFieldNames(): 

741 del result[name] 

742 result[tsRepr.NAME] = timespan 

743 return result 

744 

745 # Insert rows into table A, in chunks just to make things interesting. 

746 # Include one with a NULL timespan. 

747 aRows = [{"id": n, tsRepr.NAME: t} for n, t in enumerate(aTimespans)] 

748 aRows.append({"id": len(aRows), tsRepr.NAME: None}) 

749 db.insert(aTable, convertRowForInsert(aRows[0])) 

750 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

751 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

752 # Add another one with a NULL timespan, but this time by invoking 

753 # the server-side default. 

754 aRows.append({"id": len(aRows)}) 

755 db.insert(aTable, aRows[-1]) 

756 aRows[-1][tsRepr.NAME] = None 

757 # Test basic round-trip through database. 

758 self.assertEqual( 

759 aRows, 

760 [convertRowFromSelect(dict(row)) 

761 for row in db.query(aTable.select().order_by(aTable.columns.id)).fetchall()] 

762 ) 

763 # Create another table B with a not-null timespan and (if the database 

764 # supports it), an exclusion constraint. Use ensureTableExists this 

765 # time to check that mode of table creation vs. timespans. 

766 bSpec = ddl.TableSpec( 

767 fields=[ 

768 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

769 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

770 ], 

771 ) 

772 for fieldSpec in tsRepr.makeFieldSpecs(nullable=False): 

773 bSpec.fields.add(fieldSpec) 

774 if tsRepr.hasExclusionConstraint(): 

775 bSpec.exclusion.add(("key", tsRepr)) 

776 bTable = db.ensureTableExists("b", bSpec) 

777 # Insert rows into table B, again in chunks. Each Timespan appears 

778 # twice, but with different values for the 'key' field (which should 

779 # still be okay for any exclusion constraint we may have defined). 

780 bRows = [{"id": n, "key": 1, tsRepr.NAME: t} for n, t in enumerate(bTimespans)] 

781 offset = len(bRows) 

782 bRows.extend({"id": n + offset, "key": 2, tsRepr.NAME: t} for n, t in enumerate(bTimespans)) 

783 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

784 db.insert(bTable, convertRowForInsert(bRows[2])) 

785 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

786 # Insert a row with no timespan into table B. This should invoke the 

787 # server-side default, which is a timespan over (-∞, ∞). We set 

788 # key=3 to avoid upsetting an exclusion constraint that might exist. 

789 bRows.append({"id": len(bRows), "key": 3}) 

790 db.insert(bTable, bRows[-1]) 

791 bRows[-1][tsRepr.NAME] = Timespan(None, None) 

792 # Test basic round-trip through database. 

793 self.assertEqual( 

794 bRows, 

795 [convertRowFromSelect(dict(row)) 

796 for row in db.query(bTable.select().order_by(bTable.columns.id)).fetchall()] 

797 ) 

798 # Test that we can't insert timespan=None into this table. 

799 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

800 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, tsRepr.NAME: None})) 

801 # IFF this database supports exclusion constraints, test that they 

802 # also prevent inserts. 

803 if tsRepr.hasExclusionConstraint(): 

804 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

805 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1, 

806 tsRepr.NAME: Timespan(None, timestamps[1])})) 

807 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

808 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1, 

809 tsRepr.NAME: Timespan(timestamps[0], timestamps[2])})) 

810 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

811 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1, 

812 tsRepr.NAME: Timespan(timestamps[2], None)})) 

813 # Test NULL checks in SELECT queries, on both tables. 

814 aRepr = tsRepr.fromSelectable(aTable) 

815 self.assertEqual( 

816 [row[tsRepr.NAME] is None for row in aRows], 

817 [ 

818 row["f"] for row in db.query( 

819 sqlalchemy.sql.select( 

820 [aRepr.isNull().label("f")] 

821 ).order_by( 

822 aTable.columns.id 

823 ) 

824 ).fetchall() 

825 ] 

826 ) 

827 bRepr = tsRepr.fromSelectable(bTable) 

828 self.assertEqual( 

829 [False for row in bRows], 

830 [ 

831 row["f"] for row in db.query( 

832 sqlalchemy.sql.select( 

833 [bRepr.isNull().label("f")] 

834 ).order_by( 

835 bTable.columns.id 

836 ) 

837 ).fetchall() 

838 ] 

839 ) 

840 # Test overlap expressions that relate in-database A timespans to 

841 # Python-literal B timespans; check that this is consistent with 

842 # Python-only overlap tests. 

843 for bRow in bRows: 

844 with self.subTest(bRow=bRow): 

845 expected = {} 

846 for aRow in aRows: 

847 if aRow[tsRepr.NAME] is None: 

848 expected[aRow["id"]] = None 

849 else: 

850 expected[aRow["id"]] = aRow[tsRepr.NAME].overlaps(bRow[tsRepr.NAME]) 

851 sql = sqlalchemy.sql.select( 

852 [aTable.columns.id.label("a"), aRepr.overlaps(bRow[tsRepr.NAME]).label("f")] 

853 ).select_from(aTable) 

854 queried = {row["a"]: row["f"] for row in db.query(sql).fetchall()} 

855 self.assertEqual(expected, queried) 

856 # Test overlap expressions that relate in-database A timespans to 

857 # in-database B timespans; check that this is consistent with 

858 # Python-only overlap tests. 

859 expected = { 

860 (aRow["id"], bRow["id"]): (aRow[tsRepr.NAME].overlaps(bRow[tsRepr.NAME]) 

861 if aRow[tsRepr.NAME] is not None else None) 

862 for aRow, bRow in itertools.product(aRows, bRows) 

863 } 

864 sql = sqlalchemy.sql.select( 

865 [ 

866 aTable.columns.id.label("a"), 

867 bTable.columns.id.label("b"), 

868 aRepr.overlaps(bRepr).label("f") 

869 ] 

870 ).select_from(aTable.join(bTable, onclause=sqlalchemy.sql.literal(True))) 

871 queried = {(row["a"], row["b"]): row["f"] for row in db.query(sql).fetchall()} 

872 self.assertEqual(expected, queried)