Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25from abc import ABC, abstractmethod 

26import asyncio 

27from collections import namedtuple 

28from concurrent.futures import ThreadPoolExecutor 

29import itertools 

30from typing import ContextManager, Iterable, Set, Tuple 

31import warnings 

32 

33import astropy.time 

34import sqlalchemy 

35 

36from lsst.sphgeom import ConvexPolygon, UnitVector3d 

37from ..interfaces import ( 

38 Database, 

39 ReadOnlyDatabaseError, 

40 DatabaseConflictError, 

41 SchemaAlreadyDefinedError 

42) 

43from ...core import ddl, Timespan 

44 

45StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

46 

47STATIC_TABLE_SPECS = StaticTablesTuple( 

48 a=ddl.TableSpec( 

49 fields=[ 

50 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

51 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

52 ] 

53 ), 

54 b=ddl.TableSpec( 

55 fields=[ 

56 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

57 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

58 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

59 ], 

60 unique=[("name",)], 

61 ), 

62 c=ddl.TableSpec( 

63 fields=[ 

64 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

65 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

66 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

67 ], 

68 foreignKeys=[ 

69 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

70 ] 

71 ), 

72) 

73 

74DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

75 fields=[ 

76 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

77 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

78 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

79 ], 

80 foreignKeys=[ 

81 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"), 

82 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

83 ] 

84) 

85 

86TEMPORARY_TABLE_SPEC = ddl.TableSpec( 

87 fields=[ 

88 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

89 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

90 ], 

91) 

92 

93 

94class DatabaseTests(ABC): 

95 """Generic tests for the `Database` interface that can be subclassed to 

96 generate tests for concrete implementations. 

97 """ 

98 

99 @abstractmethod 

100 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

101 """Return an empty `Database` with the given origin, or an 

102 automatically-generated one if ``origin`` is `None`. 

103 """ 

104 raise NotImplementedError() 

105 

106 @abstractmethod 

107 def asReadOnly(self, database: Database) -> ContextManager[Database]: 

108 """Return a context manager for a read-only connection into the given 

109 database. 

110 

111 The original database should be considered unusable within the context 

112 but safe to use again afterwards (this allows the context manager to 

113 block write access by temporarily changing user permissions to really 

114 guarantee that write operations are not performed). 

115 """ 

116 raise NotImplementedError() 

117 

118 @abstractmethod 

119 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

120 """Return a new `Database` instance that points to the same underlying 

121 storage as the given one. 

122 """ 

123 raise NotImplementedError() 

124 

125 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

126 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

127 # Checking more than this currently seems fragile, as it might restrict 

128 # what Database implementations do; we don't care if the spec is 

129 # actually preserved in terms of types and constraints as long as we 

130 # can use the returned table as if it was. 

131 

132 def checkStaticSchema(self, tables: StaticTablesTuple): 

133 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

134 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

135 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

136 

137 def testDeclareStaticTables(self): 

138 """Tests for `Database.declareStaticSchema` and the methods it 

139 delegates to. 

140 """ 

141 # Create the static schema in a new, empty database. 

142 newDatabase = self.makeEmptyDatabase() 

143 with newDatabase.declareStaticTables(create=True) as context: 

144 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

145 self.checkStaticSchema(tables) 

146 # Check that we can load that schema even from a read-only connection. 

147 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

148 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

149 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

150 self.checkStaticSchema(tables) 

151 

152 def testDeclareStaticTablesTwice(self): 

153 """Tests for `Database.declareStaticSchema` being called twice. 

154 """ 

155 # Create the static schema in a new, empty database. 

156 newDatabase = self.makeEmptyDatabase() 

157 with newDatabase.declareStaticTables(create=True) as context: 

158 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

159 self.checkStaticSchema(tables) 

160 # Second time it should raise 

161 with self.assertRaises(SchemaAlreadyDefinedError): 

162 with newDatabase.declareStaticTables(create=True) as context: 

163 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

164 # Check schema, it should still contain all tables, and maybe some 

165 # extra. 

166 with newDatabase.declareStaticTables(create=False) as context: 

167 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

168 

169 def testRepr(self): 

170 """Test that repr does not return a generic thing.""" 

171 newDatabase = self.makeEmptyDatabase() 

172 rep = repr(newDatabase) 

173 # Check that stringification works and gives us something different 

174 self.assertNotEqual(rep, str(newDatabase)) 

175 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

176 self.assertIn("://", rep) 

