Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["DatabaseTests"] 

24 

25from abc import ABC, abstractmethod 

26from collections import namedtuple 

27from typing import ContextManager 

28 

29import sqlalchemy 

30 

31from lsst.sphgeom import ConvexPolygon, UnitVector3d 

32from ..interfaces import ( 

33 Database, 

34 ReadOnlyDatabaseError, 

35 DatabaseConflictError, 

36 SchemaAlreadyDefinedError 

37) 

38from ...core import ddl 

39 

40StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"]) 

41 

42STATIC_TABLE_SPECS = StaticTablesTuple( 

43 a=ddl.TableSpec( 

44 fields=[ 

45 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True), 

46 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128), 

47 ] 

48 ), 

49 b=ddl.TableSpec( 

50 fields=[ 

51 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

52 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False), 

53 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True), 

54 ], 

55 unique=[("name",)], 

56 ), 

57 c=ddl.TableSpec( 

58 fields=[ 

59 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True), 

60 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

61 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True), 

62 ], 

63 foreignKeys=[ 

64 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"), 

65 ] 

66 ), 

67) 

68 

69DYNAMIC_TABLE_SPEC = ddl.TableSpec( 

70 fields=[ 

71 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True), 

72 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True), 

73 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False), 

74 ], 

75 foreignKeys=[ 

76 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"), 

77 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"), 

78 ] 

79) 

80 

81 

82class DatabaseTests(ABC): 

83 """Generic tests for the `Database` interface that can be subclassed to 

84 generate tests for concrete implementations. 

85 """ 

86 

87 @abstractmethod 

88 def makeEmptyDatabase(self, origin: int = 0) -> Database: 

89 """Return an empty `Database` with the given origin, or an 

90 automatically-generated one if ``origin`` is `None`. 

91 """ 

92 raise NotImplementedError() 

93 

94 @abstractmethod 

95 def asReadOnly(self, database: Database) -> ContextManager[Database]: 

96 """Return a context manager for a read-only connection into the given 

97 database. 

98 

99 The original database should be considered unusable within the context 

100 but safe to use again afterwards (this allows the context manager to 

101 block write access by temporarily changing user permissions to really 

102 guarantee that write operations are not performed). 

103 """ 

104 raise NotImplementedError() 

105 

106 @abstractmethod 

107 def getNewConnection(self, database: Database, *, writeable: bool) -> Database: 

108 """Return a new `Database` instance that points to the same underlying 

109 storage as the given one. 

110 """ 

111 raise NotImplementedError() 

112 

113 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table): 

114 self.assertCountEqual(spec.fields.names, table.columns.keys()) 

115 # Checking more than this currently seems fragile, as it might restrict 

116 # what Database implementations do; we don't care if the spec is 

117 # actually preserved in terms of types and constraints as long as we 

118 # can use the returned table as if it was. 

119 

120 def checkStaticSchema(self, tables: StaticTablesTuple): 

121 self.checkTable(STATIC_TABLE_SPECS.a, tables.a) 

122 self.checkTable(STATIC_TABLE_SPECS.b, tables.b) 

123 self.checkTable(STATIC_TABLE_SPECS.c, tables.c) 

124 

125 def testDeclareStaticTables(self): 

126 """Tests for `Database.declareStaticSchema` and the methods it 

127 delegates to. 

128 """ 

129 # Create the static schema in a new, empty database. 

130 newDatabase = self.makeEmptyDatabase() 

131 with newDatabase.declareStaticTables(create=True) as context: 

132 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

133 self.checkStaticSchema(tables) 

134 # Check that we can load that schema even from a read-only connection. 

135 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

136 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

137 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

138 self.checkStaticSchema(tables) 

139 

140 def testDeclareStaticTablesTwice(self): 

141 """Tests for `Database.declareStaticSchema` being called twice. 

142 """ 

143 # Create the static schema in a new, empty database. 

144 newDatabase = self.makeEmptyDatabase() 

