Coverage for python/lsst/daf/butler/registry/collections/_base.py: 97%
238 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 02:52 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 02:52 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ()
33from abc import abstractmethod
34from collections.abc import Callable, Iterable, Iterator, Set
35from contextlib import contextmanager
36from typing import TYPE_CHECKING, Any, Generic, Literal, NamedTuple, TypeVar, cast
38import sqlalchemy
40from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError
41from ...timespan_database_representation import TimespanDatabaseRepresentation
42from .._collection_type import CollectionType
43from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
44from ..wildcards import CollectionWildcard
46if TYPE_CHECKING:
47 from .._caching_context import CachingContext
48 from ..interfaces import Database
51def _makeCollectionForeignKey(
52 sourceColumnName: str, collectionIdName: str, **kwargs: Any
53) -> ddl.ForeignKeySpec:
54 """Define foreign key specification that refers to collections table.
56 Parameters
57 ----------
58 sourceColumnName : `str`
59 Name of the column in the referring table.
60 collectionIdName : `str`
61 Name of the column in collections table that identifies it (PK).
62 **kwargs
63 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
65 Returns
66 -------
67 spec : `ddl.ForeignKeySpec`
68 Foreign key specification.
70 Notes
71 -----
72 This method assumes fixed name ("collection") of a collections table.
73 There is also a general assumption that collection primary key consists
74 of a single column.
75 """
76 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
79_T = TypeVar("_T")
82class CollectionTablesTuple(NamedTuple, Generic[_T]):
83 collection: _T
84 run: _T
85 collection_chain: _T
88def makeRunTableSpec(
89 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
90) -> ddl.TableSpec:
91 """Define specification for "run" table.
93 Parameters
94 ----------
95 collectionIdName : `str`
96 Name of the column in collections table that identifies it (PK).
97 collectionIdType : `type`
98 Type of the PK column in the collections table, one of the
99 `sqlalchemy` types.
100 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
101 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
102 timespans are stored in this database.
104 Returns
105 -------
106 spec : `ddl.TableSpec`
107 Specification for run table.
109 Notes
110 -----
111 Assumption here and in the code below is that the name of the identifying
112 column is the same in both collections and run tables. The names of
113 non-identifying columns containing run metadata are fixed.
114 """
115 result = ddl.TableSpec(
116 fields=[
117 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
118 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
119 ],
120 foreignKeys=[
121 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
122 ],
123 )
124 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
125 result.fields.add(fieldSpec)
126 return result
129def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
130 """Define specification for "collection_chain" table.
132 Parameters
133 ----------
134 collectionIdName : `str`
135 Name of the column in collections table that identifies it (PK).
136 collectionIdType : `type`
137 Type of the PK column in the collections table, one of the
138 `sqlalchemy` types.
140 Returns
141 -------
142 spec : `ddl.TableSpec`
143 Specification for collection chain table.
145 Notes
146 -----
147 Collection chain is simply an ordered one-to-many relation between
148 collections. The names of the columns in the table are fixed and
149 also hardcoded in the code below.
150 """
151 return ddl.TableSpec(
152 fields=[
153 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
154 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
155 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
156 ],
157 foreignKeys=[
158 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
159 _makeCollectionForeignKey("child", collectionIdName),
160 ],
161 )
164K = TypeVar("K")
167class DefaultCollectionManager(CollectionManager[K]):
168 """Default `CollectionManager` implementation.
170 This implementation uses record classes defined in this module and is
171 based on the same assumptions about schema outlined in the record classes.
173 Parameters
174 ----------
175 db : `Database`
176 Interface to the underlying database engine and namespace.
177 tables : `CollectionTablesTuple`
178 Named tuple of SQLAlchemy table objects.
179 collectionIdName : `str`
180 Name of the column in collections table that identifies it (PK).
181 caching_context : `CachingContext`
182 Caching context to use.
183 registry_schema_version : `VersionTuple` or `None`, optional
184 The version of the registry schema.
186 Notes
187 -----
188 Implementation uses "aggressive" pre-fetching and caching of the records
189 in memory. Memory cache is synchronized from database when `refresh`
190 method is called.
191 """
193 def __init__(
194 self,
195 db: Database,
196 tables: CollectionTablesTuple[sqlalchemy.Table],
197 collectionIdName: str,
198 *,
199 caching_context: CachingContext,
200 registry_schema_version: VersionTuple | None = None,
201 ):
202 super().__init__(registry_schema_version=registry_schema_version)
203 self._db = db
204 self._tables = tables
205 self._collectionIdName = collectionIdName
206 self._caching_context = caching_context
208 def refresh(self) -> None:
209 # Docstring inherited from CollectionManager.
210 if self._caching_context.collection_records is not None: 210 ↛ 211line 210 didn't jump to line 211, because the condition on line 210 was never true
211 self._caching_context.collection_records.clear()
213 def _fetch_all(self) -> list[CollectionRecord[K]]:
214 """Retrieve all records into cache if not done so yet."""
215 if self._caching_context.collection_records is not None:
216 if self._caching_context.collection_records.full:
217 return list(self._caching_context.collection_records.records())
218 records = self._fetch_by_key(None)
219 if self._caching_context.collection_records is not None:
220 self._caching_context.collection_records.set(records, full=True)
221 return records
223 def register(
224 self, name: str, type: CollectionType, doc: str | None = None
225 ) -> tuple[CollectionRecord[K], bool]:
226 # Docstring inherited from CollectionManager.
227 registered = False
228 record = self._getByName(name)
229 if record is None:
230 row, inserted_or_updated = self._db.sync(
231 self._tables.collection,
232 keys={"name": name},
233 compared={"type": int(type)},
234 extra={"doc": doc},
235 returning=[self._collectionIdName],
236 )
237 assert isinstance(inserted_or_updated, bool)
238 registered = inserted_or_updated
239 assert row is not None
240 collection_id = cast(K, row[self._collectionIdName])
241 if type is CollectionType.RUN:
242 TimespanReprClass = self._db.getTimespanRepresentation()
243 row, _ = self._db.sync(
244 self._tables.run,
245 keys={self._collectionIdName: collection_id},
246 returning=("host",) + TimespanReprClass.getFieldNames(),
247 )
248 assert row is not None
249 record = RunRecord[K](
250 key=collection_id,
251 name=name,
252 host=row["host"],
253 timespan=TimespanReprClass.extract(row),
254 )
255 elif type is CollectionType.CHAINED:
256 record = ChainedCollectionRecord[K](
257 key=collection_id,
258 name=name,
259 children=[],
260 )
261 else:
262 record = CollectionRecord[K](key=collection_id, name=name, type=type)
263 self._addCachedRecord(record)
264 return record, registered
266 def remove(self, name: str) -> None:
267 # Docstring inherited from CollectionManager.
268 record = self._getByName(name)
269 if record is None: 269 ↛ 270line 269 didn't jump to line 270, because the condition on line 269 was never true
270 raise MissingCollectionError(f"No collection with name '{name}' found.")
271 # This may raise
272 self._db.delete(
273 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
274 )
275 self._removeCachedRecord(record)
277 def find(self, name: str) -> CollectionRecord[K]:
278 # Docstring inherited from CollectionManager.
279 result = self._getByName(name)
280 if result is None:
281 raise MissingCollectionError(f"No collection with name '{name}' found.")
282 return result
284 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
285 """Return multiple records given their names."""
286 names = list(names)
287 # To protect against potential races in cache updates.
288 records: dict[str, CollectionRecord | None] = {}
289 if self._caching_context.collection_records is not None:
290 for name in names:
291 records[name] = self._caching_context.collection_records.get_by_name(name)
292 fetch_names = [name for name, record in records.items() if record is None]
293 else:
294 fetch_names = list(names)
295 records = {name: None for name in fetch_names}
296 if fetch_names:
297 for record in self._fetch_by_name(fetch_names):
298 records[record.name] = record
299 self._addCachedRecord(record)
300 missing_names = [name for name, record in records.items() if record is None]
301 if len(missing_names) == 1:
302 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
303 elif len(missing_names) > 1: 303 ↛ 304line 303 didn't jump to line 304, because the condition on line 303 was never true
304 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
305 return [cast(CollectionRecord[K], records[name]) for name in names]
307 def __getitem__(self, key: Any) -> CollectionRecord[K]:
308 # Docstring inherited from CollectionManager.
309 if self._caching_context.collection_records is not None:
310 if (record := self._caching_context.collection_records.get_by_key(key)) is not None:
311 return record
312 if records := self._fetch_by_key([key]): 312 ↛ 318line 312 didn't jump to line 318, because the condition on line 312 was never false
313 record = records[0]
314 if self._caching_context.collection_records is not None:
315 self._caching_context.collection_records.add(record)
316 return record
317 else:
318 raise MissingCollectionError(f"Collection with key '{key}' not found.")
320 def resolve_wildcard(
321 self,
322 wildcard: CollectionWildcard,
323 *,
324 collection_types: Set[CollectionType] = CollectionType.all(),
325 done: set[str] | None = None,
326 flatten_chains: bool = True,
327 include_chains: bool | None = None,
328 ) -> list[CollectionRecord[K]]:
329 # Docstring inherited
330 if done is None: 330 ↛ 332line 330 didn't jump to line 332, because the condition on line 330 was never false
331 done = set()
332 include_chains = include_chains if include_chains is not None else not flatten_chains
334 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]:
335 if record.name in done:
336 return
337 if record.type in collection_types:
338 done.add(record.name)
339 if record.type is not CollectionType.CHAINED or include_chains:
340 yield record
341 if flatten_chains and record.type is CollectionType.CHAINED:
342 done.add(record.name)
343 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children):
344 # flake8 can't tell that we only delete this closure when
345 # we're totally done with it.
346 yield from resolve_nested(child, done) # noqa: F821
348 result: list[CollectionRecord[K]] = []
350 if wildcard.patterns is ...:
351 for record in self._fetch_all():
352 result.extend(resolve_nested(record, done))
353 del resolve_nested
354 return result
355 if wildcard.strings:
356 for record in self._find_many(wildcard.strings):
357 result.extend(resolve_nested(record, done))
358 if wildcard.patterns:
359 for record in self._fetch_all():
360 if any(p.fullmatch(record.name) for p in wildcard.patterns):
361 result.extend(resolve_nested(record, done))
362 del resolve_nested
363 return result
365 def getDocumentation(self, key: K) -> str | None:
366 # Docstring inherited from CollectionManager.
367 sql = (
368 sqlalchemy.sql.select(self._tables.collection.columns.doc)
369 .select_from(self._tables.collection)
370 .where(self._tables.collection.columns[self._collectionIdName] == key)
371 )
372 with self._db.query(sql) as sql_result:
373 return sql_result.scalar()
375 def setDocumentation(self, key: K, doc: str | None) -> None:
376 # Docstring inherited from CollectionManager.
377 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
379 def _addCachedRecord(self, record: CollectionRecord[K]) -> None:
380 """Add single record to cache."""
381 if self._caching_context.collection_records is not None:
382 self._caching_context.collection_records.add(record)
384 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
385 """Remove single record from cache."""
386 if self._caching_context.collection_records is not None: 386 ↛ 387line 386 didn't jump to line 387, because the condition on line 386 was never true
387 self._caching_context.collection_records.discard(record)
389 def _getByName(self, name: str) -> CollectionRecord[K] | None:
390 """Find collection record given collection name."""
391 if self._caching_context.collection_records is not None:
392 if (record := self._caching_context.collection_records.get_by_name(name)) is not None:
393 return record
394 records = self._fetch_by_name([name])
395 for record in records:
396 self._addCachedRecord(record)
397 return records[0] if records else None
399 @abstractmethod
400 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
401 """Fetch collection record from database given its name."""
402 raise NotImplementedError()
404 @abstractmethod
405 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
406 """Fetch collection record from database given its key, or fetch all
407 collctions if argument is None.
408 """
409 raise NotImplementedError()
411 def update_chain(
412 self,
413 parent_collection_name: str,
414 child_collection_names: list[str],
415 allow_use_in_caching_context: bool = False,
416 ) -> None:
417 with self._modify_collection_chain(
418 parent_collection_name,
419 child_collection_names,
420 # update_chain is currently used in setCollectionChain, which is
421 # called within caching contexts. (At least in Butler.import_ and
422 # possibly other places.) So, unlike the other collection chain
423 # modification methods, it has to update the collection cache.
424 skip_caching_check=allow_use_in_caching_context,
425 ) as c:
426 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": c.parent_key})
427 self._block_for_concurrency_test()
428 self._insert_collection_chain_rows(c.parent_key, 0, c.child_keys)
430 names = [child.name for child in c.child_records]
431 record = ChainedCollectionRecord[K](c.parent_key, parent_collection_name, children=tuple(names))
432 self._addCachedRecord(record)
434 def prepend_collection_chain(
435 self, parent_collection_name: str, child_collection_names: list[str]
436 ) -> None:
437 self._add_to_collection_chain(
438 parent_collection_name, child_collection_names, self._find_prepend_position
439 )
441 def extend_collection_chain(self, parent_collection_name: str, child_collection_names: list[str]) -> None:
442 self._add_to_collection_chain(
443 parent_collection_name, child_collection_names, self._find_extend_position
444 )
446 def _add_to_collection_chain(
447 self,
448 parent_collection_name: str,
449 child_collection_names: list[str],
450 position_func: Callable[[_CollectionChainModificationContext], int],
451 ) -> None:
452 with self._modify_collection_chain(parent_collection_name, child_collection_names) as c:
453 # Remove any of the new children that are already in the
454 # collection, so they move to a new position instead of being
455 # duplicated.
456 self._remove_collection_chain_rows(c.parent_key, c.child_keys)
457 # Figure out where to insert the new children.
458 starting_position = position_func(c)
459 self._block_for_concurrency_test()
460 self._insert_collection_chain_rows(c.parent_key, starting_position, c.child_keys)
462 def remove_from_collection_chain(
463 self, parent_collection_name: str, child_collection_names: list[str]
464 ) -> None:
465 with self._modify_collection_chain(
466 parent_collection_name,
467 child_collection_names,
468 # Removing members from a chain can't create collection cycles
469 skip_cycle_check=True,
470 ) as c:
471 self._remove_collection_chain_rows(c.parent_key, c.child_keys)
473 @contextmanager
474 def _modify_collection_chain(
475 self,
476 parent_collection_name: str,
477 child_collection_names: list[str],
478 *,
479 skip_caching_check: bool = False,
480 skip_cycle_check: bool = False,
481 ) -> Iterator[_CollectionChainModificationContext[K]]:
482 if (not skip_caching_check) and self._caching_context.is_enabled:
483 # Avoid having cache-maintenance code around that is unlikely to
484 # ever be used.
485 raise RuntimeError("Chained collection modification not permitted with active caching context.")
487 if not skip_cycle_check:
488 self._sanity_check_collection_cycles(parent_collection_name, child_collection_names)
490 # Look up the collection primary keys corresponding to the
491 # user-provided list of child collection names. Because there is no
492 # locking for the child collections, it's possible for a concurrent
493 # deletion of one of the children to cause a foreign key constraint
494 # violation when we attempt to insert them in the collection chain
495 # table later.
496 child_records = self.resolve_wildcard(
497 CollectionWildcard.from_names(child_collection_names), flatten_chains=False
498 )
499 child_keys = [child.key for child in child_records]
501 with self._db.transaction():
502 # Lock the parent collection to prevent concurrent updates to the
503 # same collection chain.
504 parent_key = self._find_and_lock_collection_chain(parent_collection_name)
505 yield _CollectionChainModificationContext[K](
506 parent_key=parent_key, child_keys=child_keys, child_records=child_records
507 )
509 def _sanity_check_collection_cycles(
510 self, parent_collection_name: str, child_collection_names: list[str]
511 ) -> None:
512 """Raise an exception if any of the collections in the ``child_names``
513 list have ``parent_name`` as a child, creating a collection cycle.
515 This is only a sanity check, and does not guarantee that no collection
516 cycles are possible. Concurrent updates might allow collection cycles
517 to be inserted.
518 """
519 for record in self.resolve_wildcard(
520 CollectionWildcard.from_names(child_collection_names),
521 flatten_chains=True,
522 include_chains=True,
523 collection_types={CollectionType.CHAINED},
524 ):
525 if record.name == parent_collection_name:
526 raise CollectionCycleError(
527 f"Cycle in collection chaining when defining '{parent_collection_name}'."
528 )
530 def _insert_collection_chain_rows(
531 self,
532 parent_key: K,
533 starting_position: int,
534 child_keys: list[K],
535 ) -> None:
536 rows = [
537 {
538 "parent": parent_key,
539 "child": child,
540 "position": position,
541 }
542 for position, child in enumerate(child_keys, starting_position)
543 ]
545 # It's possible for the DB to raise an exception for the integers being
546 # out of range here. The position column is only a 16-bit number.
547 # Even if there aren't an unreasonably large number of children in the
548 # collection, a series of many deletes and insertions could cause the
549 # space to become fragmented.
550 #
551 # If this ever actually happens, we should consider doing a migration
552 # to increase the position column to a 32-bit number.
553 # To fix it in the short term, you can re-write the collection chain to
554 # defragment it by doing something like:
555 # registry.setCollectionChain(
556 # parent,
557 # registry.getCollectionChain(parent)
558 # )
559 self._db.insert(self._tables.collection_chain, *rows)
561 def _remove_collection_chain_rows(
562 self,
563 parent_key: K,
564 child_keys: list[K],
565 ) -> None:
566 table = self._tables.collection_chain
567 where = sqlalchemy.and_(table.c.parent == parent_key, table.c.child.in_(child_keys))
568 self._db.deleteWhere(table, where)
570 def _find_prepend_position(self, c: _CollectionChainModificationContext) -> int:
571 """Return the position where children can be inserted to
572 prepend them to a collection chain.
573 """
574 return self._find_position_in_collection_chain(c.parent_key, "begin") - len(c.child_keys)
576 def _find_extend_position(self, c: _CollectionChainModificationContext) -> int:
577 """Return the position where children can be inserted to append them to
578 a collection chain.
579 """
580 return self._find_position_in_collection_chain(c.parent_key, "end") + 1
582 def _find_position_in_collection_chain(self, chain_key: K, begin_or_end: Literal["begin", "end"]) -> int:
583 """Return the lowest or highest numbered position in a collection
584 chain, or 0 if the chain is empty.
585 """
586 table = self._tables.collection_chain
588 func: sqlalchemy.Function
589 match (begin_or_end):
590 case "begin":
591 func = sqlalchemy.func.min(table.c.position)
592 case "end": 592 ↛ 595line 592 didn't jump to line 595, because the pattern on line 592 always matched
593 func = sqlalchemy.func.max(table.c.position)
595 query = sqlalchemy.select(func).where(table.c.parent == chain_key)
596 with self._db.query(query) as cursor:
597 position = cursor.scalar()
599 if position is None:
600 return 0
602 return position
604 def _find_and_lock_collection_chain(self, collection_name: str) -> K:
605 """
606 Take a row lock on the specified collection's row in the collections
607 table, and return the collection's primary key.
609 This lock is used to synchronize updates to collection chains.
611 The locking strategy requires cooperation from everything modifying the
612 collection chain table -- all operations that modify collection chains
613 must obtain this lock first. The database will NOT automatically
614 prevent modification of tables based on this lock. The only guarantee
615 is that only one caller will be allowed to hold this lock for a given
616 collection at a time. Concurrent calls will block until the caller
617 holding the lock has completed its transaction.
619 Parameters
620 ----------
621 collection_name : `str`
622 Name of the collection whose chain is being modified.
624 Returns
625 -------
626 id : ``K``
627 The primary key for the given collection.
629 Raises
630 ------
631 MissingCollectionError
632 If the specified collection is not in the database table.
633 CollectionTypeError
634 If the specified collection is not a chained collection.
635 """
636 assert self._db.isInTransaction(), (
637 "Row locks are only held until the end of the current transaction,"
638 " so it makes no sense to take a lock outside a transaction."
639 )
640 assert self._db.isWriteable(), "Collection row locks are only useful for write operations."
642 query = self._select_pkey_by_name(collection_name).with_for_update()
643 with self._db.query(query) as cursor:
644 rows = cursor.all()
646 if len(rows) == 0:
647 raise MissingCollectionError(
648 f"Parent collection {collection_name} not found when updating collection chain."
649 )
650 assert len(rows) == 1, "There should only be one entry for each collection in collection table."
651 r = rows[0]._mapping
652 if r["type"] != CollectionType.CHAINED:
653 raise CollectionTypeError(f"Parent collection {collection_name} is not a chained collection.")
654 return r["key"]
656 @abstractmethod
657 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
658 """Return a SQLAlchemy select statement that will return columns from
659 the one row in the ``collection` table matching the given name. The
660 select statement includes two columns:
662 - ``key`` : the primary key for the collection
663 - ``type`` : the collection type
664 """
665 raise NotImplementedError()
668class _CollectionChainModificationContext(NamedTuple, Generic[K]):
669 parent_key: K
670 child_keys: list[K]
671 child_records: list[CollectionRecord[K]]