177 

178 def testDynamicTables(self): 

179 """Tests for `Database.ensureTableExists` and 

180 `Database.getExistingTable`. 

181 """ 

182 # Need to start with the static schema. 

183 newDatabase = self.makeEmptyDatabase() 

184 with newDatabase.declareStaticTables(create=True) as context: 

185 context.addTableTuple(STATIC_TABLE_SPECS) 

186 # Try to ensure the dyamic table exists in a read-only version of that 

187 # database, which should fail because we can't create it. 

188 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

189 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

190 context.addTableTuple(STATIC_TABLE_SPECS) 

191 with self.assertRaises(ReadOnlyDatabaseError): 

192 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

193 # Just getting the dynamic table before it exists should return None. 

194 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

195 # Ensure the new table exists back in the original database, which 

196 # should create it. 

197 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

198 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

199 # Ensuring that it exists should just return the exact same table 

200 # instance again. 

201 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

202 # Try again from the read-only database. 

203 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

204 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

205 context.addTableTuple(STATIC_TABLE_SPECS) 

206 # Just getting the dynamic table should now work... 

207 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

208 # ...as should ensuring that it exists, since it now does. 

209 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

210 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

211 # Trying to get the table with a different specification (at least 

212 # in terms of what columns are present) should raise. 

213 with self.assertRaises(DatabaseConflictError): 

214 newDatabase.ensureTableExists( 

215 "d", 

216 ddl.TableSpec( 

217 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

218 ) 

219 ) 

220 # Calling ensureTableExists inside a transaction block is an error, 

221 # even if it would do nothing. 

222 with newDatabase.transaction(): 

223 with self.assertRaises(AssertionError): 

224 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

225 

226 def testTemporaryTables(self): 

227 """Tests for `Database.makeTemporaryTable`, 

228 `Database.dropTemporaryTable`, and `Database.insert` with 

229 the ``select`` argument. 

230 """ 

231 # Need to start with the static schema; also insert some test data. 

232 newDatabase = self.makeEmptyDatabase() 

233 with newDatabase.declareStaticTables(create=True) as context: 

234 static = context.addTableTuple(STATIC_TABLE_SPECS) 

235 newDatabase.insert(static.a, 

236 {"name": "a1", "region": None}, 

237 {"name": "a2", "region": None}) 

238 bIds = newDatabase.insert(static.b, 

239 {"name": "b1", "value": 11}, 

240 {"name": "b2", "value": 12}, 

241 {"name": "b3", "value": 13}, 

242 returnIds=True) 

243 # Create the table. 

244 table1 = newDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1") 

245 self.checkTable(TEMPORARY_TABLE_SPEC, table1) 

246 # Insert via a INSERT INTO ... SELECT query. 

247 newDatabase.insert( 

248 table1, 

249 select=sqlalchemy.sql.select( 

250 [static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")] 

251 ).select_from( 

252 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)) 

253 ).where( 

254 sqlalchemy.sql.and_( 

255 static.a.columns.name == "a1", 

256 static.b.columns.value <= 12, 

257 ) 

258 ) 

259 ) 

260 # Check that the inserted rows are present. 

261 self.assertCountEqual( 

262 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]], 

263 [dict(row) for row in newDatabase.query(table1.select())] 

264 ) 

265 # Create another one via a read-only connection to the database. 

266 # We _do_ allow temporary table modifications in read-only databases. 

267 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

268 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

269 context.addTableTuple(STATIC_TABLE_SPECS) 

270 table2 = existingReadOnlyDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC) 

271 self.checkTable(TEMPORARY_TABLE_SPEC, table2) 

272 # Those tables should not be the same, despite having the same ddl. 

273 self.assertIsNot(table1, table2) 

274 # Do a slightly different insert into this table, to check that 

275 # it works in a read-only database. This time we pass column 

276 # names as a kwarg to insert instead of by labeling the columns in 

277 # the select. 

278 existingReadOnlyDatabase.insert( 

279 table2, 

280 select=sqlalchemy.sql.select( 

281 [static.a.columns.name, static.b.columns.id] 

282 ).select_from( 

283 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True)) 

284 ).where( 

285 sqlalchemy.sql.and_( 

286 static.a.columns.name == "a2", 

287 static.b.columns.value >= 12, 

288 ) 

289 ), 

290 names=["a_name", "b_id"], 

291 ) 

292 # Check that the inserted rows are present. 

293 self.assertCountEqual( 

294 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]], 

