Coverage for python/lsst/daf/butler/registry/tests/_database.py : 10%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["DatabaseTests"]
25from abc import ABC, abstractmethod
26import asyncio
27from collections import namedtuple
28from concurrent.futures import ThreadPoolExecutor
29from typing import ContextManager, Iterable, Set, Tuple
30import warnings
32import sqlalchemy
34from lsst.sphgeom import ConvexPolygon, UnitVector3d
35from ..interfaces import (
36 Database,
37 ReadOnlyDatabaseError,
38 DatabaseConflictError,
39 SchemaAlreadyDefinedError
40)
41from ...core import ddl
43StaticTablesTuple = namedtuple("StaticTablesTuple", ["a", "b", "c"])
45STATIC_TABLE_SPECS = StaticTablesTuple(
46 a=ddl.TableSpec(
47 fields=[
48 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, primaryKey=True),
49 ddl.FieldSpec("region", dtype=ddl.Base64Region, nbytes=128),
50 ]
51 ),
52 b=ddl.TableSpec(
53 fields=[
54 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
55 ddl.FieldSpec("name", dtype=sqlalchemy.String, length=16, nullable=False),
56 ddl.FieldSpec("value", dtype=sqlalchemy.SmallInteger, nullable=True),
57 ],
58 unique=[("name",)],
59 ),
60 c=ddl.TableSpec(
61 fields=[
62 ddl.FieldSpec("id", dtype=sqlalchemy.BigInteger, autoincrement=True, primaryKey=True),
63 ddl.FieldSpec("origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
64 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, nullable=True),
65 ],
66 foreignKeys=[
67 ddl.ForeignKeySpec("b", source=("b_id",), target=("id",), onDelete="SET NULL"),
68 ]
69 ),
70)
72DYNAMIC_TABLE_SPEC = ddl.TableSpec(
73 fields=[
74 ddl.FieldSpec("c_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
75 ddl.FieldSpec("c_origin", dtype=sqlalchemy.BigInteger, primaryKey=True),
76 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, nullable=False),
77 ],
78 foreignKeys=[
79 ddl.ForeignKeySpec("c", source=("c_id", "c_origin"), target=("id", "origin"), onDelete="CASCADE"),
80 ddl.ForeignKeySpec("a", source=("a_name",), target=("name",), onDelete="CASCADE"),
81 ]
82)
84TEMPORARY_TABLE_SPEC = ddl.TableSpec(
85 fields=[
86 ddl.FieldSpec("a_name", dtype=sqlalchemy.String, length=16, primaryKey=True),
87 ddl.FieldSpec("b_id", dtype=sqlalchemy.BigInteger, primaryKey=True),
88 ],
89)
92class DatabaseTests(ABC):
93 """Generic tests for the `Database` interface that can be subclassed to
94 generate tests for concrete implementations.
95 """
97 @abstractmethod
98 def makeEmptyDatabase(self, origin: int = 0) -> Database:
99 """Return an empty `Database` with the given origin, or an
100 automatically-generated one if ``origin`` is `None`.
101 """
102 raise NotImplementedError()
104 @abstractmethod
105 def asReadOnly(self, database: Database) -> ContextManager[Database]:
106 """Return a context manager for a read-only connection into the given
107 database.
109 The original database should be considered unusable within the context
110 but safe to use again afterwards (this allows the context manager to
111 block write access by temporarily changing user permissions to really
112 guarantee that write operations are not performed).
113 """
114 raise NotImplementedError()
116 @abstractmethod
117 def getNewConnection(self, database: Database, *, writeable: bool) -> Database:
118 """Return a new `Database` instance that points to the same underlying
119 storage as the given one.
120 """
121 raise NotImplementedError()
123 def checkTable(self, spec: ddl.TableSpec, table: sqlalchemy.schema.Table):
124 self.assertCountEqual(spec.fields.names, table.columns.keys())
125 # Checking more than this currently seems fragile, as it might restrict
126 # what Database implementations do; we don't care if the spec is
127 # actually preserved in terms of types and constraints as long as we
128 # can use the returned table as if it was.
130 def checkStaticSchema(self, tables: StaticTablesTuple):
131 self.checkTable(STATIC_TABLE_SPECS.a, tables.a)
132 self.checkTable(STATIC_TABLE_SPECS.b, tables.b)
133 self.checkTable(STATIC_TABLE_SPECS.c, tables.c)
135 def testDeclareStaticTables(self):
136 """Tests for `Database.declareStaticSchema` and the methods it
137 delegates to.
138 """
139 # Create the static schema in a new, empty database.
140 newDatabase = self.makeEmptyDatabase()
141 with newDatabase.declareStaticTables(create=True) as context:
142 tables = context.addTableTuple(STATIC_TABLE_SPECS)
143 self.checkStaticSchema(tables)
144 # Check that we can load that schema even from a read-only connection.
145 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
146 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
147 tables = context.addTableTuple(STATIC_TABLE_SPECS)
148 self.checkStaticSchema(tables)
150 def testDeclareStaticTablesTwice(self):
151 """Tests for `Database.declareStaticSchema` being called twice.
152 """
153 # Create the static schema in a new, empty database.
154 newDatabase = self.makeEmptyDatabase()
155 with newDatabase.declareStaticTables(create=True) as context:
156 tables = context.addTableTuple(STATIC_TABLE_SPECS)
157 self.checkStaticSchema(tables)
158 # Second time it should raise
159 with self.assertRaises(SchemaAlreadyDefinedError):
160 with newDatabase.declareStaticTables(create=True) as context:
161 tables = context.addTableTuple(STATIC_TABLE_SPECS)
162 # Check schema, it should still contain all tables, and maybe some
163 # extra.
164 with newDatabase.declareStaticTables(create=False) as context:
165 self.assertLessEqual(frozenset(STATIC_TABLE_SPECS._fields), context._tableNames)
167 def testRepr(self):
168 """Test that repr does not return a generic thing."""
169 newDatabase = self.makeEmptyDatabase()
170 rep = repr(newDatabase)
171 # Check that stringification works and gives us something different
172 self.assertNotEqual(rep, str(newDatabase))
173 self.assertNotIn("object at 0x", rep, "Check default repr was not used")
174 self.assertIn("://", rep)
176 def testDynamicTables(self):
177 """Tests for `Database.ensureTableExists` and
178 `Database.getExistingTable`.
179 """
180 # Need to start with the static schema.
181 newDatabase = self.makeEmptyDatabase()
182 with newDatabase.declareStaticTables(create=True) as context:
183 context.addTableTuple(STATIC_TABLE_SPECS)
184 # Try to ensure the dyamic table exists in a read-only version of that
185 # database, which should fail because we can't create it.
186 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
187 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
188 context.addTableTuple(STATIC_TABLE_SPECS)
189 with self.assertRaises(ReadOnlyDatabaseError):
190 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
191 # Just getting the dynamic table before it exists should return None.
192 self.assertIsNone(newDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
193 # Ensure the new table exists back in the original database, which
194 # should create it.
195 table = newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
196 self.checkTable(DYNAMIC_TABLE_SPEC, table)
197 # Ensuring that it exists should just return the exact same table
198 # instance again.
199 self.assertIs(newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC), table)
200 # Try again from the read-only database.
201 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
202 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
203 context.addTableTuple(STATIC_TABLE_SPECS)
204 # Just getting the dynamic table should now work...
205 self.assertIsNotNone(existingReadOnlyDatabase.getExistingTable("d", DYNAMIC_TABLE_SPEC))
206 # ...as should ensuring that it exists, since it now does.
207 existingReadOnlyDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
208 self.checkTable(DYNAMIC_TABLE_SPEC, table)
209 # Trying to get the table with a different specification (at least
210 # in terms of what columns are present) should raise.
211 with self.assertRaises(DatabaseConflictError):
212 newDatabase.ensureTableExists(
213 "d",
214 ddl.TableSpec(
215 fields=[ddl.FieldSpec("name", dtype=sqlalchemy.String, length=4, primaryKey=True)]
216 )
217 )
218 # Calling ensureTableExists inside a transaction block is an error,
219 # even if it would do nothing.
220 with newDatabase.transaction():
221 with self.assertRaises(AssertionError):
222 newDatabase.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
224 def testTemporaryTables(self):
225 """Tests for `Database.makeTemporaryTable`,
226 `Database.dropTemporaryTable`, and `Database.insert` with
227 the ``select`` argument.
228 """
229 # Need to start with the static schema; also insert some test data.
230 newDatabase = self.makeEmptyDatabase()
231 with newDatabase.declareStaticTables(create=True) as context:
232 static = context.addTableTuple(STATIC_TABLE_SPECS)
233 newDatabase.insert(static.a,
234 {"name": "a1", "region": None},
235 {"name": "a2", "region": None})
236 bIds = newDatabase.insert(static.b,
237 {"name": "b1", "value": 11},
238 {"name": "b2", "value": 12},
239 {"name": "b3", "value": 13},
240 returnIds=True)
241 # Create the table.
242 table1 = newDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC, "e1")
243 self.checkTable(TEMPORARY_TABLE_SPEC, table1)
244 # Insert via a INSERT INTO ... SELECT query.
245 newDatabase.insert(
246 table1,
247 select=sqlalchemy.sql.select(
248 [static.a.columns.name.label("a_name"), static.b.columns.id.label("b_id")]
249 ).select_from(
250 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
251 ).where(
252 sqlalchemy.sql.and_(
253 static.a.columns.name == "a1",
254 static.b.columns.value <= 12,
255 )
256 )
257 )
258 # Check that the inserted rows are present.
259 self.assertCountEqual(
260 [{"a_name": "a1", "b_id": bId} for bId in bIds[:2]],
261 [dict(row) for row in newDatabase.query(table1.select())]
262 )
263 # Create another one via a read-only connection to the database.
264 # We _do_ allow temporary table modifications in read-only databases.
265 with self.asReadOnly(newDatabase) as existingReadOnlyDatabase:
266 with existingReadOnlyDatabase.declareStaticTables(create=False) as context:
267 context.addTableTuple(STATIC_TABLE_SPECS)
268 table2 = existingReadOnlyDatabase.makeTemporaryTable(TEMPORARY_TABLE_SPEC)
269 self.checkTable(TEMPORARY_TABLE_SPEC, table2)
270 # Those tables should not be the same, despite having the same ddl.
271 self.assertIsNot(table1, table2)
272 # Do a slightly different insert into this table, to check that
273 # it works in a read-only database. This time we pass column
274 # names as a kwarg to insert instead of by labeling the columns in
275 # the select.
276 existingReadOnlyDatabase.insert(
277 table2,
278 select=sqlalchemy.sql.select(
279 [static.a.columns.name, static.b.columns.id]
280 ).select_from(
281 static.a.join(static.b, onclause=sqlalchemy.sql.literal(True))
282 ).where(
283 sqlalchemy.sql.and_(
284 static.a.columns.name == "a2",
285 static.b.columns.value >= 12,
286 )
287 ),
288 names=["a_name", "b_id"],
289 )
290 # Check that the inserted rows are present.
291 self.assertCountEqual(
292 [{"a_name": "a2", "b_id": bId} for bId in bIds[1:]],
293 [dict(row) for row in existingReadOnlyDatabase.query(table2.select())]
294 )
295 # Drop the temporary table from the read-only DB. It's unspecified
296 # whether attempting to use it after this point is an error or just
297 # never returns any results, so we can't test what it does, only
298 # that it's not an error.
299 existingReadOnlyDatabase.dropTemporaryTable(table2)
300 # Drop the original temporary table.
301 newDatabase.dropTemporaryTable(table1)
303 def testSchemaSeparation(self):
304 """Test that creating two different `Database` instances allows us
305 to create different tables with the same name in each.
306 """
307 db1 = self.makeEmptyDatabase(origin=1)
308 with db1.declareStaticTables(create=True) as context:
309 tables = context.addTableTuple(STATIC_TABLE_SPECS)
310 self.checkStaticSchema(tables)
312 db2 = self.makeEmptyDatabase(origin=2)
313 # Make the DDL here intentionally different so we'll definitely
314 # notice if db1 and db2 are pointing at the same schema.
315 spec = ddl.TableSpec(fields=[ddl.FieldSpec("id", dtype=sqlalchemy.Integer, primaryKey=True)])
316 with db2.declareStaticTables(create=True) as context:
317 # Make the DDL here intentionally different so we'll definitely
318 # notice if db1 and db2 are pointing at the same schema.
319 table = context.addTable("a", spec)
320 self.checkTable(spec, table)
322 def testInsertQueryDelete(self):
323 """Test the `Database.insert`, `Database.query`, and `Database.delete`
324 methods, as well as the `Base64Region` type and the ``onDelete``
325 argument to `ddl.ForeignKeySpec`.
326 """
327 db = self.makeEmptyDatabase(origin=1)
328 with db.declareStaticTables(create=True) as context:
329 tables = context.addTableTuple(STATIC_TABLE_SPECS)
330 # Insert a single, non-autoincrement row that contains a region and
331 # query to get it back.
332 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
333 row = {"name": "a1", "region": region}
334 db.insert(tables.a, row)
335 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row])
336 # Insert multiple autoincrement rows but do not try to get the IDs
337 # back immediately.
338 db.insert(tables.b, {"name": "b1", "value": 10}, {"name": "b2", "value": 20})
339 results = [dict(r) for r in db.query(tables.b.select().order_by("id")).fetchall()]
340 self.assertEqual(len(results), 2)
341 for row in results:
342 self.assertIn(row["name"], ("b1", "b2"))
343 self.assertIsInstance(row["id"], int)
344 self.assertGreater(results[1]["id"], results[0]["id"])
345 # Insert multiple autoincrement rows and get the IDs back from insert.
346 rows = [{"name": "b3", "value": 30}, {"name": "b4", "value": 40}]
347 ids = db.insert(tables.b, *rows, returnIds=True)
348 results = [
349 dict(r) for r in db.query(
350 tables.b.select().where(tables.b.columns.id > results[1]["id"])
351 ).fetchall()
352 ]
353 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
354 self.assertCountEqual(results, expected)
355 self.assertTrue(all(result["id"] is not None for result in results))
356 # Insert multiple rows into a table with an autoincrement+origin
357 # primary key, then use the returned IDs to insert into a dynamic
358 # table.
359 rows = [{"origin": db.origin, "b_id": results[0]["id"]},
360 {"origin": db.origin, "b_id": None}]
361 ids = db.insert(tables.c, *rows, returnIds=True)
362 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
363 expected = [dict(row, id=id) for row, id in zip(rows, ids)]
364 self.assertCountEqual(results, expected)
365 self.assertTrue(all(result["id"] is not None for result in results))
366 # Add the dynamic table.
367 d = db.ensureTableExists("d", DYNAMIC_TABLE_SPEC)
368 # Insert into it.
369 rows = [{"c_origin": db.origin, "c_id": id, "a_name": "a1"} for id in ids]
370 db.insert(d, *rows)
371 results = [dict(r) for r in db.query(d.select()).fetchall()]
372 self.assertCountEqual(rows, results)
373 # Insert multiple rows into a table with an autoincrement+origin
374 # primary key (this is especially tricky for SQLite, but good to test
375 # for all DBs), but pass in a value for the autoincrement key.
376 # For extra complexity, we re-use the autoincrement value with a
377 # different value for origin.
378 rows2 = [{"id": 700, "origin": db.origin, "b_id": None},
379 {"id": 700, "origin": 60, "b_id": None},
380 {"id": 1, "origin": 60, "b_id": None}]
381 db.insert(tables.c, *rows2)
382 results = [dict(r) for r in db.query(tables.c.select()).fetchall()]
383 self.assertCountEqual(results, expected + rows2)
384 self.assertTrue(all(result["id"] is not None for result in results))
386 # Define 'SELECT COUNT(*)' query for later use.
387 count = sqlalchemy.sql.select([sqlalchemy.sql.func.count()])
388 # Get the values we inserted into table b.
389 bValues = [dict(r) for r in db.query(tables.b.select()).fetchall()]
390 # Remove two row from table b by ID.
391 n = db.delete(tables.b, ["id"], {"id": bValues[0]["id"]}, {"id": bValues[1]["id"]})
392 self.assertEqual(n, 2)
393 # Remove the other two rows from table b by name.
394 n = db.delete(tables.b, ["name"], {"name": bValues[2]["name"]}, {"name": bValues[3]["name"]})
395 self.assertEqual(n, 2)
396 # There should now be no rows in table b.
397 self.assertEqual(
398 db.query(count.select_from(tables.b)).scalar(),
399 0
400 )
401 # All b_id values in table c should now be NULL, because there's an
402 # onDelete='SET NULL' foreign key.
403 self.assertEqual(
404 db.query(count.select_from(tables.c).where(tables.c.columns.b_id != None)).scalar(), # noqa:E711
405 0
406 )
407 # Remove all rows in table a (there's only one); this should remove all
408 # rows in d due to onDelete='CASCADE'.
409 n = db.delete(tables.a, [])
410 self.assertEqual(n, 1)
411 self.assertEqual(db.query(count.select_from(tables.a)).scalar(), 0)
412 self.assertEqual(db.query(count.select_from(d)).scalar(), 0)
414 def testUpdate(self):
415 """Tests for `Database.update`.
416 """
417 db = self.makeEmptyDatabase(origin=1)
418 with db.declareStaticTables(create=True) as context:
419 tables = context.addTableTuple(STATIC_TABLE_SPECS)
420 # Insert two rows into table a, both without regions.
421 db.insert(tables.a, {"name": "a1"}, {"name": "a2"})
422 # Update one of the rows with a region.
423 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
424 n = db.update(tables.a, {"name": "k"}, {"k": "a2", "region": region})
425 self.assertEqual(n, 1)
426 sql = sqlalchemy.sql.select([tables.a.columns.name, tables.a.columns.region]).select_from(tables.a)
427 self.assertCountEqual(
428 [dict(r) for r in db.query(sql).fetchall()],
429 [{"name": "a1", "region": None}, {"name": "a2", "region": region}]
430 )
432 def testSync(self):
433 """Tests for `Database.sync`.
434 """
435 db = self.makeEmptyDatabase(origin=1)
436 with db.declareStaticTables(create=True) as context:
437 tables = context.addTableTuple(STATIC_TABLE_SPECS)
438 # Insert a row with sync, because it doesn't exist yet.
439 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
440 self.assertTrue(inserted)
441 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
442 [dict(r) for r in db.query(tables.b.select()).fetchall()])
443 # Repeat that operation, which should do nothing but return the
444 # requested values.
445 values, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10}, returning=["id"])
446 self.assertFalse(inserted)
447 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
448 [dict(r) for r in db.query(tables.b.select()).fetchall()])
449 # Repeat the operation without the 'extra' arg, which should also just
450 # return the existing row.
451 values, inserted = db.sync(tables.b, keys={"name": "b1"}, returning=["id"])
452 self.assertFalse(inserted)
453 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
454 [dict(r) for r in db.query(tables.b.select()).fetchall()])
455 # Repeat the operation with a different value in 'extra'. That still
456 # shouldn't be an error, because 'extra' is only used if we really do
457 # insert. Also drop the 'returning' argument.
458 _, inserted = db.sync(tables.b, keys={"name": "b1"}, extra={"value": 20})
459 self.assertFalse(inserted)
460 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
461 [dict(r) for r in db.query(tables.b.select()).fetchall()])
462 # Repeat the operation with the correct value in 'compared' instead of
463 # 'extra'.
464 _, inserted = db.sync(tables.b, keys={"name": "b1"}, compared={"value": 10})
465 self.assertFalse(inserted)
466 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
467 [dict(r) for r in db.query(tables.b.select()).fetchall()])
468 # Repeat the operation with an incorrect value in 'compared'; this
469 # should raise.
470 with self.assertRaises(DatabaseConflictError):
471 db.sync(tables.b, keys={"name": "b1"}, compared={"value": 20})
472 # Try to sync inside a transaction. That's always an error, regardless
473 # of whether there would be an insertion or not.
474 with self.assertRaises(AssertionError):
475 with db.transaction():
476 db.sync(tables.b, keys={"name": "b1"}, extra={"value": 10})
477 with self.assertRaises(AssertionError):
478 with db.transaction():
479 db.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
480 # Try to sync in a read-only database. This should work if and only
481 # if the matching row already exists.
482 with self.asReadOnly(db) as rodb:
483 with rodb.declareStaticTables(create=False) as context:
484 tables = context.addTableTuple(STATIC_TABLE_SPECS)
485 _, inserted = rodb.sync(tables.b, keys={"name": "b1"})
486 self.assertFalse(inserted)
487 self.assertEqual([{"id": values["id"], "name": "b1", "value": 10}],
488 [dict(r) for r in rodb.query(tables.b.select()).fetchall()])
489 with self.assertRaises(ReadOnlyDatabaseError):
490 rodb.sync(tables.b, keys={"name": "b2"}, extra={"value": 20})
492 def testReplace(self):
493 """Tests for `Database.replace`.
494 """
495 db = self.makeEmptyDatabase(origin=1)
496 with db.declareStaticTables(create=True) as context:
497 tables = context.addTableTuple(STATIC_TABLE_SPECS)
498 # Use 'replace' to insert a single row that contains a region and
499 # query to get it back.
500 region = ConvexPolygon((UnitVector3d(1, 0, 0), UnitVector3d(0, 1, 0), UnitVector3d(0, 0, 1)))
501 row1 = {"name": "a1", "region": region}
502 db.replace(tables.a, row1)
503 self.assertEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1])
504 # Insert another row without a region.
505 row2 = {"name": "a2", "region": None}
506 db.replace(tables.a, row2)
507 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
508 # Use replace to re-insert both of those rows again, which should do
509 # nothing.
510 db.replace(tables.a, row1, row2)
511 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2])
512 # Replace row1 with a row with no region, while reinserting row2.
513 row1a = {"name": "a1", "region": None}
514 db.replace(tables.a, row1a, row2)
515 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1a, row2])
516 # Replace both rows, returning row1 to its original state, while adding
517 # a new one. Pass them in in a different order.
518 row2a = {"name": "a2", "region": region}
519 row3 = {"name": "a3", "region": None}
520 db.replace(tables.a, row3, row2a, row1)
521 self.assertCountEqual([dict(r) for r in db.query(tables.a.select()).fetchall()], [row1, row2a, row3])
523 def testTransactionNesting(self):
524 """Test that transactions can be nested with the behavior in the
525 presence of exceptions working as documented.
526 """
527 db = self.makeEmptyDatabase(origin=1)
528 with db.declareStaticTables(create=True) as context:
529 tables = context.addTableTuple(STATIC_TABLE_SPECS)
530 # Insert one row so we can trigger integrity errors by trying to insert
531 # a duplicate of it below.
532 db.insert(tables.a, {"name": "a1"})
533 # First test: error recovery via explicit savepoint=True in the inner
534 # transaction.
535 with db.transaction():
536 # This insert should succeed, and should not be rolled back because
537 # the assertRaises context should catch any exception before it
538 # propagates up to the outer transaction.
539 db.insert(tables.a, {"name": "a2"})
540 with self.assertRaises(sqlalchemy.exc.IntegrityError):
541 with db.transaction(savepoint=True):
542 # This insert should succeed, but should be rolled back.
543 db.insert(tables.a, {"name": "a4"})
544 # This insert should fail (duplicate primary key), raising
545 # an exception.
546 db.insert(tables.a, {"name": "a1"})
547 self.assertCountEqual(
548 [dict(r) for r in db.query(tables.a.select()).fetchall()],
549 [{"name": "a1", "region": None}, {"name": "a2", "region": None}],
550 )
551 # Second test: error recovery via implicit savepoint=True, when the
552 # innermost transaction is inside a savepoint=True transaction.
553 with db.transaction():
554 # This insert should succeed, and should not be rolled back
555 # because the assertRaises context should catch any
556 # exception before it propagates up to the outer
557 # transaction.
558 db.insert(tables.a, {"name": "a3"})
559 with self.assertRaises(sqlalchemy.exc.IntegrityError):
560 with db.transaction(savepoint=True):
561 # This insert should succeed, but should be rolled back.
562 db.insert(tables.a, {"name": "a4"})
563 with db.transaction():
564 # This insert should succeed, but should be rolled
565 # back.
566 db.insert(tables.a, {"name": "a5"})
567 # This insert should fail (duplicate primary key),
568 # raising an exception.
569 db.insert(tables.a, {"name": "a1"})
570 self.assertCountEqual(
571 [dict(r) for r in db.query(tables.a.select()).fetchall()],
572 [{"name": "a1", "region": None}, {"name": "a2", "region": None}, {"name": "a3", "region": None}],
573 )
575 def testTransactionLocking(self):
576 """Test that `Database.transaction` can be used to acquire a lock
577 that prohibits concurrent writes.
578 """
579 db1 = self.makeEmptyDatabase(origin=1)
580 with db1.declareStaticTables(create=True) as context:
581 tables1 = context.addTableTuple(STATIC_TABLE_SPECS)
583 async def side1(lock: Iterable[str] = ()) -> Tuple[Set[str], Set[str]]:
584 """One side of the concurrent locking test.
586 This optionally locks the table (and maybe the whole database),
587 does a select for its contents, inserts a new row, and then selects
588 again, with some waiting in between to make sure the other side has
589 a chance to _attempt_ to insert in between. If the locking is
590 enabled and works, the difference between the selects should just
591 be the insert done on this thread.
592 """
593 # Give Side2 a chance to create a connection
594 await asyncio.sleep(1.0)
595 with db1.transaction(lock=lock):
596 names1 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
597 # Give Side2 a chance to insert (which will be blocked if
598 # we've acquired a lock).
599 await asyncio.sleep(2.0)
600 db1.insert(tables1.a, {"name": "a1"})
601 names2 = {row["name"] for row in db1.query(tables1.a.select()).fetchall()}
602 return names1, names2
604 async def side2() -> None:
605 """The other side of the concurrent locking test.
607 This side just waits a bit and then tries to insert a row into the
608 table that the other side is trying to lock. Hopefully that
609 waiting is enough to give the other side a chance to acquire the
610 lock and thus make this side block until the lock is released. If
611 this side manages to do the insert before side1 acquires the lock,
612 we'll just warn about not succeeding at testing the locking,
613 because we can only make that unlikely, not impossible.
614 """
615 def toRunInThread():
616 """SQLite locking isn't asyncio-friendly unless we actually
617 run it in another thread. And SQLite gets very unhappy if
618 we try to use a connection from multiple threads, so we have
619 to create the new connection here instead of out in the main
620 body of the test function.
621 """
622 db2 = self.getNewConnection(db1, writeable=True)
623 with db2.declareStaticTables(create=False) as context:
624 tables2 = context.addTableTuple(STATIC_TABLE_SPECS)
625 with db2.transaction():
626 db2.insert(tables2.a, {"name": "a2"})
628 await asyncio.sleep(2.0)
629 loop = asyncio.get_running_loop()
630 with ThreadPoolExecutor() as pool:
631 await loop.run_in_executor(pool, toRunInThread)
633 async def testProblemsWithNoLocking() -> None:
634 """Run side1 and side2 with no locking, attempting to demonstrate
635 the problem that locking is supposed to solve. If we get unlucky
636 with scheduling, side2 will just happen to insert after side1 is
637 done, and we won't have anything definitive. We just warn in that
638 case because we really don't want spurious test failures.
639 """
640 task1 = asyncio.create_task(side1())
641 task2 = asyncio.create_task(side2())
643 names1, names2 = await task1
644 await task2
645 if "a2" in names1:
646 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
647 "happened before first SELECT.")
648 self.assertEqual(names1, {"a2"})
649 self.assertEqual(names2, {"a1", "a2"})
650 elif "a2" not in names2:
651 warnings.warn("Unlucky scheduling in no-locking test: concurrent INSERT "
652 "happened after second SELECT even without locking.")
653 self.assertEqual(names1, set())
654 self.assertEqual(names2, {"a1"})
655 else:
656 # This is the expected case: both INSERTS happen between the
657 # two SELECTS. If we don't get this almost all of the time we
658 # should adjust the sleep amounts.
659 self.assertEqual(names1, set())
660 self.assertEqual(names2, {"a1", "a2"})
662 asyncio.run(testProblemsWithNoLocking())
664 # Clean up after first test.
665 db1.delete(tables1.a, ["name"], {"name": "a1"}, {"name": "a2"})
667 async def testSolutionWithLocking() -> None:
668 """Run side1 and side2 with locking, which should make side2 block
669 its insert until side2 releases its lock.
670 """
671 task1 = asyncio.create_task(side1(lock=[tables1.a]))
672 task2 = asyncio.create_task(side2())
674 names1, names2 = await task1
675 await task2
676 if "a2" in names1:
677 warnings.warn("Unlucky scheduling in locking test: concurrent INSERT "
678 "happened before first SELECT.")
679 self.assertEqual(names1, {"a2"})
680 self.assertEqual(names2, {"a1", "a2"})
681 else:
682 # This is the expected case: the side2 INSERT happens after the
683 # last SELECT on side1. This can also happen due to unlucky
684 # scheduling, and we have no way to detect that here, but the
685 # similar "no-locking" test has at least some chance of being
686 # affected by the same problem and warning about it.
687 self.assertEqual(names1, set())
688 self.assertEqual(names2, {"a1"})
690 asyncio.run(testSolutionWithLocking())