145 with newDatabase.declareStaticTables(create=True) as context: 

146 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

147 self.checkStaticSchema(tables) 

148 # Second time it should raise 

149 with self.assertRaises(SchemaAlreadyDefinedError): 

150 with newDatabase.declareStaticTables(create=True) as context: 

151 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

152 # Check schema, it should still contain all tables, and maybe some 

153 # extra. 

154 with newDatabase.declareStaticTables(create=False) as context: 

155 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames) 

156 

157 def testRepr(self): 

158 """Test that repr does not return a generic thing.""" 

159 newDatabase = self.makeEmptyDatabase() 

160 rep = repr(newDatabase) 

161 # Check that stringification works and gives us something different 

162 self.assertNotEqual(rep, str(newDatabase)) 

163 self.assertNotIn("object at 0x", rep, "Check default repr was not used") 

164 self.assertIn("://", rep) 

165 

166 def testDynamicTables(self): 

167 """Tests for `Database.ensureTableExists` and 

168 `Database.getExistingTable`. 

169 """ 

170 # Need to start with the static schema. 

171 newDatabase = self.makeEmptyDatabase() 

172 with newDatabase.declareStaticTables(create=True) as context: 

173 context.addTableTuple(STATIC_TABLE_SPECS) 

174 # Try to ensure the dyamic table exists in a read-only version of that 

175 # database, which should fail because we can't create it. 

176 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

177 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

178 context.addTableTuple(STATIC_TABLE_SPECS) 

179 with self.assertRaises(ReadOnlyDatabaseError): 

180 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

181 # Just getting the dynamic table before it exists should return None. 

182 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

183 # Ensure the new table exists back in the original database, which 

184 # should create it. 

185 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

186 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

187 # Ensuring that it exists should just return the exact same table 

188 # instance again. 

189 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table) 

190 # Try again from the read-only database. 

191 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase: 

192 with existingReadOnlyDatabase.declareStaticTables(create=False) as context: 

193 context.addTableTuple(STATIC_TABLE_SPECS) 

194 # Just getting the dynamic table should now work... 

195 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC)) 

196 # ...as should ensuring that it exists, since it now does. 

197 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

198 self.checkTable(DYNAMIC_TABLE_SPEC, table) 

199 # Trying to get the table with a different specification (at least 

200 # in terms of what columns are present) should raise. 

201 with self.assertRaises(DatabaseConflictError): 

202 newDatabase.ensureTableExists( 

203 "d", 

204 ddl.TableSpec( 

205 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)] 

206 ) 

207 ) 

208 # Calling ensureTableExists inside a transaction block is an error, 

209 # even if it would do nothing. 

210 with newDatabase.transaction(): 

211 with self.assertRaises(AssertionError): 

212 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

213 

214 def testSchemaSeparation(self): 

215 """Test that creating two different `Database` instances allows us 

216 to create different tables with the same name in each. 

217 """ 

218 db1 = self.makeEmptyDatabase(origin=1) 

219 with db1.declareStaticTables(create=True) as context: 

220 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

221 self.checkStaticSchema(tables) 

222 

223 db2 = self.makeEmptyDatabase(origin=2) 

224 # Make the DDL here intentionally different so we'll definitely 

225 # notice if db1 and db2 are pointing at the same schema. 

226 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)]) 

227 with db2.declareStaticTables(create=True) as context: 

228 # Make the DDL here intentionally different so we'll definitely 

229 # notice if db1 and db2 are pointing at the same schema. 

230 table = context.addTable("a", spec) 

231 self.checkTable(spec, table) 

232 

233 def testInsertQueryDelete(self): 

234 """Test the `Database.insert`, `Database.query`, and `Database.delete` 

235 methods, as well as the `Base64Region` type and the ``onDelete`` 

236 argument to `ddl.ForeignKeySpec`. 

237 """ 

238 db = self.makeEmptyDatabase(origin=1) 

239 with db.declareStaticTables(create=True) as context: 

240 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