295 [dict(row) for row in existingReadOnlyDatabase.query(table2.select())] 

296 ) 

297 # Drop the temporary table from the read-only DB. It's unspecified 

298 # whether attempting to use it after this point is an error or just 

299 # never returns any results, so we can't test what it does, only 

300 # that it's not an error. 

301 existingReadOnlyDatabase.dropTemporaryTable(table2) 

302 # Drop the original temporary table. 

303 newDatabase.dropTemporaryTable(table1) 

304 

305 def testSchemaSeparation(self): 

306 """Test that creating two different `Database` instances allows us 

307 to create different tables with the same name in each. 

308 """ 

309 db1 = self.makeEmptyDatabase(origin=1) 

310 with db1.declareStaticTables(create=True) as context: 

311 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

312 self.checkStaticSchema(tables) 

313 

314 db2 = self.makeEmptyDatabase(origin=2) 

315 # Make the DDL here intentionally different so we'll definitely 

316 # notice if db1 and db2 are pointing at the same schema. 

317 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

318 with db2.declareStaticTables(create=True) as context: 

319 # Make the DDL here intentionally different so we'll definitely 

320 # notice if db1 and db2 are pointing at the same schema. 

321 table = context.addTable("a", spec) 

322 self.checkTable(spec, table) 

323 

324 def testInsertQueryDelete(self): 

325 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

326 methods, as well as the `Base64Region` type and the ``onDelete`` 

327 argument to `ddl.ForeignKeySpec`. 

328 """ 

329 db = self.makeEmptyDatabase(origin=1) 

330 with db.declareStaticTables(create=True) as context: 

331 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

332 # Insert a single, non-autoincrement row that contains a region and 

333 # query to get it back. 

334 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

335 row = {"name": "a1", "region": region} 

336 db.insert(tables.a, row) 

337 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row]) 

338 # Insert multiple autoincrement rows but do not try to get the IDs 

339 # back immediately. 

340 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

341 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()] 

342 self.assertEqual(len(results), 2) 

343 for row in results: 

344 self.assertIn(row["name"], ("b1", "b2")) 

345 self.assertIsInstance(row["id"], int) 

346 self.assertGreater(results[1]["id"], results[0]["id"]) 

347 # Insert multiple autoincrement rows and get the IDs back from insert. 

348 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

349 ids = db.insert(tables.b, *rows, returnIds=True) 

350 results = [ 

351 dict(r) for r in db.query( 

352 tables.b.select().where(tables.b.columns.id > results[1]["id"]) 

353 ).fetchall() 

354 ] 

355 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

356 self.assertCountEqual(results, expected) 

357 self.assertTrue(all(result["id"] is not None for result in results)) 

358 # Insert multiple rows into a table with an autoincrement+origin 

359 # primary key, then use the returned IDs to insert into a dynamic 

360 # table. 

361 rows = [{"origin": db.origin, "b_id": results[0]["id"]}, 

362 {"origin": db.origin, "b_id": None}] 

363 ids = db.insert(tables.c, *rows, returnIds=True) 

364 results = [dict(r) for r in db.query(tables.c.select()).fetchall()] 

365 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

366 self.assertCountEqual(results, expected) 

367 self.assertTrue(all(result["id"] is not None for result in results)) 

368 # Add the dynamic table. 

369 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

370 # Insert into it. 

371 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids] 

372 db.insert(d, *rows) 

373 results = [dict(r) for r in db.query(d.select()).fetchall()] 

374 self.assertCountEqual(rows, results) 

375 # Insert multiple rows into a table with an autoincrement+origin 

376 # primary key (this is especially tricky for SQLite, but good to test 

377 # for all DBs), but pass in a value for the autoincrement key. 

378 # For extra complexity, we re-use the autoincrement value with a 

379 # different value for origin. 

380 rows2 = [{"id": 700, "origin": db.origin, "b_id": None}, 

381 {"id": 700, "origin": 60, "b_id": None}, 

382 {"id": 1, "origin": 60, "b_id": None}] 

383 db.insert(tables.c, *rows2) 

384 results = [dict(r) for r in db.query(tables.c.select()).fetchall()] 

385 self.assertCountEqual(results, expected + rows2) 

386 self.assertTrue(all(result["id"] is not None for result in results)) 

387 

388 # Define 'SELECT COUNT(*)' query for later use. 

389 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()]) 

390 # Get the values we inserted into table b. 

391 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()] 

392 # Remove two row from table b by ID. 

393 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

394 self.assertEqual(n, 2) 

395 # Remove the other two rows from table b by name. 

396 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

397 self.assertEqual(n, 2) 

398 # There should now be no rows in table b. 

399 self.assertEqual( 

400 db.query(count.select_from(tables.b)).scalar(), 

401 0 

402 ) 

403 # All b_id values in table c should now be NULL, because there's an 

404 # onDelete='SET NULL' foreign key. 

405 self.assertEqual( 

406 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711 

407 0 

408 ) 

409 # Remove all rows in table a (there's only one); this should remove all 

410 # rows in d due to onDelete='CASCADE'. 

411 n = db.delete(tables.a, []) 

412 self.assertEqual(n, 1) 

413 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0) 

414 self.assertEqual(db.query(count.select_from(d)).scalar(), 0) 

415 

416 def testUpdate(self): 

417 """Tests for `Database.update`. 

