Coverage for python/lsst/daf/butler/registry/collections/_base.py: 96%
176 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:43 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:43 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ()
33import itertools
34from abc import abstractmethod
35from collections import namedtuple
36from collections.abc import Iterable, Iterator, Set
37from typing import TYPE_CHECKING, Any, TypeVar, cast
39import sqlalchemy
41from ...timespan_database_representation import TimespanDatabaseRepresentation
42from .._collection_type import CollectionType
43from .._exceptions import MissingCollectionError
44from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
45from ..wildcards import CollectionWildcard
47if TYPE_CHECKING:
48 from .._caching_context import CachingContext
49 from ..interfaces import Database, DimensionRecordStorageManager
52def _makeCollectionForeignKey(
53 sourceColumnName: str, collectionIdName: str, **kwargs: Any
54) -> ddl.ForeignKeySpec:
55 """Define foreign key specification that refers to collections table.
57 Parameters
58 ----------
59 sourceColumnName : `str`
60 Name of the column in the referring table.
61 collectionIdName : `str`
62 Name of the column in collections table that identifies it (PK).
63 **kwargs
64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
66 Returns
67 -------
68 spec : `ddl.ForeignKeySpec`
69 Foreign key specification.
71 Notes
72 -----
73 This method assumes fixed name ("collection") of a collections table.
74 There is also a general assumption that collection primary key consists
75 of a single column.
76 """
77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
83def makeRunTableSpec(
84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
85) -> ddl.TableSpec:
86 """Define specification for "run" table.
88 Parameters
89 ----------
90 collectionIdName : `str`
91 Name of the column in collections table that identifies it (PK).
92 collectionIdType : `type`
93 Type of the PK column in the collections table, one of the
94 `sqlalchemy` types.
95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
97 timespans are stored in this database.
99 Returns
100 -------
101 spec : `ddl.TableSpec`
102 Specification for run table.
104 Notes
105 -----
106 Assumption here and in the code below is that the name of the identifying
107 column is the same in both collections and run tables. The names of
108 non-identifying columns containing run metadata are fixed.
109 """
110 result = ddl.TableSpec(
111 fields=[
112 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
113 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
114 ],
115 foreignKeys=[
116 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
117 ],
118 )
119 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
120 result.fields.add(fieldSpec)
121 return result
124def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
125 """Define specification for "collection_chain" table.
127 Parameters
128 ----------
129 collectionIdName : `str`
130 Name of the column in collections table that identifies it (PK).
131 collectionIdType : `type`
132 Type of the PK column in the collections table, one of the
133 `sqlalchemy` types.
135 Returns
136 -------
137 spec : `ddl.TableSpec`
138 Specification for collection chain table.
140 Notes
141 -----
142 Collection chain is simply an ordered one-to-many relation between
143 collections. The names of the columns in the table are fixed and
144 also hardcoded in the code below.
145 """
146 return ddl.TableSpec(
147 fields=[
148 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
149 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
150 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
151 ],
152 foreignKeys=[
153 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
154 _makeCollectionForeignKey("child", collectionIdName),
155 ],
156 )
159K = TypeVar("K")
162class DefaultCollectionManager(CollectionManager[K]):
163 """Default `CollectionManager` implementation.
165 This implementation uses record classes defined in this module and is
166 based on the same assumptions about schema outlined in the record classes.
168 Parameters
169 ----------
170 db : `Database`
171 Interface to the underlying database engine and namespace.
172 tables : `CollectionTablesTuple`
173 Named tuple of SQLAlchemy table objects.
174 collectionIdName : `str`
175 Name of the column in collections table that identifies it (PK).
176 dimensions : `DimensionRecordStorageManager`
177 Manager object for the dimensions in this `Registry`.
178 caching_context : `CachingContext`
179 Caching context to use.
180 registry_schema_version : `VersionTuple` or `None`, optional
181 The version of the registry schema.
183 Notes
184 -----
185 Implementation uses "aggressive" pre-fetching and caching of the records
186 in memory. Memory cache is synchronized from database when `refresh`
187 method is called.
188 """
190 def __init__(
191 self,
192 db: Database,
193 tables: CollectionTablesTuple,
194 collectionIdName: str,
195 *,
196 dimensions: DimensionRecordStorageManager,
197 caching_context: CachingContext,
198 registry_schema_version: VersionTuple | None = None,
199 ):
200 super().__init__(registry_schema_version=registry_schema_version)
201 self._db = db
202 self._tables = tables
203 self._collectionIdName = collectionIdName
204 self._dimensions = dimensions
205 self._caching_context = caching_context
207 def refresh(self) -> None:
208 # Docstring inherited from CollectionManager.
209 if self._caching_context.collection_records is not None: 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true
210 self._caching_context.collection_records.clear()
212 def _fetch_all(self) -> list[CollectionRecord[K]]:
213 """Retrieve all records into cache if not done so yet."""
214 if self._caching_context.collection_records is not None:
215 if self._caching_context.collection_records.full:
216 return list(self._caching_context.collection_records.records())
217 records = self._fetch_by_key(None)
218 if self._caching_context.collection_records is not None:
219 self._caching_context.collection_records.set(records, full=True)
220 return records
222 def register(
223 self, name: str, type: CollectionType, doc: str | None = None
224 ) -> tuple[CollectionRecord[K], bool]:
225 # Docstring inherited from CollectionManager.
226 registered = False
227 record = self._getByName(name)
228 if record is None:
229 row, inserted_or_updated = self._db.sync(
230 self._tables.collection,
231 keys={"name": name},
232 compared={"type": int(type)},
233 extra={"doc": doc},
234 returning=[self._collectionIdName],
235 )
236 assert isinstance(inserted_or_updated, bool)
237 registered = inserted_or_updated
238 assert row is not None
239 collection_id = cast(K, row[self._collectionIdName])
240 if type is CollectionType.RUN:
241 TimespanReprClass = self._db.getTimespanRepresentation()
242 row, _ = self._db.sync(
243 self._tables.run,
244 keys={self._collectionIdName: collection_id},
245 returning=("host",) + TimespanReprClass.getFieldNames(),
246 )
247 assert row is not None
248 record = RunRecord[K](
249 key=collection_id,
250 name=name,
251 host=row["host"],
252 timespan=TimespanReprClass.extract(row),
253 )
254 elif type is CollectionType.CHAINED:
255 record = ChainedCollectionRecord[K](
256 key=collection_id,
257 name=name,
258 children=[],
259 )
260 else:
261 record = CollectionRecord[K](key=collection_id, name=name, type=type)
262 self._addCachedRecord(record)
263 return record, registered
265 def remove(self, name: str) -> None:
266 # Docstring inherited from CollectionManager.
267 record = self._getByName(name)
268 if record is None: 268 ↛ 269line 268 didn't jump to line 269, because the condition on line 268 was never true
269 raise MissingCollectionError(f"No collection with name '{name}' found.")
270 # This may raise
271 self._db.delete(
272 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
273 )
274 self._removeCachedRecord(record)
276 def find(self, name: str) -> CollectionRecord[K]:
277 # Docstring inherited from CollectionManager.
278 result = self._getByName(name)
279 if result is None:
280 raise MissingCollectionError(f"No collection with name '{name}' found.")
281 return result
283 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
284 """Return multiple records given their names."""
285 names = list(names)
286 # To protect against potential races in cache updates.
287 records: dict[str, CollectionRecord | None] = {}
288 if self._caching_context.collection_records is not None:
289 for name in names:
290 records[name] = self._caching_context.collection_records.get_by_name(name)
291 fetch_names = [name for name, record in records.items() if record is None]
292 else:
293 fetch_names = list(names)
294 records = {name: None for name in fetch_names}
295 if fetch_names:
296 for record in self._fetch_by_name(fetch_names):
297 records[record.name] = record
298 self._addCachedRecord(record)
299 missing_names = [name for name, record in records.items() if record is None]
300 if len(missing_names) == 1:
301 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
302 elif len(missing_names) > 1: 302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true
303 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
304 return [cast(CollectionRecord[K], records[name]) for name in names]
306 def __getitem__(self, key: Any) -> CollectionRecord[K]:
307 # Docstring inherited from CollectionManager.
308 if self._caching_context.collection_records is not None:
309 if (record := self._caching_context.collection_records.get_by_key(key)) is not None:
310 return record
311 if records := self._fetch_by_key([key]): 311 ↛ 317line 311 didn't jump to line 317, because the condition on line 311 was never false
312 record = records[0]
313 if self._caching_context.collection_records is not None:
314 self._caching_context.collection_records.add(record)
315 return record
316 else:
317 raise MissingCollectionError(f"Collection with key '{key}' not found.")
319 def resolve_wildcard(
320 self,
321 wildcard: CollectionWildcard,
322 *,
323 collection_types: Set[CollectionType] = CollectionType.all(),
324 done: set[str] | None = None,
325 flatten_chains: bool = True,
326 include_chains: bool | None = None,
327 ) -> list[CollectionRecord[K]]:
328 # Docstring inherited
329 if done is None: 329 ↛ 331line 329 didn't jump to line 331, because the condition on line 329 was never false
330 done = set()
331 include_chains = include_chains if include_chains is not None else not flatten_chains
333 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]:
334 if record.name in done:
335 return
336 if record.type in collection_types:
337 done.add(record.name)
338 if record.type is not CollectionType.CHAINED or include_chains:
339 yield record
340 if flatten_chains and record.type is CollectionType.CHAINED:
341 done.add(record.name)
342 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children):
343 # flake8 can't tell that we only delete this closure when
344 # we're totally done with it.
345 yield from resolve_nested(child, done) # noqa: F821
347 result: list[CollectionRecord[K]] = []
349 if wildcard.patterns is ...:
350 for record in self._fetch_all():
351 result.extend(resolve_nested(record, done))
352 del resolve_nested
353 return result
354 if wildcard.strings:
355 for record in self._find_many(wildcard.strings):
356 result.extend(resolve_nested(record, done))
357 if wildcard.patterns:
358 for record in self._fetch_all():
359 if any(p.fullmatch(record.name) for p in wildcard.patterns):
360 result.extend(resolve_nested(record, done))
361 del resolve_nested
362 return result
364 def getDocumentation(self, key: K) -> str | None:
365 # Docstring inherited from CollectionManager.
366 sql = (
367 sqlalchemy.sql.select(self._tables.collection.columns.doc)
368 .select_from(self._tables.collection)
369 .where(self._tables.collection.columns[self._collectionIdName] == key)
370 )
371 with self._db.query(sql) as sql_result:
372 return sql_result.scalar()
374 def setDocumentation(self, key: K, doc: str | None) -> None:
375 # Docstring inherited from CollectionManager.
376 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
378 def _addCachedRecord(self, record: CollectionRecord[K]) -> None:
379 """Add single record to cache."""
380 if self._caching_context.collection_records is not None:
381 self._caching_context.collection_records.add(record)
383 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
384 """Remove single record from cache."""
385 if self._caching_context.collection_records is not None: 385 ↛ 386line 385 didn't jump to line 386, because the condition on line 385 was never true
386 self._caching_context.collection_records.discard(record)
388 def _getByName(self, name: str) -> CollectionRecord[K] | None:
389 """Find collection record given collection name."""
390 if self._caching_context.collection_records is not None:
391 if (record := self._caching_context.collection_records.get_by_name(name)) is not None:
392 return record
393 records = self._fetch_by_name([name])
394 for record in records:
395 self._addCachedRecord(record)
396 return records[0] if records else None
398 @abstractmethod
399 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
400 """Fetch collection record from database given its name."""
401 raise NotImplementedError()
403 @abstractmethod
404 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
405 """Fetch collection record from database given its key, or fetch all
406 collctions if argument is None.
407 """
408 raise NotImplementedError()
410 def update_chain(
411 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
412 ) -> ChainedCollectionRecord[K]:
413 # Docstring inherited from CollectionManager.
414 children_as_wildcard = CollectionWildcard.from_names(children)
415 for record in self.resolve_wildcard(
416 children_as_wildcard,
417 flatten_chains=True,
418 include_chains=True,
419 collection_types={CollectionType.CHAINED},
420 ):
421 if record == chain:
422 raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.")
423 if flatten:
424 children = tuple(
425 record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True)
426 )
428 rows = []
429 position = itertools.count()
430 names = []
431 for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
432 rows.append(
433 {
434 "parent": chain.key,
435 "child": child.key,
436 "position": next(position),
437 }
438 )
439 names.append(child.name)
440 with self._db.transaction():
441 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key})
442 self._db.insert(self._tables.collection_chain, *rows)
444 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names))
445 self._addCachedRecord(record)
446 return record