Coverage for python / lsst / daf / butler / registry / collections / _base.py: 0%
293 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-24 08:17 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-24 08:17 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ()
33from abc import abstractmethod
34from collections.abc import Callable, Iterable, Iterator, Mapping, Set
35from contextlib import contextmanager
36from typing import TYPE_CHECKING, Any, Generic, Literal, NamedTuple, TypeVar, cast
38import sqlalchemy
40from lsst.utils.iteration import chunk_iterable
42from ..._collection_type import CollectionType
43from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError
44from ...timespan_database_representation import TimespanDatabaseRepresentation
45from .._collection_record_cache import CollectionRecordCache
46from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
47from ..wildcards import CollectionWildcard
49if TYPE_CHECKING:
50 from .._caching_context import CachingContext
51 from ..interfaces import Database
54def _makeCollectionForeignKey(
55 sourceColumnName: str, collectionIdName: str, **kwargs: Any
56) -> ddl.ForeignKeySpec:
57 """Define foreign key specification that refers to collections table.
59 Parameters
60 ----------
61 sourceColumnName : `str`
62 Name of the column in the referring table.
63 collectionIdName : `str`
64 Name of the column in collections table that identifies it (PK).
65 **kwargs
66 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
68 Returns
69 -------
70 spec : `ddl.ForeignKeySpec`
71 Foreign key specification.
73 Notes
74 -----
75 This method assumes fixed name ("collection") of a collections table.
76 There is also a general assumption that collection primary key consists
77 of a single column.
78 """
79 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
82_T = TypeVar("_T")
85class CollectionTablesTuple(NamedTuple, Generic[_T]):
86 collection: _T
87 run: _T
88 collection_chain: _T
91def makeRunTableSpec(
92 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
93) -> ddl.TableSpec:
94 """Define specification for "run" table.
96 Parameters
97 ----------
98 collectionIdName : `str`
99 Name of the column in collections table that identifies it (PK).
100 collectionIdType : `type`
101 Type of the PK column in the collections table, one of the
102 `sqlalchemy` types.
103 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
104 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
105 timespans are stored in this database.
107 Returns
108 -------
109 spec : `ddl.TableSpec`
110 Specification for run table.
112 Notes
113 -----
114 Assumption here and in the code below is that the name of the identifying
115 column is the same in both collections and run tables. The names of
116 non-identifying columns containing run metadata are fixed.
117 """
118 result = ddl.TableSpec(
119 fields=[
120 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
121 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
122 ],
123 foreignKeys=[
124 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
125 ],
126 )
127 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
128 result.fields.add(fieldSpec)
129 return result
132def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
133 """Define specification for "collection_chain" table.
135 Parameters
136 ----------
137 collectionIdName : `str`
138 Name of the column in collections table that identifies it (PK).
139 collectionIdType : `type`
140 Type of the PK column in the collections table, one of the
141 `sqlalchemy` types.
143 Returns
144 -------
145 spec : `ddl.TableSpec`
146 Specification for collection chain table.
148 Notes
149 -----
150 Collection chain is simply an ordered one-to-many relation between
151 collections. The names of the columns in the table are fixed and
152 also hardcoded in the code below.
153 """
154 return ddl.TableSpec(
155 fields=[
156 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
157 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
158 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
159 ],
160 foreignKeys=[
161 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
162 _makeCollectionForeignKey("child", collectionIdName),
163 ],
164 )
167K = TypeVar("K")
170class DefaultCollectionManager(CollectionManager[K]):
171 """Default `CollectionManager` implementation.
173 This implementation uses record classes defined in this module and is
174 based on the same assumptions about schema outlined in the record classes.
176 Parameters
177 ----------
178 db : `Database`
179 Interface to the underlying database engine and namespace.
180 tables : `CollectionTablesTuple`
181 Named tuple of SQLAlchemy table objects.
182 collectionIdName : `str`
183 Name of the column in collections table that identifies it (PK).
184 caching_context : `CachingContext`
185 Caching context to use.
186 registry_schema_version : `VersionTuple` or `None`, optional
187 The version of the registry schema.
189 Notes
190 -----
191 Implementation uses "aggressive" pre-fetching and caching of the records
192 in memory. Memory cache is synchronized from database when `refresh`
193 method is called.
194 """
196 def __init__(
197 self,
198 db: Database,
199 tables: CollectionTablesTuple[sqlalchemy.Table],
200 collectionIdName: str,
201 *,
202 caching_context: CachingContext,
203 registry_schema_version: VersionTuple | None = None,
204 ):
205 super().__init__(registry_schema_version=registry_schema_version)
206 self._db = db
207 self._tables = tables
208 self._collectionIdName = collectionIdName
209 self._caching_context = caching_context
211 def refresh(self) -> None:
212 # Docstring inherited from CollectionManager.
213 if self._caching_context.collection_records is not None:
214 self._caching_context.collection_records.clear()
216 def _fetch_all(self, collection_cache: CollectionRecordCache | None = None) -> list[CollectionRecord[K]]:
217 """Retrieve all records into cache if not done so yet."""
218 if collection_cache is None:
219 collection_cache = self._caching_context.collection_records
220 if collection_cache is not None:
221 if collection_cache.full:
222 return list(collection_cache.records())
223 records = self._fetch_by_key(None)
224 if collection_cache is not None:
225 collection_cache.set(records, full=True)
226 return records
228 def register(
229 self, name: str, type: CollectionType, doc: str | None = None
230 ) -> tuple[CollectionRecord[K], bool]:
231 # Docstring inherited from CollectionManager.
232 registered = False
233 record = self._getByName(name)
234 if record is None:
235 row, inserted_or_updated = self._db.sync(
236 self._tables.collection,
237 keys={"name": name},
238 compared={"type": int(type)},
239 extra={"doc": doc},
240 returning=[self._collectionIdName],
241 )
242 assert isinstance(inserted_or_updated, bool)
243 registered = inserted_or_updated
244 assert row is not None
245 collection_id = cast(K, row[self._collectionIdName])
246 if type is CollectionType.RUN:
247 TimespanReprClass = self._db.getTimespanRepresentation()
248 row, _ = self._db.sync(
249 self._tables.run,
250 keys={self._collectionIdName: collection_id},
251 returning=("host",) + TimespanReprClass.getFieldNames(),
252 )
253 assert row is not None
254 record = RunRecord[K](
255 key=collection_id,
256 name=name,
257 host=row["host"],
258 timespan=TimespanReprClass.extract(row),
259 )
260 elif type is CollectionType.CHAINED:
261 record = ChainedCollectionRecord[K](
262 key=collection_id,
263 name=name,
264 children=[],
265 )
266 else:
267 record = CollectionRecord[K](key=collection_id, name=name, type=type)
268 self._addCachedRecord(record)
269 return record, registered
271 def remove(self, name: str) -> None:
272 # Docstring inherited from CollectionManager.
273 record = self._getByName(name)
274 if record is None:
275 raise MissingCollectionError(f"No collection with name '{name}' found.")
276 # This may raise
277 self._db.delete(
278 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
279 )
280 self._removeCachedRecord(record)
282 def find(self, name: str) -> CollectionRecord[K]:
283 # Docstring inherited from CollectionManager.
284 result = self._getByName(name)
285 if result is None:
286 raise MissingCollectionError(f"No collection with name '{name}' found.")
287 return result
289 def _find_many(
290 self, names: Iterable[str], flatten_chains: bool, collection_cache: CollectionRecordCache | None
291 ) -> list[CollectionRecord[K]]:
292 """Return multiple records given their names.
294 Parameters
295 ----------
296 names : `~collections.abc.Iterable` [`str`]
297 Collection names to search for.
298 flatten_chains : `bool`
299 If `True` then also retrieve recursively collection records for all
300 chained collections in the input list. Child collections are not
301 returned but are stored in collection cache.
302 collection_cache : `CollectionRecordCache`
303 If `None` then the cache from the caching context will be used if
304 that is not `None`. Collections are searched in the cache first,
305 collections that are missing from the cache are fetched from
306 database. All fetched collections are added to the cache.
308 Returns
309 -------
310 records : `list` [`CollectionRecord`]
311 Collection records. Records are ordered according to the input list
312 and expanded depth-first if ``flatten_chains`` is True.
313 """
315 def check_cache(
316 name: str, cache: CollectionRecordCache, flatten_chains: bool
317 ) -> list[CollectionRecord[K]]:
318 """Check that cache contains a record for a given name and all its
319 child records if ``flatten_chain`` is True.
321 Parameters
322 ----------
323 name : `str`
324 Collection name.
325 cache : `CollectionRecordCache`
326 Record cache.
327 flatten_chains : `bool`
328 If `True` then return all children records recursively.
330 Returns
331 -------
332 records : `list` [`CollectionRecord`]
333 Records from cache, including all child records if
334 ``flatten_chains`` is True.
336 Raises
337 ------
338 LookupError
339 Raised if any record is missing from cache. If LookupError is
340 raised then no records are generated.
341 """
342 record = cache.get_by_name(name)
343 if record is not None:
344 records = [record]
345 if flatten_chains and record.type is CollectionType.CHAINED:
346 # Check all children recursively.
347 for child_name in cast(ChainedCollectionRecord, record).children:
348 records += check_cache(child_name, cache, flatten_chains)
349 return records
350 else:
351 raise LookupError(name)
353 names = list(names)
354 if collection_cache is None:
355 collection_cache = self._caching_context.collection_records
357 # To protect against potential races in cache updates.
358 records: dict[str, CollectionRecord[K]] = {}
359 fetch_names = []
360 if collection_cache is not None:
361 for name in names:
362 try:
363 for record in check_cache(name, collection_cache, flatten_chains):
364 records[record.name] = record
365 except LookupError:
366 fetch_names.append(name)
367 else:
368 fetch_names = names
370 if fetch_names:
371 # Fetch all missing collections and optionally their children.
372 for record in self._fetch_by_name(fetch_names, flatten_chains):
373 records[record.name] = record
374 self._addCachedRecord(record, collection_cache)
376 missing_names = [name for name in names if name not in records]
377 if len(missing_names) == 1:
378 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
379 elif len(missing_names) > 1:
380 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
382 def order(names: Iterable[str]) -> Iterator[CollectionRecord[K]]:
383 for name in names:
384 record = records[name]
385 yield record
386 if flatten_chains and record.type is CollectionType.CHAINED:
387 # Also return all children recursively.
388 yield from order(cast(ChainedCollectionRecord, record).children)
390 return list(order(names))
392 def __getitem__(self, key: Any) -> CollectionRecord[K]:
393 # Docstring inherited from CollectionManager.
394 if self._caching_context.collection_records is not None:
395 if (record := self._caching_context.collection_records.get_by_key(key)) is not None:
396 return record
397 if records := self._fetch_by_key([key]):
398 record = records[0]
399 if self._caching_context.collection_records is not None:
400 self._caching_context.collection_records.add(record)
401 return record
402 else:
403 raise MissingCollectionError(f"Collection with key '{key}' not found.")
405 def resolve_wildcard(
406 self,
407 wildcard: CollectionWildcard,
408 *,
409 collection_types: Set[CollectionType] = CollectionType.all(),
410 flatten_chains: bool = True,
411 include_chains: bool | None = None,
412 ) -> list[CollectionRecord[K]]:
413 # Docstring inherited
414 include_chains = include_chains if include_chains is not None else not flatten_chains
416 def filter_types(records: Iterable[CollectionRecord[K]]) -> Iterator[CollectionRecord[K]]:
417 for record in records:
418 if record.type in collection_types:
419 if record.type is not CollectionType.CHAINED or include_chains:
420 yield record
422 if wildcard.patterns is ...:
423 # As _fetch_all() returns all records without duplicates, we just
424 # have to filter types.
425 return list(filter_types(self._fetch_all()))
427 cache: CollectionRecordCache | None = None
428 result: list[CollectionRecord[K]] = []
429 done_keys: set[K] = set()
430 explicit_names = list(wildcard.strings)
431 if wildcard.patterns:
432 if explicit_names or flatten_chains:
433 # To be efficient in case both patterns and strings are
434 # specified we want to have caching enabled for at least the
435 # duration of this call. Chains flattening can also produce
436 # additional names to look for.
437 cache = self._caching_context.collection_records or CollectionRecordCache()
438 all_records = self._fetch_all(cache)
439 for record in filter_types(all_records):
440 if record.key not in done_keys:
441 if any(p.fullmatch(record.name) for p in wildcard.patterns):
442 result.append(record)
443 done_keys.add(record.key)
444 if flatten_chains:
445 # If flattening include children names of all matching chains.
446 for record in all_records:
447 if isinstance(record, ChainedCollectionRecord):
448 if any(p.fullmatch(record.name) for p in wildcard.patterns):
449 explicit_names.extend(record.children)
451 if explicit_names:
452 # _find_many() returns correctly ordered records, but there may be
453 # duplicates.
454 for record in filter_types(self._find_many(explicit_names, flatten_chains, cache)):
455 if record.key not in done_keys:
456 result.append(record)
457 done_keys.add(record.key)
459 return result
461 def getDocumentation(self, key: K) -> str | None:
462 # Docstring inherited from CollectionManager.
463 docs = self.get_docs([key])
464 return docs.get(key)
466 def get_docs(self, keys: Iterable[K]) -> Mapping[K, str]:
467 # Docstring inherited from CollectionManager.
468 docs: dict[K, str] = {}
469 id_column = self._tables.collection.columns[self._collectionIdName]
470 doc_column = self._tables.collection.columns.doc
471 for chunk in chunk_iterable(keys):
472 sql = (
473 sqlalchemy.sql.select(id_column, doc_column)
474 .select_from(self._tables.collection)
475 .where(sqlalchemy.sql.and_(id_column.in_(chunk), doc_column != sqlalchemy.literal("")))
476 )
477 with self._db.query(sql) as sql_result:
478 for row in sql_result:
479 docs[row[0]] = row[1]
480 return docs
482 def setDocumentation(self, key: K, doc: str | None) -> None:
483 # Docstring inherited from CollectionManager.
484 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
486 def _addCachedRecord(
487 self, record: CollectionRecord[K], collection_cache: CollectionRecordCache | None = None
488 ) -> None:
489 """Add single record to cache."""
490 if collection_cache is None:
491 collection_cache = self._caching_context.collection_records
492 if collection_cache is not None:
493 collection_cache.add(record)
495 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
496 """Remove single record from cache."""
497 if self._caching_context.collection_records is not None:
498 self._caching_context.collection_records.discard(record)
500 def _getByName(self, name: str) -> CollectionRecord[K] | None:
501 """Find collection record given collection name."""
502 if self._caching_context.collection_records is not None:
503 if (record := self._caching_context.collection_records.get_by_name(name)) is not None:
504 return record
505 records = self._fetch_by_name([name], False)
506 for record in records:
507 self._addCachedRecord(record)
508 return records[0] if records else None
510 @abstractmethod
511 def _fetch_by_name(self, names: Iterable[str], flatten_chains: bool) -> list[CollectionRecord[K]]:
512 """Fetch collection record from database given its name."""
513 raise NotImplementedError()
515 @abstractmethod
516 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
517 """Fetch collection record from database given its key, or fetch all
518 collctions if argument is None.
519 """
520 raise NotImplementedError()
522 def update_chain(
523 self,
524 parent_collection_name: str,
525 child_collection_names: list[str],
526 allow_use_in_caching_context: bool = False,
527 ) -> None:
528 with self._modify_collection_chain(
529 parent_collection_name,
530 child_collection_names,
531 # update_chain is currently used in setCollectionChain, which is
532 # called within caching contexts. (At least in Butler.import_ and
533 # possibly other places.) So, unlike the other collection chain
534 # modification methods, it has to update the collection cache.
535 skip_caching_check=allow_use_in_caching_context,
536 ) as c:
537 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": c.parent_key})
538 self._block_for_concurrency_test()
539 self._insert_collection_chain_rows(c.parent_key, 0, c.child_keys)
541 names = [child.name for child in c.child_records]
542 record = ChainedCollectionRecord[K](c.parent_key, parent_collection_name, children=tuple(names))
543 self._addCachedRecord(record)
545 def prepend_collection_chain(
546 self, parent_collection_name: str, child_collection_names: list[str]
547 ) -> None:
548 self._add_to_collection_chain(
549 parent_collection_name, child_collection_names, self._find_prepend_position
550 )
552 def extend_collection_chain(self, parent_collection_name: str, child_collection_names: list[str]) -> None:
553 self._add_to_collection_chain(
554 parent_collection_name, child_collection_names, self._find_extend_position
555 )
557 def _add_to_collection_chain(
558 self,
559 parent_collection_name: str,
560 child_collection_names: list[str],
561 position_func: Callable[[_CollectionChainModificationContext], int],
562 ) -> None:
563 with self._modify_collection_chain(parent_collection_name, child_collection_names) as c:
564 # Remove any of the new children that are already in the
565 # collection, so they move to a new position instead of being
566 # duplicated.
567 self._remove_collection_chain_rows(c.parent_key, c.child_keys)
568 # Figure out where to insert the new children.
569 starting_position = position_func(c)
570 self._block_for_concurrency_test()
571 self._insert_collection_chain_rows(c.parent_key, starting_position, c.child_keys)
573 def remove_from_collection_chain(
574 self, parent_collection_name: str, child_collection_names: list[str]
575 ) -> None:
576 with self._modify_collection_chain(
577 parent_collection_name,
578 child_collection_names,
579 # Removing members from a chain can't create collection cycles
580 skip_cycle_check=True,
581 # It is OK for multiple instances of `remove_from_collection_chain`
582 # to run concurrently on the same collection, because it doesn't
583 # read/modify the position numbers of the children -- it only
584 # deletes existing rows.
585 #
586 # However, other chain modification operations must still be
587 # blocked to avoid consistency issues.
588 exclusive_lock=False,
589 ) as c:
590 self._block_for_concurrency_test()
591 self._remove_collection_chain_rows(c.parent_key, c.child_keys)
593 @contextmanager
594 def _modify_collection_chain(
595 self,
596 parent_collection_name: str,
597 child_collection_names: list[str],
598 *,
599 skip_caching_check: bool = False,
600 skip_cycle_check: bool = False,
601 exclusive_lock: bool = True,
602 ) -> Iterator[_CollectionChainModificationContext[K]]:
603 if (not skip_caching_check) and self._caching_context.collection_records is not None:
604 # Avoid having cache-maintenance code around that is unlikely to
605 # ever be used.
606 raise RuntimeError("Chained collection modification not permitted with active caching context.")
608 if not skip_cycle_check:
609 self._sanity_check_collection_cycles(parent_collection_name, child_collection_names)
611 # Look up the collection primary keys corresponding to the
612 # user-provided list of child collection names. Because there is no
613 # locking for the child collections, it's possible for a concurrent
614 # deletion of one of the children to cause a foreign key constraint
615 # violation when we attempt to insert them in the collection chain
616 # table later.
617 child_records = self.resolve_wildcard(
618 CollectionWildcard.from_names(child_collection_names), flatten_chains=False
619 )
620 child_keys = [child.key for child in child_records]
622 with self._db.transaction():
623 # Lock the parent collection to prevent concurrent updates to the
624 # same collection chain.
625 parent_key = self._find_and_lock_collection_chain(
626 parent_collection_name, exclusive_lock=exclusive_lock
627 )
628 yield _CollectionChainModificationContext[K](
629 parent_key=parent_key, child_keys=child_keys, child_records=child_records
630 )
632 def _sanity_check_collection_cycles(
633 self, parent_collection_name: str, child_collection_names: list[str]
634 ) -> None:
635 """Raise an exception if any of the collections in the ``child_names``
636 list have ``parent_name`` as a child, creating a collection cycle.
638 This is only a sanity check, and does not guarantee that no collection
639 cycles are possible. Concurrent updates might allow collection cycles
640 to be inserted.
641 """
642 for record in self.resolve_wildcard(
643 CollectionWildcard.from_names(child_collection_names),
644 flatten_chains=True,
645 include_chains=True,
646 collection_types={CollectionType.CHAINED},
647 ):
648 if record.name == parent_collection_name:
649 raise CollectionCycleError(
650 f"Cycle in collection chaining when defining '{parent_collection_name}'."
651 )
653 def _insert_collection_chain_rows(
654 self,
655 parent_key: K,
656 starting_position: int,
657 child_keys: list[K],
658 ) -> None:
659 rows = [
660 {
661 "parent": parent_key,
662 "child": child,
663 "position": position,
664 }
665 for position, child in enumerate(child_keys, starting_position)
666 ]
668 # It's possible for the DB to raise an exception for the integers being
669 # out of range here. The position column is only a 16-bit number.
670 # Even if there aren't an unreasonably large number of children in the
671 # collection, a series of many deletes and insertions could cause the
672 # space to become fragmented.
673 #
674 # If this ever actually happens, we should consider doing a migration
675 # to increase the position column to a 32-bit number.
676 # To fix it in the short term, you can re-write the collection chain to
677 # defragment it by doing something like:
678 # registry.setCollectionChain(
679 # parent,
680 # registry.getCollectionChain(parent)
681 # )
682 self._db.insert(self._tables.collection_chain, *rows)
684 def _remove_collection_chain_rows(
685 self,
686 parent_key: K,
687 child_keys: list[K],
688 ) -> None:
689 table = self._tables.collection_chain
690 where = sqlalchemy.and_(table.c.parent == parent_key, table.c.child.in_(child_keys))
691 self._db.deleteWhere(table, where)
693 def _find_prepend_position(self, c: _CollectionChainModificationContext) -> int:
694 """Return the position where children can be inserted to
695 prepend them to a collection chain.
696 """
697 return self._find_position_in_collection_chain(c.parent_key, "begin") - len(c.child_keys)
699 def _find_extend_position(self, c: _CollectionChainModificationContext) -> int:
700 """Return the position where children can be inserted to append them to
701 a collection chain.
702 """
703 return self._find_position_in_collection_chain(c.parent_key, "end") + 1
705 def _find_position_in_collection_chain(self, chain_key: K, begin_or_end: Literal["begin", "end"]) -> int:
706 """Return the lowest or highest numbered position in a collection
707 chain, or 0 if the chain is empty.
708 """
709 table = self._tables.collection_chain
711 func: sqlalchemy.Function
712 match begin_or_end:
713 case "begin":
714 func = sqlalchemy.func.min(table.c.position)
715 case "end":
716 func = sqlalchemy.func.max(table.c.position)
718 query = sqlalchemy.select(func).where(table.c.parent == chain_key)
719 with self._db.query(query) as cursor:
720 position = cursor.scalar()
722 if position is None:
723 return 0
725 return position
727 def _find_and_lock_collection_chain(self, collection_name: str, *, exclusive_lock: bool) -> K:
728 """
729 Take a row lock on the specified collection's row in the collections
730 table, and return the collection's primary key.
732 This lock is used to synchronize updates to collection chains.
734 The locking strategy requires cooperation from everything modifying the
735 collection chain table -- all operations that modify collection chains
736 must obtain this lock first. The database will NOT automatically
737 prevent modification of tables based on this lock. The only guarantee
738 is that only one caller will be allowed to hold the exclusive lock for
739 a given collection at a time. Concurrent calls will block until the
740 caller holding the lock has completed its transaction.
742 Parameters
743 ----------
744 collection_name : `str`
745 Name of the collection whose chain is being modified.
746 exclusive_lock : `bool`
747 If `True`, an exclusive lock will be taken to block all concurrent
748 modifications to the same collection. If `False`, a shared lock
749 will be taken which will only block operations that request an
750 exclusive lock.
752 Returns
753 -------
754 id : ``K``
755 The primary key for the given collection.
757 Raises
758 ------
759 MissingCollectionError
760 If the specified collection is not in the database table.
761 CollectionTypeError
762 If the specified collection is not a chained collection.
763 """
764 assert self._db.isInTransaction(), (
765 "Row locks are only held until the end of the current transaction,"
766 " so it makes no sense to take a lock outside a transaction."
767 )
768 assert self._db.isWriteable(), "Collection row locks are only useful for write operations."
770 query = self._select_pkey_by_name(collection_name).with_for_update(read=not exclusive_lock)
771 with self._db.query(query) as cursor:
772 rows = cursor.all()
774 if len(rows) == 0:
775 raise MissingCollectionError(
776 f"Parent collection {collection_name} not found when updating collection chain."
777 )
778 assert len(rows) == 1, "There should only be one entry for each collection in collection table."
779 r = rows[0]._mapping
780 if r["type"] != CollectionType.CHAINED:
781 raise CollectionTypeError(f"Parent collection {collection_name} is not a chained collection.")
782 return r["key"]
784 @abstractmethod
785 def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select:
786 """Return a SQLAlchemy select statement that will return columns from
787 the one row in the ``collection` table matching the given name. The
788 select statement includes two columns:
790 - ``key`` : the primary key for the collection
791 - ``type`` : the collection type
792 """
793 raise NotImplementedError()
795 def _query_recursive(
796 self,
797 collections: Iterable[str],
798 key_type: type,
799 ) -> list[Mapping]:
800 """Run the query that recursively finds collections and all their
801 child collections.
803 Parameters
804 ----------
805 collections : `~collections.abc.Iterable` [`str`]
806 List of collection names to retrieve.
807 key_type : `type`
808 Type of the key column, e.g. `sqlalchemy.BigInteger`.
810 Returns
811 -------
812 rows : `list` [`~collections.abc.Mapping`]
813 Database rows resulting from the query. Each row contains a
814 combination of columns from ``collections`` table and
815 ``collection_chain`` table joined on ``child`` column,
816 ``child`` column is not included into returned mappings.
817 For top-level collections both ``parent`` and ``position`` will
818 be `None`. Same collection can appear multiple times if it is a
819 child of multiple collections.
820 """
821 # Make recursive CTE to fetch everything in one query. There may be
822 # duplicate collection names in the result, but it should not affect
823 # performance too much for the limited number of input collections.
824 #
825 # The query will look like
826 #
827 # WITH RECURSIVE chains AS (
828 # SELECT
829 # coll_1.*,
830 # cast(NULL as KEY_TYPE) parent,
831 # cast(NULL as SMALLINT) position
832 # FROM
833 # collection coll_1
834 # WHERE
835 # coll_1.name IN (:collections)
836 # UNION ALL
837 # SELECT
838 # coll_2.*,
839 # chain_2.parent,
840 # chain_2.position
841 # FROM
842 # collection coll_2
843 # JOIN collection_chain chain_2
844 # ON coll_2.key_column = chain_2.child
845 # JOIN chains ON chain_2.parent = chains.key_column
846 # )
847 # SELECT
848 # ch.*,
849 # run.host,
850 # run.timespan
851 # FROM
852 # chains ch
853 # LEFT OUTER JOIN run
854 # ON ch.key_column = run.key_column;
855 #
856 chain_table = self._tables.collection_chain
857 collection_table = self._tables.collection
858 run_table = self._tables.run
859 key_column = self._collectionIdName
861 # First CTE select.
862 coll_1 = collection_table.alias("coll_1")
863 chains_cte = (
864 sqlalchemy.select(
865 *coll_1.columns,
866 sqlalchemy.cast(None, type_=key_type).label("parent"),
867 sqlalchemy.cast(None, type_=sqlalchemy.SmallInteger).label("position"),
868 )
869 .where(coll_1.columns["name"].in_(collections))
870 .cte("chains", recursive=True)
871 )
873 # Second CTE select.
874 cte_alias = chains_cte.alias()
875 coll_2 = collection_table.alias("coll_2")
876 chain_2 = chain_table.alias("chain_2")
877 chains_cte = chains_cte.union_all(
878 sqlalchemy.select(
879 *coll_2.columns, chain_2.columns["parent"], chain_2.columns["position"]
880 ).select_from(
881 coll_2.join(chain_2, onclause=(coll_2.columns[key_column] == chain_2.columns["child"])).join(
882 cte_alias, onclause=(chain_2.columns["parent"] == cte_alias.columns[key_column])
883 )
884 )
885 )
887 # Outer select joining chains CTE with run table using LEFT OUTER JOIN.
888 TimespanReprClass = self._db.getTimespanRepresentation()
889 query = sqlalchemy.select(
890 *chains_cte.columns,
891 run_table.columns["host"],
892 *[run_table.columns[column] for column in TimespanReprClass.getFieldNames()],
893 ).select_from(
894 chains_cte.join(
895 run_table,
896 isouter=True,
897 onclause=(chains_cte.columns[key_column] == run_table.columns[key_column]),
898 )
899 )
901 with self._db.transaction():
902 with self._db.query(query) as sql_result:
903 return list(sql_result.mappings().fetchall())
906class _CollectionChainModificationContext(NamedTuple, Generic[K]):
907 parent_key: K
908 child_keys: list[K]
909 child_records: list[CollectionRecord[K]]