418 """ 

419 db = self.makeEmptyDatabase(origin=1) 

420 with db.declareStaticTables(create=True) as context: 

421 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

422 # Insert two rows into table a, both without regions. 

423 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

424 # Update one of the rows with a region. 

425 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

426 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

427 self.assertEqual(n, 1) 

428 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a) 

429 self.assertCountEqual( 

430 [dict(r) for r in db.query(sql).fetchall()], 

431 [{"name": "a1", "region": None}, {"name": "a2", "region": region}] 

432 ) 

433 

434 def testSync(self): 

435 """Tests for `Database.sync`. 

436 """ 

437 db = self.makeEmptyDatabase(origin=1) 

438 with db.declareStaticTables(create=True) as context: 

439 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

440 # Insert a row with sync, because it doesn't exist yet. 

441 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

442 self.assertTrue(inserted) 

443 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

444 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

445 # Repeat that operation, which should do nothing but return the 

446 # requested values. 

447 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

448 self.assertFalse(inserted) 

449 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

450 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

451 # Repeat the operation without the 'extra' arg, which should also just 

452 # return the existing row. 

453 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

454 self.assertFalse(inserted) 

455 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

456 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

457 # Repeat the operation with a different value in 'extra'. That still 

458 # shouldn't be an error, because 'extra' is only used if we really do 

459 # insert. Also drop the 'returning' argument. 

460 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

461 self.assertFalse(inserted) 

462 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

463 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

464 # Repeat the operation with the correct value in 'compared' instead of 

465 # 'extra'. 

466 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

467 self.assertFalse(inserted) 

468 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

469 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

470 # Repeat the operation with an incorrect value in 'compared'; this 

471 # should raise. 

472 with self.assertRaises(DatabaseConflictError): 

473 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

474 # Try to sync in a read-only database. This should work if and only 

475 # if the matching row already exists. 

476 with self.asReadOnly(db) as rodb: 

477 with rodb.declareStaticTables(create=False) as context: 

478 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

479 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

480 self.assertFalse(inserted) 

481 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

482 [dict(r) for r in rodb.query(tables.b.select()).fetchall()]) 

483 with self.assertRaises(ReadOnlyDatabaseError): 

484 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

485 

486 def testReplace(self): 

487 """Tests for `Database.replace`. 

488 """ 

489 db = self.makeEmptyDatabase(origin=1) 

490 with db.declareStaticTables(create=True) as context: 

491 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

492 # Use 'replace' to insert a single row that contains a region and 

493 # query to get it back. 

494 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

495 row1 = {"name": "a1", "region": region} 

496 db.replace(tables.a, row1) 

497 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1]) 

498 # Insert another row without a region. 

499 row2 = {"name": "a2", "region": None} 

500 db.replace(tables.a, row2) 

501 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

502 # Use replace to re-insert both of those rows again, which should do 

503 # nothing. 

504 db.replace(tables.a, row1, row2) 

505 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

506 # Replace row1 with a row with no region, while reinserting row2. 

507 row1a = {"name": "a1", "region": None} 

508 db.replace(tables.a, row1a, row2) 

509 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2]) 

510 # Replace both rows, returning row1 to its original state, while adding 

511 # a new one. Pass them in in a different order. 

512 row2a = {"name": "a2", "region": region} 

513 row3 = {"name": "a3", "region": None} 

514 db.replace(tables.a, row3, row2a, row1) 

515 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3]) 

516 

517 def testEnsure(self): 

518 """Tests for `Database.ensure`. 

