Coverage for python/lsst/daf/butler/registry/collections/_base.py: 86%
176 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 10:52 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 10:52 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ()
33import itertools
34from abc import abstractmethod
35from collections import namedtuple
36from collections.abc import Iterable, Iterator, Set
37from typing import TYPE_CHECKING, Any, TypeVar, cast
39import sqlalchemy
41from ..._timespan import TimespanDatabaseRepresentation
42from .._collection_type import CollectionType
43from .._exceptions import MissingCollectionError
44from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
45from ..wildcards import CollectionWildcard
47if TYPE_CHECKING:
48 from .._caching_context import CachingContext
49 from ..interfaces import Database, DimensionRecordStorageManager
52def _makeCollectionForeignKey(
53 sourceColumnName: str, collectionIdName: str, **kwargs: Any
54) -> ddl.ForeignKeySpec:
55 """Define foreign key specification that refers to collections table.
57 Parameters
58 ----------
59 sourceColumnName : `str`
60 Name of the column in the referring table.
61 collectionIdName : `str`
62 Name of the column in collections table that identifies it (PK).
63 **kwargs
64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
66 Returns
67 -------
68 spec : `ddl.ForeignKeySpec`
69 Foreign key specification.
71 Notes
72 -----
73 This method assumes fixed name ("collection") of a collections table.
74 There is also a general assumption that collection primary key consists
75 of a single column.
76 """
77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
83def makeRunTableSpec(
84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
85) -> ddl.TableSpec:
86 """Define specification for "run" table.
88 Parameters
89 ----------
90 collectionIdName : `str`
91 Name of the column in collections table that identifies it (PK).
92 collectionIdType
93 Type of the PK column in the collections table, one of the
94 `sqlalchemy` types.
95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
97 timespans are stored in this database.
100 Returns
101 -------
102 spec : `ddl.TableSpec`
103 Specification for run table.
105 Notes
106 -----
107 Assumption here and in the code below is that the name of the identifying
108 column is the same in both collections and run tables. The names of
109 non-identifying columns containing run metadata are fixed.
110 """
111 result = ddl.TableSpec(
112 fields=[
113 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
114 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
115 ],
116 foreignKeys=[
117 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
118 ],
119 )
120 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
121 result.fields.add(fieldSpec)
122 return result
125def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
126 """Define specification for "collection_chain" table.
128 Parameters
129 ----------
130 collectionIdName : `str`
131 Name of the column in collections table that identifies it (PK).
132 collectionIdType
133 Type of the PK column in the collections table, one of the
134 `sqlalchemy` types.
136 Returns
137 -------
138 spec : `ddl.TableSpec`
139 Specification for collection chain table.
141 Notes
142 -----
143 Collection chain is simply an ordered one-to-many relation between
144 collections. The names of the columns in the table are fixed and
145 also hardcoded in the code below.
146 """
147 return ddl.TableSpec(
148 fields=[
149 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
150 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
151 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
152 ],
153 foreignKeys=[
154 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
155 _makeCollectionForeignKey("child", collectionIdName),
156 ],
157 )
160K = TypeVar("K")
163class DefaultCollectionManager(CollectionManager[K]):
164 """Default `CollectionManager` implementation.
166 This implementation uses record classes defined in this module and is
167 based on the same assumptions about schema outlined in the record classes.
169 Parameters
170 ----------
171 db : `Database`
172 Interface to the underlying database engine and namespace.
173 tables : `CollectionTablesTuple`
174 Named tuple of SQLAlchemy table objects.
175 collectionIdName : `str`
176 Name of the column in collections table that identifies it (PK).
177 dimensions : `DimensionRecordStorageManager`
178 Manager object for the dimensions in this `Registry`.
180 Notes
181 -----
182 Implementation uses "aggressive" pre-fetching and caching of the records
183 in memory. Memory cache is synchronized from database when `refresh`
184 method is called.
185 """
187 def __init__(
188 self,
189 db: Database,
190 tables: CollectionTablesTuple,
191 collectionIdName: str,
192 *,
193 dimensions: DimensionRecordStorageManager,
194 caching_context: CachingContext,
195 registry_schema_version: VersionTuple | None = None,
196 ):
197 super().__init__(registry_schema_version=registry_schema_version)
198 self._db = db
199 self._tables = tables
200 self._collectionIdName = collectionIdName
201 self._dimensions = dimensions
202 self._caching_context = caching_context
204 def refresh(self) -> None:
205 # Docstring inherited from CollectionManager.
206 if self._caching_context.collection_records is not None: 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true
207 self._caching_context.collection_records.clear()
209 def _fetch_all(self) -> list[CollectionRecord[K]]:
210 """Retrieve all records into cache if not done so yet."""
211 if self._caching_context.collection_records is not None: 211 ↛ 212line 211 didn't jump to line 212, because the condition on line 211 was never true
212 if self._caching_context.collection_records.full:
213 return list(self._caching_context.collection_records.records())
214 records = self._fetch_by_key(None)
215 if self._caching_context.collection_records is not None: 215 ↛ 216line 215 didn't jump to line 216, because the condition on line 215 was never true
216 self._caching_context.collection_records.set(records, full=True)
217 return records
219 def register(
220 self, name: str, type: CollectionType, doc: str | None = None
221 ) -> tuple[CollectionRecord[K], bool]:
222 # Docstring inherited from CollectionManager.
223 registered = False
224 record = self._getByName(name)
225 if record is None:
226 row, inserted_or_updated = self._db.sync(
227 self._tables.collection,
228 keys={"name": name},
229 compared={"type": int(type)},
230 extra={"doc": doc},
231 returning=[self._collectionIdName],
232 )
233 assert isinstance(inserted_or_updated, bool)
234 registered = inserted_or_updated
235 assert row is not None
236 collection_id = cast(K, row[self._collectionIdName])
237 if type is CollectionType.RUN:
238 TimespanReprClass = self._db.getTimespanRepresentation()
239 row, _ = self._db.sync(
240 self._tables.run,
241 keys={self._collectionIdName: collection_id},
242 returning=("host",) + TimespanReprClass.getFieldNames(),
243 )
244 assert row is not None
245 record = RunRecord[K](
246 key=collection_id,
247 name=name,
248 host=row["host"],
249 timespan=TimespanReprClass.extract(row),
250 )
251 elif type is CollectionType.CHAINED:
252 record = ChainedCollectionRecord[K](
253 key=collection_id,
254 name=name,
255 children=[],
256 )
257 else:
258 record = CollectionRecord[K](key=collection_id, name=name, type=type)
259 self._addCachedRecord(record)
260 return record, registered
262 def remove(self, name: str) -> None:
263 # Docstring inherited from CollectionManager.
264 record = self._getByName(name)
265 if record is None: 265 ↛ 266line 265 didn't jump to line 266, because the condition on line 265 was never true
266 raise MissingCollectionError(f"No collection with name '{name}' found.")
267 # This may raise
268 self._db.delete(
269 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
270 )
271 self._removeCachedRecord(record)
273 def find(self, name: str) -> CollectionRecord[K]:
274 # Docstring inherited from CollectionManager.
275 result = self._getByName(name)
276 if result is None:
277 raise MissingCollectionError(f"No collection with name '{name}' found.")
278 return result
280 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
281 """Return multiple records given their names."""
282 names = list(names)
283 # To protect against potential races in cache updates.
284 records: dict[str, CollectionRecord | None] = {}
285 if self._caching_context.collection_records is not None: 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true
286 for name in names:
287 records[name] = self._caching_context.collection_records.get_by_name(name)
288 fetch_names = [name for name, record in records.items() if record is None]
289 else:
290 fetch_names = list(names)
291 records = {name: None for name in fetch_names}
292 if fetch_names:
293 for record in self._fetch_by_name(fetch_names):
294 records[record.name] = record
295 self._addCachedRecord(record)
296 missing_names = [name for name, record in records.items() if record is None]
297 if len(missing_names) == 1:
298 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
299 elif len(missing_names) > 1: 299 ↛ 300line 299 didn't jump to line 300, because the condition on line 299 was never true
300 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
301 return [cast(CollectionRecord[K], records[name]) for name in names]
303 def __getitem__(self, key: Any) -> CollectionRecord[K]:
304 # Docstring inherited from CollectionManager.
305 if self._caching_context.collection_records is not None: 305 ↛ 306line 305 didn't jump to line 306, because the condition on line 305 was never true
306 if (record := self._caching_context.collection_records.get_by_key(key)) is not None:
307 return record
308 if records := self._fetch_by_key([key]): 308 ↛ 314line 308 didn't jump to line 314, because the condition on line 308 was never false
309 record = records[0]
310 if self._caching_context.collection_records is not None: 310 ↛ 311line 310 didn't jump to line 311, because the condition on line 310 was never true
311 self._caching_context.collection_records.add(record)
312 return record
313 else:
314 raise MissingCollectionError(f"Collection with key '{key}' not found.")
316 def resolve_wildcard(
317 self,
318 wildcard: CollectionWildcard,
319 *,
320 collection_types: Set[CollectionType] = CollectionType.all(),
321 done: set[str] | None = None,
322 flatten_chains: bool = True,
323 include_chains: bool | None = None,
324 ) -> list[CollectionRecord[K]]:
325 # Docstring inherited
326 if done is None: 326 ↛ 328line 326 didn't jump to line 328, because the condition on line 326 was never false
327 done = set()
328 include_chains = include_chains if include_chains is not None else not flatten_chains
330 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]:
331 if record.name in done:
332 return
333 if record.type in collection_types:
334 done.add(record.name)
335 if record.type is not CollectionType.CHAINED or include_chains:
336 yield record
337 if flatten_chains and record.type is CollectionType.CHAINED:
338 done.add(record.name)
339 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children):
340 # flake8 can't tell that we only delete this closure when
341 # we're totally done with it.
342 yield from resolve_nested(child, done) # noqa: F821
344 result: list[CollectionRecord[K]] = []
346 if wildcard.patterns is ...:
347 for record in self._fetch_all():
348 result.extend(resolve_nested(record, done))
349 del resolve_nested
350 return result
351 if wildcard.strings:
352 for record in self._find_many(wildcard.strings):
353 result.extend(resolve_nested(record, done))
354 if wildcard.patterns:
355 for record in self._fetch_all():
356 if any(p.fullmatch(record.name) for p in wildcard.patterns):
357 result.extend(resolve_nested(record, done))
358 del resolve_nested
359 return result
361 def getDocumentation(self, key: K) -> str | None:
362 # Docstring inherited from CollectionManager.
363 sql = (
364 sqlalchemy.sql.select(self._tables.collection.columns.doc)
365 .select_from(self._tables.collection)
366 .where(self._tables.collection.columns[self._collectionIdName] == key)
367 )
368 with self._db.query(sql) as sql_result:
369 return sql_result.scalar()
371 def setDocumentation(self, key: K, doc: str | None) -> None:
372 # Docstring inherited from CollectionManager.
373 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
375 def _addCachedRecord(self, record: CollectionRecord[K]) -> None:
376 """Add single record to cache."""
377 if self._caching_context.collection_records is not None: 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true
378 self._caching_context.collection_records.add(record)
380 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
381 """Remove single record from cache."""
382 if self._caching_context.collection_records is not None: 382 ↛ 383line 382 didn't jump to line 383, because the condition on line 382 was never true
383 self._caching_context.collection_records.discard(record)
385 def _getByName(self, name: str) -> CollectionRecord[K] | None:
386 """Find collection record given collection name."""
387 if self._caching_context.collection_records is not None: 387 ↛ 388line 387 didn't jump to line 388, because the condition on line 387 was never true
388 if (record := self._caching_context.collection_records.get_by_name(name)) is not None:
389 return record
390 records = self._fetch_by_name([name])
391 for record in records:
392 self._addCachedRecord(record)
393 return records[0] if records else None
395 @abstractmethod
396 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
397 """Fetch collection record from database given its name."""
398 raise NotImplementedError()
400 @abstractmethod
401 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
402 """Fetch collection record from database given its key, or fetch all
403 collctions if argument is None.
404 """
405 raise NotImplementedError()
407 def update_chain(
408 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
409 ) -> ChainedCollectionRecord[K]:
410 # Docstring inherited from CollectionManager.
411 children_as_wildcard = CollectionWildcard.from_names(children)
412 for record in self.resolve_wildcard(
413 children_as_wildcard,
414 flatten_chains=True,
415 include_chains=True,
416 collection_types={CollectionType.CHAINED},
417 ):
418 if record == chain:
419 raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.")
420 if flatten:
421 children = tuple(
422 record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True)
423 )
425 rows = []
426 position = itertools.count()
427 names = []
428 for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
429 rows.append(
430 {
431 "parent": chain.key,
432 "child": child.key,
433 "position": next(position),
434 }
435 )
436 names.append(child.name)
437 with self._db.transaction():
438 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key})
439 self._db.insert(self._tables.collection_chain, *rows)
441 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names))
442 self._addCachedRecord(record)
443 return record