Coverage for python/lsst/daf/butler/registry/collections/_base.py: 97%
219 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 02:52 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 02:52 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ()
33from abc import abstractmethod
34from collections.abc import Iterable, Iterator, Set
35from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast
37import sqlalchemy
39from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError
40from ...timespan_database_representation import TimespanDatabaseRepresentation
41from .._collection_type import CollectionType
42from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
43from ..wildcards import CollectionWildcard
45if TYPE_CHECKING:
46 from .._caching_context import CachingContext
47 from ..interfaces import Database
50def _makeCollectionForeignKey(
51 sourceColumnName: str, collectionIdName: str, **kwargs: Any
52) -> ddl.ForeignKeySpec:
53 """Define foreign key specification that refers to collections table.
55 Parameters
56 ----------
57 sourceColumnName : `str`
58 Name of the column in the referring table.
59 collectionIdName : `str`
60 Name of the column in collections table that identifies it (PK).
61 **kwargs
62 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
64 Returns
65 -------
66 spec : `ddl.ForeignKeySpec`
67 Foreign key specification.
69 Notes
70 -----
71 This method assumes fixed name ("collection") of a collections table.
72 There is also a general assumption that collection primary key consists
73 of a single column.
74 """
75 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
78_T = TypeVar("_T")
81class CollectionTablesTuple(NamedTuple, Generic[_T]):
82 collection: _T
83 run: _T
84 collection_chain: _T
87def makeRunTableSpec(
88 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
89) -> ddl.TableSpec:
90 """Define specification for "run" table.
92 Parameters
93 ----------
94 collectionIdName : `str`
95 Name of the column in collections table that identifies it (PK).
96 collectionIdType : `type`
97 Type of the PK column in the collections table, one of the
98 `sqlalchemy` types.
99 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
100 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
101 timespans are stored in this database.
103 Returns
104 -------
105 spec : `ddl.TableSpec`
106 Specification for run table.
108 Notes
109 -----
110 Assumption here and in the code below is that the name of the identifying
111 column is the same in both collections and run tables. The names of
112 non-identifying columns containing run metadata are fixed.
113 """
114 result = ddl.TableSpec(
115 fields=[
116 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
117 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
118 ],
119 foreignKeys=[
120 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
121 ],
122 )
123 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
124 result.fields.add(fieldSpec)
125 return result
128def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
129 """Define specification for "collection_chain" table.
131 Parameters
132 ----------
133 collectionIdName : `str`
134 Name of the column in collections table that identifies it (PK).
135 collectionIdType : `type`
136 Type of the PK column in the collections table, one of the
137 `sqlalchemy` types.
139 Returns
140 -------
141 spec : `ddl.TableSpec`
142 Specification for collection chain table.
144 Notes
145 -----
146 Collection chain is simply an ordered one-to-many relation between
147 collections. The names of the columns in the table are fixed and
148 also hardcoded in the code below.
149 """
150 return ddl.TableSpec(
151 fields=[
152 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
153 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
154 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
155 ],
156 foreignKeys=[
157 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
158 _makeCollectionForeignKey("child", collectionIdName),
159 ],
160 )
163K = TypeVar("K")
166class DefaultCollectionManager(CollectionManager[K]):
167 """Default `CollectionManager` implementation.
169 This implementation uses record classes defined in this module and is
170 based on the same assumptions about schema outlined in the record classes.
172 Parameters
173 ----------
174 db : `Database`
175 Interface to the underlying database engine and namespace.
176 tables : `CollectionTablesTuple`
177 Named tuple of SQLAlchemy table objects.
178 collectionIdName : `str`
179 Name of the column in collections table that identifies it (PK).
180 caching_context : `CachingContext`
181 Caching context to use.
182 registry_schema_version : `VersionTuple` or `None`, optional
183 The version of the registry schema.
185 Notes
186 -----
187 Implementation uses "aggressive" pre-fetching and caching of the records
188 in memory. Memory cache is synchronized from database when `refresh`
189 method is called.
190 """
192 def __init__(
193 self,
194 db: Database,
195 tables: CollectionTablesTuple[sqlalchemy.Table],
196 collectionIdName: str,
197 *,
198 caching_context: CachingContext,
199 registry_schema_version: VersionTuple | None = None,
200 ):
201 super().__init__(registry_schema_version=registry_schema_version)
202 self._db = db
203 self._tables = tables
204 self._collectionIdName = collectionIdName
205 self._caching_context = caching_context
207 def refresh(self) -> None:
208 # Docstring inherited from CollectionManager.
209 if self._caching_context.collection_records is not None: 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true
210 self._caching_context.collection_records.clear()
212 def _fetch_all(self) -> list[CollectionRecord[K]]:
213 """Retrieve all records into cache if not done so yet."""
214 if self._caching_context.collection_records is not None:
215 if self._caching_context.collection_records.full:
216 return list(self._caching_context.collection_records.records())
217 records = self._fetch_by_key(None)
218 if self._caching_context.collection_records is not None:
219 self._caching_context.collection_records.set(records, full=True)
220 return records
222 def register(
223 self, name: str, type: CollectionType, doc: str | None = None
224 ) -> tuple[CollectionRecord[K], bool]:
225 # Docstring inherited from CollectionManager.
226 registered = False
227 record = self._getByName(name)
228 if record is None:
229 row, inserted_or_updated = self._db.sync(
230 self._tables.collection,
231 keys={"name": name},
232 compared={"type": int(type)},
233 extra={"doc": doc},
234 returning=[self._collectionIdName],
235 )
236 assert isinstance(inserted_or_updated, bool)
237 registered = inserted_or_updated
238 assert row is not None
239 collection_id = cast(K, row[self._collectionIdName])
240 if type is CollectionType.RUN:
241 TimespanReprClass = self._db.getTimespanRepresentation()
242 row, _ = self._db.sync(
243 self._tables.run,
244 keys={self._collectionIdName: collection_id},
245 returning=("host",) + TimespanReprClass.getFieldNames(),
246 )
247 assert row is not None
248 record = RunRecord[K](
249 key=collection_id,
250 name=name,
251 host=row["host"],
252 timespan=TimespanReprClass.extract(row),
253 )
254 elif type is CollectionType.CHAINED:
255 record = ChainedCollectionRecord[K](
256 key=collection_id,
257 name=name,
258 children=[],
259 )
260 else:
261 record = CollectionRecord[K](key=collection_id, name=name, type=type)
262 self._addCachedRecord(record)
263 return record, registered
265 def remove(self, name: str) -> None:
266 # Docstring inherited from CollectionManager.
267 record = self._getByName(name)
268 if record is None: 268 ↛ 269line 268 didn't jump to line 269, because the condition on line 268 was never true
269 raise MissingCollectionError(f"No collection with name '{name}' found.")
270 # This may raise
271 self._db.delete(
272 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
273 )
274 self._removeCachedRecord(record)
276 def find(self, name: str) -> CollectionRecord[K]:
277 # Docstring inherited from CollectionManager.
278 result = self._getByName(name)
279 if result is None:
280 raise MissingCollectionError(f"No collection with name '{name}' found.")
281 return result
283 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
284 """Return multiple records given their names."""
285 names = list(names)
286 # To protect against potential races in cache updates.
287 records: dict[str, CollectionRecord | None] = {}
288 if self._caching_context.collection_records is not None:
289 for name in names:
290 records[name] = self._caching_context.collection_records.get_by_name(name)
291 fetch_names = [name for name, record in records.items() if record is None]
292 else:
293 fetch_names = list(names)
294 records = {name: None for name in fetch_names}
295 if fetch_names:
296 for record in self._fetch_by_name(fetch_names):
297 records[record.name] = record
298 self._addCachedRecord(record)
299 missing_names = [name for name, record in records.items() if record is None]
300 if len(missing_names) == 1:
301 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
302 elif len(missing_names) > 1: 302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true
303 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
304 return [cast(CollectionRecord[K], records[name]) for name in names]
306 def __getitem__(self, key: Any) -> CollectionRecord[K]:
307 # Docstring inherited from CollectionManager.
308 if self._caching_context.collection_records is not None:
309 if (record := self._caching_context.collection_records.get_by_key(key)) is not None:
310 return record
311 if records := self._fetch_by_key([key]): 311 ↛ 317line 311 didn't jump to line 317, because the condition on line 311 was never false
312 record = records[0]
313 if self._caching_context.collection_records is not None:
314 self._caching_context.collection_records.add(record)
315 return record
316 else:
317 raise MissingCollectionError(f"Collection with key '{key}' not found.")
319 def resolve_wildcard(
320 self,
321 wildcard: CollectionWildcard,
322 *,
323 collection_types: Set[CollectionType] = CollectionType.all(),
324 done: set[str] | None = None,
325 flatten_chains: bool = True,
326 include_chains: bool | None = None,
327 ) -> list[CollectionRecord[K]]:
328 # Docstring inherited
329 if done is None: 329 ↛ 331line 329 didn't jump to line 331, because the condition on line 329 was never false
330 done = set()
331 include_chains = include_chains if include_chains is not None else not flatten_chains
333 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]:
334 if record.name in done:
335 return
336 if record.type in collection_types:
337 done.add(record.name)
338 if record.type is not CollectionType.CHAINED or include_chains:
339 yield record
340 if flatten_chains and record.type is CollectionType.CHAINED:
341 done.add(record.name)
342 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children):
343 # flake8 can't tell that we only delete this closure when
344 # we're totally done with it.
345 yield from resolve_nested(child, done) # noqa: F821
347 result: list[CollectionRecord[K]] = []
349 if wildcard.patterns is ...:
350 for record in self._fetch_all():
351 result.extend(resolve_nested(record, done))
352 del resolve_nested
353 return result
354 if wildcard.strings:
355 for record in self._find_many(wildcard.strings):
356 result.extend(resolve_nested(record, done))
357 if wildcard.patterns:
358 for record in self._fetch_all():
359 if any(p.fullmatch(record.name) for p in wildcard.patterns):
360 result.extend(resolve_nested(record, done))
361 del resolve_nested
362 return result
364 def getDocumentation(self, key: K) -> str | None:
365 # Docstring inherited from CollectionManager.
366 sql = (
367 sqlalchemy.sql.select(self._tables.collection.columns.doc)
368 .select_from(self._tables.collection)
369 .where(self._tables.collection.columns[self._collectionIdName] == key)
370 )
371 with self._db.query(sql) as sql_result:
372 return sql_result.scalar()
374 def setDocumentation(self, key: K, doc: str | None) -> None:
375 # Docstring inherited from CollectionManager.
376 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
378 def _addCachedRecord(self, record: CollectionRecord[K]) -> None:
379 """Add single record to cache."""
380 if self._caching_context.collection_records is not None:
381 self._caching_context.collection_records.add(record)
383 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
384 """Remove single record from cache."""
385 if self._caching_context.collection_records is not None: 385 ↛ 386line 385 didn't jump to line 386, because the condition on line 385 was never true
386 self._caching_context.collection_records.discard(record)
388 def _getByName(self, name: str) -> CollectionRecord[K] | None:
389 """Find collection record given collection name."""
390 if self._caching_context.collection_records is not None:
391 if (record := self._caching_context.collection_records.get_by_name(name)) is not None:
392 return record
393 records = self._fetch_by_name([name])
394 for record in records:
395 self._addCachedRecord(record)
396 return records[0] if records else None
398 @abstractmethod
399 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
400 """Fetch collection record from database given its name."""
401 raise NotImplementedError()
403 @abstractmethod
404 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
405 """Fetch collection record from database given its key, or fetch all
406 collctions if argument is None.
407 """
408 raise NotImplementedError()
410 def update_chain(
411 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
412 ) -> ChainedCollectionRecord[K]:
413 # Docstring inherited from CollectionManager.
414 children = list(children)
415 self._sanity_check_collection_cycles(chain.name, children)
417 if flatten:
418 children = tuple(
419 record.name
420 for record in self.resolve_wildcard(
421 CollectionWildcard.from_names(children), flatten_chains=True
422 )
423 )
425 child_records = self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False)
426 names = [child.name for child in child_records]
427 with self._db.transaction():
428 self._find_and_lock_collection_chain(chain.name)
429 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key})
430 self._block_for_concurrency_test()
431 self._insert_collection_chain_rows(chain.key, 0, [child.key for child in child_records])
433 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names))
434 self._addCachedRecord(record)
435 return record
437 def _sanity_check_collection_cycles(
438 self, parent_collection_name: str, child_collection_names: list[str]
439 ) -> None:
440 """Raise an exception if any of the collections in the ``child_names``
441 list have ``parent_name`` as a child, creating a collection cycle.
443 This is only a sanity check, and does not guarantee that no collection
444 cycles are possible. Concurrent updates might allow collection cycles
445 to be inserted.
446 """
447 for record in self.resolve_wildcard(
448 CollectionWildcard.from_names(child_collection_names),
449 flatten_chains=True,
450 include_chains=True,
451 collection_types={CollectionType.CHAINED},
452 ):
453 if record.name == parent_collection_name:
454 raise CollectionCycleError(
455 f"Cycle in collection chaining when defining '{parent_collection_name}'."
456 )
458 def _insert_collection_chain_rows(
459 self,
460 parent_key: K,
461 starting_position: int,
462 child_keys: list[K],
463 ) -> None:
464 rows = [
465 {
466 "parent": parent_key,
467 "child": child,
468 "position": position,
469 }
470 for position, child in enumerate(child_keys, starting_position)
471 ]
472 self._db.insert(self._tables.collection_chain, *rows)
474 def _remove_collection_chain_rows(
475 self,
476 parent_key: K,
477 child_keys: list[K],
478 ) -> None:
479 table = self._tables.collection_chain
480 where = sqlalchemy.and_(table.c.parent == parent_key, table.c.child.in_(child_keys))
481 self._db.deleteWhere(table, where)
483 def prepend_collection_chain(
484 self, parent_collection_name: str, child_collection_names: list[str]
485 ) -> None:
486 if self._caching_context.is_enabled:
487 # Avoid having cache-maintenance code around that is unlikely to
488 # ever be used.
489 raise RuntimeError("Chained collection modification not permitted with active caching context.")
491 self._sanity_check_collection_cycles(parent_collection_name, child_collection_names)
493 child_records = self.resolve_wildcard(
494 CollectionWildcard.from_names(child_collection_names), flatten_chains=False
495 )
496 child_keys = [child.key for child in child_records]
498 with self._db.transaction():
499 parent_key = self._find_and_lock_collection_chain(parent_collection_name)
500 self._remove_collection_chain_rows(parent_key, child_keys)
501 starting_position = self._find_lowest_position_in_collection_chain(parent_key) - len(child_keys)
502 self._block_for_concurrency_test()
503 self._insert_collection_chain_rows(parent_key, starting_position, child_keys)
505 def _find_lowest_position_in_collection_chain(self, chain_key: K) -> int:
506 """Return the lowest-numbered position in a collection chain, or 0 if
507 the chain is empty.
508 """
509 table = self._tables.collection_chain
510 query = sqlalchemy.select(sqlalchemy.func.min(table.c.position)).where(table.c.parent == chain_key)
511 with self._db.query(query) as cursor:
512 lowest_existing_position = cursor.scalar()
514 if lowest_existing_position is None:
515 return 0
517 return lowest_existing_position
519 def _find_and_lock_collection_chain(self, collection_name: str) -> K:
520 """
521 Take a row lock on the specified collection's row in the collections
522 table, and return the collection's primary key.
524 This lock is used to synchronize updates to collection chains.
526 The locking strategy requires cooperation from everything modifying the
527 collection chain table -- all operations that modify collection chains
528 must obtain this lock first. The database will NOT automatically
529 prevent modification of tables based on this lock. The only guarantee
530 is that only one caller will be allowed to hold this lock for a given
531 collection at a time. Concurrent calls will block until the caller
532 holding the lock has completed its transaction.
534 Parameters
535 ----------
536 collection_name : `str`
537 Name of the collection whose chain is being modified.
539 Returns
540 -------
541 id : ``K``
542 The primary key for the given collection.
544 Raises
545 ------
546 MissingCollectionError
547 If the specified collection is not in the database table.
548 CollectionTypeError
549 If the specified collection is not a chained collection.
550 """
551 assert self._db.isInTransaction(), (
552 "Row locks are only held until the end of the current transaction,"
553 " so it makes no sense to take a lock outside a transaction."
554 )
555 assert self._db.isWriteable(), "Collection row locks are only useful for write operations."
557 query = self._select_pkey_by_name(collection_name).with_for_update()
558 with self._db.query(query) as cursor:
559 rows = cursor.all()
561 if len(rows) == 0:
562 raise MissingCollectionError(
563 f"Parent collection {collection_name} not found when updating collection chain."
564 )
565 assert len(rows) == 1, "There should only be one entry for each collection in collection table."
566 r = rows[0]._mapping
567 if r["type"] != CollectionType.CHAINED:
568 raise CollectionTypeError(f"Parent collection {collection_name} is not a chained collection.")
569 return r["key"]
571 @abstractmethod
572 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
573 """Return a SQLAlchemy select statement that will return columns from
574 the one row in the ``collection` table matching the given name. The
575 select statement includes two columns:
577 - ``key`` : the primary key for the collection
578 - ``type`` : the collection type
579 """
580 raise NotImplementedError()