519 """ 

520 db = self.makeEmptyDatabase(origin=1) 

521 with db.declareStaticTables(create=True) as context: 

522 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

523 # Use 'ensure' to insert a single row that contains a region and 

524 # query to get it back. 

525 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

526 row1 = {"name": "a1", "region": region} 

527 self.assertEqual(db.ensure(tables.a, row1), 1) 

528 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1]) 

529 # Insert another row without a region. 

530 row2 = {"name": "a2", "region": None} 

531 self.assertEqual(db.ensure(tables.a, row2), 1) 

532 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

533 # Use ensure to re-insert both of those rows again, which should do 

534 # nothing. 

535 self.assertEqual(db.ensure(tables.a, row1, row2), 0) 

536 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

537 # Attempt to insert row1's key with no region, while 

538 # reinserting row2. This should also do nothing. 

539 row1a = {"name": "a1", "region": None} 

540 self.assertEqual(db.ensure(tables.a, row1a, row2), 0) 

541 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

542 # Attempt to insert new rows for both existing keys, this time also 

543 # adding a new row. Pass them in in a different order. Only the new 

544 # row should be added. 

545 row2a = {"name": "a2", "region": region} 

546 row3 = {"name": "a3", "region": None} 

547 self.assertEqual(db.ensure(tables.a, row3, row2a, row1a), 1) 

548 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2, row3]) 

549 

550 def testTransactionNesting(self): 

551 """Test that transactions can be nested with the behavior in the 

552 presence of exceptions working as documented. 

553 """ 

554 db = self.makeEmptyDatabase(origin=1) 

555 with db.declareStaticTables(create=True) as context: 

556 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

557 # Insert one row so we can trigger integrity errors by trying to insert 

558 # a duplicate of it below. 

559 db.insert(tables.a, {"name": "a1"}) 

560 # First test: error recovery via explicit savepoint=True in the inner 

561 # transaction. 

562 with db.transaction(): 

563 # This insert should succeed, and should not be rolled back because 

564 # the assertRaises context should catch any exception before it 

565 # propagates up to the outer transaction. 

566 db.insert(tables.a, {"name": "a2"}) 

567 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

568 with db.transaction(savepoint=True): 

569 # This insert should succeed, but should be rolled back. 

570 db.insert(tables.a, {"name": "a4"}) 

571 # This insert should fail (duplicate primary key), raising 

572 # an exception. 

573 db.insert(tables.a, {"name": "a1"}) 

574 self.assertCountEqual( 

575 [dict(r) for r in db.query(tables.a.select()).fetchall()], 

576 [{"name": "a1", "region": None}, {"name": "a2", "region": None}], 

577 ) 

578 # Second test: error recovery via implicit savepoint=True, when the 

579 # innermost transaction is inside a savepoint=True transaction. 

580 with db.transaction(): 

581 # This insert should succeed, and should not be rolled back 

582 # because the assertRaises context should catch any 

583 # exception before it propagates up to the outer 

584 # transaction. 

585 db.insert(tables.a, {"name": "a3"}) 

586 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

587 with db.transaction(savepoint=True): 

588 # This insert should succeed, but should be rolled back. 

589 db.insert(tables.a, {"name": "a4"}) 

590 with db.transaction(): 

591 # This insert should succeed, but should be rolled 

592 # back. 

593 db.insert(tables.a, {"name": "a5"}) 

594 # This insert should fail (duplicate primary key), 

595 # raising an exception. 

596 db.insert(tables.a, {"name": "a1"}) 

597 self.assertCountEqual( 

598 [dict(r) for r in db.query(tables.a.select()).fetchall()], 

599 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}], 

600 ) 

601 

602 def testTransactionLocking(self): 

603 """Test that `Database.transaction` can be used to acquire a lock 

604 that prohibits concurrent writes. 

605 """ 

606 db1 = self.makeEmptyDatabase(origin=1) 

607 with db1.declareStaticTables(create=True) as context: 

608 tables1 = context.addTableTuple(STATIC_TABLE_SPECS) 

609 

610 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]: 

611 """One side of the concurrent locking test. 

612 

613 This optionally locks the table (and maybe the whole database), 

614 does a select for its contents, inserts a new row, and then selects 

615 again, with some waiting in between to make sure the other side has 

616 a chance to _attempt_ to insert in between. If the locking is 

617 enabled and works, the difference between the selects should just 

618 be the insert done on this thread. 