241 # Insert a single, non-autoincrement row that contains a region and 

242 # query to get it back. 

243 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

244 row = {"name": "a1", "region": region} 

245 db.insert(tables.a, row) 

246 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row]) 

247 # Insert multiple autoincrement rows but do not try to get the IDs 

248 # back immediately. 

249 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20}) 

250 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()] 

251 self.assertEqual(len(results), 2) 

252 for row in results: 

253 self.assertIn(row["name"], ("b1", "b2")) 

254 self.assertIsInstance(row["id"], int) 

255 self.assertGreater(results[1]["id"], results[0]["id"]) 

256 # Insert multiple autoincrement rows and get the IDs back from insert. 

257 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}] 

258 ids = db.insert(tables.b, *rows, returnIds=True) 

259 results = [ 

260 dict(r) for r in db.query( 

261 tables.b.select().where(tables.b.columns.id > results[1]["id"]) 

262 ).fetchall() 

263 ] 

264 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

265 self.assertCountEqual(results, expected) 

266 self.assertTrue(all(result["id"] is not None for result in results)) 

267 # Insert multiple rows into a table with an autoincrement+origin 

268 # primary key, then use the returned IDs to insert into a dynamic 

269 # table. 

270 rows = [{"origin": db.origin, "b_id": results[0]["id"]}, 

271 {"origin": db.origin, "b_id": None}] 

272 ids = db.insert(tables.c, *rows, returnIds=True) 

273 results = [dict(r) for r in db.query(tables.c.select()).fetchall()] 

274 expected = [dict(row, id=id) for row, id in zip(rows, ids)] 

275 self.assertCountEqual(results, expected) 

276 self.assertTrue(all(result["id"] is not None for result in results)) 

277 # Add the dynamic table. 

278 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC) 

279 # Insert into it. 

280 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids] 

281 db.insert(d, *rows) 

282 results = [dict(r) for r in db.query(d.select()).fetchall()] 

283 self.assertCountEqual(rows, results) 

284 # Insert multiple rows into a table with an autoincrement+origin 

285 # primary key (this is especially tricky for SQLite, but good to test 

286 # for all DBs), but pass in a value for the autoincrement key. 

287 # For extra complexity, we re-use the autoincrement value with a 

288 # different value for origin. 

289 rows2 = [{"id": 700, "origin": db.origin, "b_id": None}, 

290 {"id": 700, "origin": 60, "b_id": None}, 

291 {"id": 1, "origin": 60, "b_id": None}] 

292 db.insert(tables.c, *rows2) 

293 results = [dict(r) for r in db.query(tables.c.select()).fetchall()] 

294 self.assertCountEqual(results, expected + rows2) 

295 self.assertTrue(all(result["id"] is not None for result in results)) 

296 

297 # Define 'SELECT COUNT(*)' query for later use. 

298 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()]) 

299 # Get the values we inserted into table b. 

300 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()] 

301 # Remove two row from table b by ID. 

302 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]}) 

303 self.assertEqual(n, 2) 

304 # Remove the other two rows from table b by name. 

305 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]}) 

306 self.assertEqual(n, 2) 

307 # There should now be no rows in table b. 

308 self.assertEqual( 

309 db.query(count.select_from(tables.b)).scalar(), 

310 0 

311 ) 

312 # All b_id values in table c should now be NULL, because there's an 

313 # onDelete='SET NULL' foreign key. 

314 self.assertEqual( 

315 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711 

316 0 

317 ) 

318 # Remove all rows in table a (there's only one); this should remove all 

319 # rows in d due to onDelete='CASCADE'. 

320 n = db.delete(tables.a, []) 

321 self.assertEqual(n, 1) 

322 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0) 

323 self.assertEqual(db.query(count.select_from(d)).scalar(), 0) 

324 

325 def testUpdate(self): 

326 """Tests for `Database.update`. 

327 """ 

328 db = self.makeEmptyDatabase(origin=1) 

329 with db.declareStaticTables(create=True) as context: 

330 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