619 """ 

620 # Give Side2 a chance to create a connection 

621 await asyncio.sleep(1.0) 

622 with db1.transaction(lock=lock): 

623 names1 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()} 

624 # Give Side2 a chance to insert (which will be blocked if 

625 # we've acquired a lock). 

626 await asyncio.sleep(2.0) 

627 db1.insert(tables1.a, {"name": "a1"}) 

628 names2 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()} 

629 return names1, names2 

630 

631 async def side2() -> None: 

632 """The other side of the concurrent locking test. 

633 

634 This side just waits a bit and then tries to insert a row into the 

635 table that the other side is trying to lock. Hopefully that 

636 waiting is enough to give the other side a chance to acquire the 

637 lock and thus make this side block until the lock is released. If 

638 this side manages to do the insert before side1 acquires the lock, 

639 we'll just warn about not succeeding at testing the locking, 

640 because we can only make that unlikely, not impossible. 

641 """ 

642 def toRunInThread(): 

643 """SQLite locking isn't asyncio-friendly unless we actually 

644 run it in another thread. And SQLite gets very unhappy if 

645 we try to use a connection from multiple threads, so we have 

646 to create the new connection here instead of out in the main 

647 body of the test function. 

648 """ 

649 db2 = self.getNewConnection(db1, writeable=True) 

650 with db2.declareStaticTables(create=False) as context: 

651 tables2 = context.addTableTuple(STATIC_TABLE_SPECS) 

652 with db2.transaction(): 

653 db2.insert(tables2.a, {"name": "a2"}) 

654 

655 await asyncio.sleep(2.0) 

656 loop = asyncio.get_running_loop() 

657 with ThreadPoolExecutor() as pool: 

658 await loop.run_in_executor(pool, toRunInThread) 

659 

660 async def testProblemsWithNoLocking() -> None: 

661 """Run side1 and side2 with no locking, attempting to demonstrate 

662 the problem that locking is supposed to solve. If we get unlucky 

663 with scheduling, side2 will just happen to insert after side1 is 

664 done, and we won't have anything definitive. We just warn in that 

665 case because we really don't want spurious test failures. 

666 """ 

667 task1 = asyncio.create_task(side1()) 

668 task2 = asyncio.create_task(side2()) 

669 

670 names1, names2 = await task1 

671 await task2 

672 if "a2" in names1: 

673 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT " 

674 "happened before first SELECT.") 

675 self.assertEqual(names1, {"a2"}) 

676 self.assertEqual(names2, {"a1", "a2"}) 

677 elif "a2" not in names2: 

678 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT " 

679 "happened after second SELECT even without locking.") 

680 self.assertEqual(names1, set()) 

681 self.assertEqual(names2, {"a1"}) 

682 else: 

683 # This is the expected case: both INSERTS happen between the 

684 # two SELECTS. If we don't get this almost all of the time we 

685 # should adjust the sleep amounts. 

686 self.assertEqual(names1, set()) 

687 self.assertEqual(names2, {"a1", "a2"}) 

688 

689 asyncio.run(testProblemsWithNoLocking()) 

690 

691 # Clean up after first test. 

692 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"}) 

693 

694 async def testSolutionWithLocking() -> None: 

695 """Run side1 and side2 with locking, which should make side2 block 

696 its insert until side2 releases its lock. 

697 """ 

698 task1 = asyncio.create_task(side1(lock=[tables1.a])) 

699 task2 = asyncio.create_task(side2()) 

700 

701 names1, names2 = await task1 

702 await task2 

703 if "a2" in names1: 

704 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT " 

705 "happened before first SELECT.") 

706 self.assertEqual(names1, {"a2"}) 

707 self.assertEqual(names2, {"a1", "a2"}) 

708 else: 

709 # This is the expected case: the side2 INSERT happens after the 

710 # last SELECT on side1. This can also happen due to unlucky 

711 # scheduling, and we have no way to detect that here, but the 

712 # similar "no-locking" test has at least some chance of being 

713 # affected by the same problem and warning about it. 

714 self.assertEqual(names1, set()) 

715 self.assertEqual(names2, {"a1"}) 

716 

717 asyncio.run(testSolutionWithLocking()) 

718 

719 def testDatabaseTimespanRepresentation(self): 

720 """Tests for `DatabaseTimespanRepresentation` and the `Database` 

721 methods that interact with it. 