331 # Insert two rows into table a, both without regions. 

332 db.insert(tables.a, {"name": "a1"}, {"name": "a2"}) 

333 # Update one of the rows with a region. 

334 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

335 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region}) 

336 self.assertEqual(n, 1) 

337 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a) 

338 self.assertCountEqual( 

339 [dict(r) for r in db.query(sql).fetchall()], 

340 [{"name": "a1", "region": None}, {"name": "a2", "region": region}] 

341 ) 

342 

343 def testSync(self): 

344 """Tests for `Database.sync`. 

345 """ 

346 db = self.makeEmptyDatabase(origin=1) 

347 with db.declareStaticTables(create=True) as context: 

348 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

349 # Insert a row with sync, because it doesn't exist yet. 

350 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

351 self.assertTrue(inserted) 

352 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

353 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

354 # Repeat that operation, which should do nothing but return the 

355 # requested values. 

356 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"]) 

357 self.assertFalse(inserted) 

358 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

359 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

360 # Repeat the operation without the 'extra' arg, which should also just 

361 # return the existing row. 

362 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"]) 

363 self.assertFalse(inserted) 

364 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

365 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

366 # Repeat the operation with a different value in 'extra'. That still 

367 # shouldn't be an error, because 'extra' is only used if we really do 

368 # insert. Also drop the 'returning' argument. 

369 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20}) 

370 self.assertFalse(inserted) 

371 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

372 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

373 # Repeat the operation with the correct value in 'compared' instead of 

374 # 'extra'. 

375 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10}) 

376 self.assertFalse(inserted) 

377 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

378 [dict(r) for r in db.query(tables.b.select()).fetchall()]) 

379 # Repeat the operation with an incorrect value in 'compared'; this 

380 # should raise. 

381 with self.assertRaises(DatabaseConflictError): 

382 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20}) 

383 # Try to sync inside a transaction. That's always an error, regardless 

384 # of whether there would be an insertion or not. 

385 with self.assertRaises(AssertionError): 

386 with db.transaction(): 

387 db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}) 

388 with self.assertRaises(AssertionError): 

389 with db.transaction(): 

390 db.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

391 # Try to sync in a read-only database. This should work if and only 

392 # if the matching row already exists. 

393 with self.asReadOnly(db) as rodb: 

394 with rodb.declareStaticTables(create=False) as context: 

395 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

396 _, inserted = rodb.sync(tables.b, keys={"name": "b1"}) 

397 self.assertFalse(inserted) 

398 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}], 

399 [dict(r) for r in rodb.query(tables.b.select()).fetchall()]) 

400 with self.assertRaises(ReadOnlyDatabaseError): 

401 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20}) 

402 

403 def testReplace(self): 

404 """Tests for `Database.replace`. 

405 """ 

406 db = self.makeEmptyDatabase(origin=1) 

407 with db.declareStaticTables(create=True) as context: 

408 tables = context.addTableTuple(STATIC_TABLE_SPECS) 

409 # Use 'replace' to insert a single row that contains a region and 

410 # query to get it back. 

411 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1))) 

412 row1 = {"name": "a1", "region": region} 

413 db.replace(tables.a, row1) 

414 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1]) 

415 # Insert another row without a region. 

416 row2 = {"name": "a2", "region": None} 

417 db.replace(tables.a, row2) 

418 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

419 # Use replace to re-insert both of those rows again, which should do 

420 # nothing. 

421 db.replace(tables.a, row1, row2) 

422 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2]) 

423 # Replace row1 with a row with no region, while reinserting row2. 

424 row1a = {"name": "a1", "region": None} 

425 db.replace(tables.a, row1a, row2) 

426 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2]) 

427 # Replace both rows, returning row1 to its original state, while adding 

428 # a new one. Pass them in in a different order. 

429 row2a = {"name": "a2", "region": region} 

430 row3 = {"name": "a3", "region": None} 

431 db.replace(tables.a, row3, row2a, row1) 

432 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3])