722 """ 

723 # Make some test timespans to play with, with the full suite of 

724 # topological relationships. 

725 start = astropy.time.Time('2020-01-01T00:00:00', format="isot", scale="tai") 

726 offset = astropy.time.TimeDelta(60, format="sec") 

727 timestamps = [start + offset*n for n in range(3)] 

728 aTimespans = [Timespan(begin=None, end=None)] 

729 aTimespans.extend(Timespan(begin=None, end=t) for t in timestamps) 

730 aTimespans.extend(Timespan(begin=t, end=None) for t in timestamps) 

731 aTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in itertools.combinations(timestamps, 2)) 

732 # Make another list of timespans that span the full range but don't 

733 # overlap. This is a subset of the previous list. 

734 bTimespans = [Timespan(begin=None, end=timestamps[0])] 

735 bTimespans.extend(Timespan(begin=t1, end=t2) for t1, t2 in zip(timestamps[:-1], timestamps[1:])) 

736 bTimespans.append(Timespan(begin=timestamps[-1], end=None)) 

737 # Make a database and create a table with that database's timespan 

738 # representation. This one will have no exclusion constraint and 

739 # a nullable timespan. 

740 db = self.makeEmptyDatabase(origin=1) 

741 tsRepr = db.getTimespanRepresentation() 

742 aSpec = ddl.TableSpec( 

743 fields=[ 

744 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

745 ], 

746 ) 

747 for fieldSpec in tsRepr.makeFieldSpecs(nullable=True): 

748 aSpec.fields.add(fieldSpec) 

749 with db.declareStaticTables(create=True) as context: 

750 aTable = context.addTable("a", aSpec) 

751 

752 def convertRowForInsert(row: dict) -> dict: 

753 """Convert a row containing a Timespan instance into one suitable 

754 for insertion into the database. 

755 """ 

756 result = row.copy() 

757 ts = result.pop(tsRepr.NAME) 

758 return tsRepr.update(ts, result=result) 

759 

760 def convertRowFromSelect(row: dict) -> dict: 

761 """Convert a row from the database into one containing a Timespan. 

762 """ 

763 result = row.copy() 

764 timespan = tsRepr.extract(result) 

765 for name in tsRepr.getFieldNames(): 

766 del result[name] 

767 result[tsRepr.NAME] = timespan 

768 return result 

769 

770 # Insert rows into table A, in chunks just to make things interesting. 

771 # Include one with a NULL timespan. 

772 aRows = [{"id": n, tsRepr.NAME: t} for n, t in enumerate(aTimespans)] 

773 aRows.append({"id": len(aRows), tsRepr.NAME: None}) 

774 db.insert(aTable, convertRowForInsert(aRows[0])) 

775 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[1:3]]) 

776 db.insert(aTable, *[convertRowForInsert(r) for r in aRows[3:]]) 

777 # Add another one with a NULL timespan, but this time by invoking 

778 # the server-side default. 

779 aRows.append({"id": len(aRows)}) 

780 db.insert(aTable, aRows[-1]) 

781 aRows[-1][tsRepr.NAME] = None 

782 # Test basic round-trip through database. 

783 self.assertEqual( 

784 aRows, 

785 [convertRowFromSelect(dict(row)) 

786 for row in db.query(aTable.select().order_by(aTable.columns.id)).fetchall()] 

787 ) 

788 # Create another table B with a not-null timespan and (if the database 

789 # supports it), an exclusion constraint. Use ensureTableExists this 

790 # time to check that mode of table creation vs. timespans. 

791 bSpec = ddl.TableSpec( 

792 fields=[ 

793 ddl.FieldSpec(name="id", dtype=sqlalchemy.Integer, primaryKey=True), 

794 ddl.FieldSpec(name="key", dtype=sqlalchemy.Integer, nullable=False), 

795 ], 

796 ) 

797 for fieldSpec in tsRepr.makeFieldSpecs(nullable=False): 

798 bSpec.fields.add(fieldSpec) 

799 if tsRepr.hasExclusionConstraint(): 

800 bSpec.exclusion.add(("key", tsRepr)) 

801 bTable = db.ensureTableExists("b", bSpec) 

802 # Insert rows into table B, again in chunks. Each Timespan appears 

803 # twice, but with different values for the 'key' field (which should 

804 # still be okay for any exclusion constraint we may have defined). 

805 bRows = [{"id": n, "key": 1, tsRepr.NAME: t} for n, t in enumerate(bTimespans)] 

806 offset = len(bRows) 

807 bRows.extend({"id": n + offset, "key": 2, tsRepr.NAME: t} for n, t in enumerate(bTimespans)) 

808 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[:2]]) 

809 db.insert(bTable, convertRowForInsert(bRows[2])) 

810 db.insert(bTable, *[convertRowForInsert(r) for r in bRows[3:]]) 

811 # Insert a row with no timespan into table B. This should invoke the 

812 # server-side default, which is a timespan over (-∞, ∞). We set 

813 # key=3 to avoid upsetting an exclusion constraint that might exist. 

814 bRows.append({"id": len(bRows), "key": 3}) 

815 db.insert(bTable, bRows[-1]) 

816 bRows[-1][tsRepr.NAME] = Timespan(None, None) 

817 # Test basic round-trip through database. 

818 self.assertEqual( 

819 bRows, 

820 [convertRowFromSelect(dict(row)) 

821 for row in db.query(bTable.select().order_by(bTable.columns.id)).fetchall()] 

822 ) 

823 # Test that we can't insert timespan=None into this table. 

824 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

825 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 4, tsRepr.NAME: None})) 

826 # IFF this database supports exclusion constraints, test that they 

827 # also prevent inserts. 

828 if tsRepr.hasExclusionConstraint(): 

829 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

830 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1, 

831 tsRepr.NAME: Timespan(None, timestamps[1])})) 

832 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

833 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1, 

834 tsRepr.NAME: Timespan(timestamps[0], timestamps[2])})) 

835 with self.assertRaises(sqlalchemy.exc.IntegrityError): 

836 db.insert(bTable, convertRowForInsert({"id": len(bRows), "key": 1, 

837 tsRepr.NAME: Timespan(timestamps[2], None)})) 

838 # Test NULL checks in SELECT queries, on both tables. 

839 aRepr = tsRepr.fromSelectable(aTable) 

840 self.assertEqual( 

841 [row[tsRepr.NAME] is None for row in aRows], 

842 [ 

843 row["f"] for row in db.query( 

844 sqlalchemy.sql.select( 

845 [aRepr.isNull().label("f")] 

846 ).order_by( 

847 aTable.columns.id 

848 ) 

849 ).fetchall() 

850 ] 

851 ) 

852 bRepr = tsRepr.fromSelectable(bTable) 

853 self.assertEqual( 

854 [False for row in bRows], 

855 [ 

856 row["f"] for row in db.query( 

857 sqlalchemy.sql.select( 

858 [bRepr.isNull().label("f")] 

859 ).order_by( 

860 bTable.columns.id 

861 ) 

862 ).fetchall() 

863 ] 

864 ) 

865 # Test overlap expressions that relate in-database A timespans to 

866 # Python-literal B timespans; check that this is consistent with 

867 # Python-only overlap tests. 

868 for bRow in bRows: 

869 with self.subTest(bRow=bRow): 

870 expected = {} 

871 for aRow in aRows: 

872 if aRow[tsRepr.NAME] is None: 

873 expected[aRow["id"]] = None 

874 else: 

875 expected[aRow["id"]] = aRow[tsRepr.NAME].overlaps(bRow[tsRepr.NAME]) 

876 sql = sqlalchemy.sql.select( 

877 [aTable.columns.id.label("a"), aRepr.overlaps(bRow[tsRepr.NAME]).label("f")] 

878 ).select_from(aTable) 

879 queried = {row["a"]: row["f"] for row in db.query(sql).fetchall()} 

880 self.assertEqual(expected, queried) 

881 # Test overlap expressions that relate in-database A timespans to 

882 # in-database B timespans; check that this is consistent with 

883 # Python-only overlap tests. 

884 expected = { 

885 (aRow["id"], bRow["id"]): (aRow[tsRepr.NAME].overlaps(bRow[tsRepr.NAME]) 

886 if aRow[tsRepr.NAME] is not None else None) 

887 for aRow, bRow in itertools.product(aRows, bRows) 

888 } 

889 sql = sqlalchemy.sql.select( 

890 [ 

891 aTable.columns.id.label("a"), 

892 bTable.columns.id.label("b"), 

893 aRepr.overlaps(bRepr).label("f") 

894 ] 

895 ).select_from(aTable.join(bTable, onclause=sqlalchemy.sql.literal(True))) 

896 queried = {(row["a"], row["b"]): row["f"] for row in db.query(sql).fetchall()} 

897 self.assertEqual(expected, queried)