Coverage for python/lsst/daf/butler/registry/collections/_base.py: 96%
175 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:05 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:05 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ()
33import itertools
34from abc import abstractmethod
35from collections import namedtuple
36from collections.abc import Iterable, Iterator, Set
37from typing import TYPE_CHECKING, Any, TypeVar, cast
39import sqlalchemy
41from ...timespan_database_representation import TimespanDatabaseRepresentation
42from .._collection_type import CollectionType
43from .._exceptions import MissingCollectionError
44from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
45from ..wildcards import CollectionWildcard
47if TYPE_CHECKING:
48 from .._caching_context import CachingContext
49 from ..interfaces import Database
52def _makeCollectionForeignKey(
53 sourceColumnName: str, collectionIdName: str, **kwargs: Any
54) -> ddl.ForeignKeySpec:
55 """Define foreign key specification that refers to collections table.
57 Parameters
58 ----------
59 sourceColumnName : `str`
60 Name of the column in the referring table.
61 collectionIdName : `str`
62 Name of the column in collections table that identifies it (PK).
63 **kwargs
64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
66 Returns
67 -------
68 spec : `ddl.ForeignKeySpec`
69 Foreign key specification.
71 Notes
72 -----
73 This method assumes fixed name ("collection") of a collections table.
74 There is also a general assumption that collection primary key consists
75 of a single column.
76 """
77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
83def makeRunTableSpec(
84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
85) -> ddl.TableSpec:
86 """Define specification for "run" table.
88 Parameters
89 ----------
90 collectionIdName : `str`
91 Name of the column in collections table that identifies it (PK).
92 collectionIdType : `type`
93 Type of the PK column in the collections table, one of the
94 `sqlalchemy` types.
95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
97 timespans are stored in this database.
99 Returns
100 -------
101 spec : `ddl.TableSpec`
102 Specification for run table.
104 Notes
105 -----
106 Assumption here and in the code below is that the name of the identifying
107 column is the same in both collections and run tables. The names of
108 non-identifying columns containing run metadata are fixed.
109 """
110 result = ddl.TableSpec(
111 fields=[
112 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
113 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
114 ],
115 foreignKeys=[
116 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
117 ],
118 )
119 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
120 result.fields.add(fieldSpec)
121 return result
124def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
125 """Define specification for "collection_chain" table.
127 Parameters
128 ----------
129 collectionIdName : `str`
130 Name of the column in collections table that identifies it (PK).
131 collectionIdType : `type`
132 Type of the PK column in the collections table, one of the
133 `sqlalchemy` types.
135 Returns
136 -------
137 spec : `ddl.TableSpec`
138 Specification for collection chain table.
140 Notes
141 -----
142 Collection chain is simply an ordered one-to-many relation between
143 collections. The names of the columns in the table are fixed and
144 also hardcoded in the code below.
145 """
146 return ddl.TableSpec(
147 fields=[
148 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
149 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
150 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
151 ],
152 foreignKeys=[
153 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
154 _makeCollectionForeignKey("child", collectionIdName),
155 ],
156 )
159K = TypeVar("K")
162class DefaultCollectionManager(CollectionManager[K]):
163 """Default `CollectionManager` implementation.
165 This implementation uses record classes defined in this module and is
166 based on the same assumptions about schema outlined in the record classes.
168 Parameters
169 ----------
170 db : `Database`
171 Interface to the underlying database engine and namespace.
172 tables : `CollectionTablesTuple`
173 Named tuple of SQLAlchemy table objects.
174 collectionIdName : `str`
175 Name of the column in collections table that identifies it (PK).
176 caching_context : `CachingContext`
177 Caching context to use.
178 registry_schema_version : `VersionTuple` or `None`, optional
179 The version of the registry schema.
181 Notes
182 -----
183 Implementation uses "aggressive" pre-fetching and caching of the records
184 in memory. Memory cache is synchronized from database when `refresh`
185 method is called.
186 """
188 def __init__(
189 self,
190 db: Database,
191 tables: CollectionTablesTuple,
192 collectionIdName: str,
193 *,
194 caching_context: CachingContext,
195 registry_schema_version: VersionTuple | None = None,
196 ):
197 super().__init__(registry_schema_version=registry_schema_version)
198 self._db = db
199 self._tables = tables
200 self._collectionIdName = collectionIdName
201 self._caching_context = caching_context
203 def refresh(self) -> None:
204 # Docstring inherited from CollectionManager.
205 if self._caching_context.collection_records is not None: 205 ↛ 206line 205 didn't jump to line 206, because the condition on line 205 was never true
206 self._caching_context.collection_records.clear()
208 def _fetch_all(self) -> list[CollectionRecord[K]]:
209 """Retrieve all records into cache if not done so yet."""
210 if self._caching_context.collection_records is not None:
211 if self._caching_context.collection_records.full:
212 return list(self._caching_context.collection_records.records())
213 records = self._fetch_by_key(None)
214 if self._caching_context.collection_records is not None:
215 self._caching_context.collection_records.set(records, full=True)
216 return records
218 def register(
219 self, name: str, type: CollectionType, doc: str | None = None
220 ) -> tuple[CollectionRecord[K], bool]:
221 # Docstring inherited from CollectionManager.
222 registered = False
223 record = self._getByName(name)
224 if record is None:
225 row, inserted_or_updated = self._db.sync(
226 self._tables.collection,
227 keys={"name": name},
228 compared={"type": int(type)},
229 extra={"doc": doc},
230 returning=[self._collectionIdName],
231 )
232 assert isinstance(inserted_or_updated, bool)
233 registered = inserted_or_updated
234 assert row is not None
235 collection_id = cast(K, row[self._collectionIdName])
236 if type is CollectionType.RUN:
237 TimespanReprClass = self._db.getTimespanRepresentation()
238 row, _ = self._db.sync(
239 self._tables.run,
240 keys={self._collectionIdName: collection_id},
241 returning=("host",) + TimespanReprClass.getFieldNames(),
242 )
243 assert row is not None
244 record = RunRecord[K](
245 key=collection_id,
246 name=name,
247 host=row["host"],
248 timespan=TimespanReprClass.extract(row),
249 )
250 elif type is CollectionType.CHAINED:
251 record = ChainedCollectionRecord[K](
252 key=collection_id,
253 name=name,
254 children=[],
255 )
256 else:
257 record = CollectionRecord[K](key=collection_id, name=name, type=type)
258 self._addCachedRecord(record)
259 return record, registered
261 def remove(self, name: str) -> None:
262 # Docstring inherited from CollectionManager.
263 record = self._getByName(name)
264 if record is None: 264 ↛ 265line 264 didn't jump to line 265, because the condition on line 264 was never true
265 raise MissingCollectionError(f"No collection with name '{name}' found.")
266 # This may raise
267 self._db.delete(
268 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
269 )
270 self._removeCachedRecord(record)
272 def find(self, name: str) -> CollectionRecord[K]:
273 # Docstring inherited from CollectionManager.
274 result = self._getByName(name)
275 if result is None:
276 raise MissingCollectionError(f"No collection with name '{name}' found.")
277 return result
279 def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
280 """Return multiple records given their names."""
281 names = list(names)
282 # To protect against potential races in cache updates.
283 records: dict[str, CollectionRecord | None] = {}
284 if self._caching_context.collection_records is not None:
285 for name in names:
286 records[name] = self._caching_context.collection_records.get_by_name(name)
287 fetch_names = [name for name, record in records.items() if record is None]
288 else:
289 fetch_names = list(names)
290 records = {name: None for name in fetch_names}
291 if fetch_names:
292 for record in self._fetch_by_name(fetch_names):
293 records[record.name] = record
294 self._addCachedRecord(record)
295 missing_names = [name for name, record in records.items() if record is None]
296 if len(missing_names) == 1:
297 raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
298 elif len(missing_names) > 1: 298 ↛ 299line 298 didn't jump to line 299, because the condition on line 298 was never true
299 raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
300 return [cast(CollectionRecord[K], records[name]) for name in names]
302 def __getitem__(self, key: Any) -> CollectionRecord[K]:
303 # Docstring inherited from CollectionManager.
304 if self._caching_context.collection_records is not None:
305 if (record := self._caching_context.collection_records.get_by_key(key)) is not None:
306 return record
307 if records := self._fetch_by_key([key]): 307 ↛ 313line 307 didn't jump to line 313, because the condition on line 307 was never false
308 record = records[0]
309 if self._caching_context.collection_records is not None:
310 self._caching_context.collection_records.add(record)
311 return record
312 else:
313 raise MissingCollectionError(f"Collection with key '{key}' not found.")
315 def resolve_wildcard(
316 self,
317 wildcard: CollectionWildcard,
318 *,
319 collection_types: Set[CollectionType] = CollectionType.all(),
320 done: set[str] | None = None,
321 flatten_chains: bool = True,
322 include_chains: bool | None = None,
323 ) -> list[CollectionRecord[K]]:
324 # Docstring inherited
325 if done is None: 325 ↛ 327line 325 didn't jump to line 327, because the condition on line 325 was never false
326 done = set()
327 include_chains = include_chains if include_chains is not None else not flatten_chains
329 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]:
330 if record.name in done:
331 return
332 if record.type in collection_types:
333 done.add(record.name)
334 if record.type is not CollectionType.CHAINED or include_chains:
335 yield record
336 if flatten_chains and record.type is CollectionType.CHAINED:
337 done.add(record.name)
338 for child in self._find_many(cast(ChainedCollectionRecord[K], record).children):
339 # flake8 can't tell that we only delete this closure when
340 # we're totally done with it.
341 yield from resolve_nested(child, done) # noqa: F821
343 result: list[CollectionRecord[K]] = []
345 if wildcard.patterns is ...:
346 for record in self._fetch_all():
347 result.extend(resolve_nested(record, done))
348 del resolve_nested
349 return result
350 if wildcard.strings:
351 for record in self._find_many(wildcard.strings):
352 result.extend(resolve_nested(record, done))
353 if wildcard.patterns:
354 for record in self._fetch_all():
355 if any(p.fullmatch(record.name) for p in wildcard.patterns):
356 result.extend(resolve_nested(record, done))
357 del resolve_nested
358 return result
360 def getDocumentation(self, key: K) -> str | None:
361 # Docstring inherited from CollectionManager.
362 sql = (
363 sqlalchemy.sql.select(self._tables.collection.columns.doc)
364 .select_from(self._tables.collection)
365 .where(self._tables.collection.columns[self._collectionIdName] == key)
366 )
367 with self._db.query(sql) as sql_result:
368 return sql_result.scalar()
370 def setDocumentation(self, key: K, doc: str | None) -> None:
371 # Docstring inherited from CollectionManager.
372 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
374 def _addCachedRecord(self, record: CollectionRecord[K]) -> None:
375 """Add single record to cache."""
376 if self._caching_context.collection_records is not None:
377 self._caching_context.collection_records.add(record)
379 def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
380 """Remove single record from cache."""
381 if self._caching_context.collection_records is not None: 381 ↛ 382line 381 didn't jump to line 382, because the condition on line 381 was never true
382 self._caching_context.collection_records.discard(record)
384 def _getByName(self, name: str) -> CollectionRecord[K] | None:
385 """Find collection record given collection name."""
386 if self._caching_context.collection_records is not None:
387 if (record := self._caching_context.collection_records.get_by_name(name)) is not None:
388 return record
389 records = self._fetch_by_name([name])
390 for record in records:
391 self._addCachedRecord(record)
392 return records[0] if records else None
394 @abstractmethod
395 def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
396 """Fetch collection record from database given its name."""
397 raise NotImplementedError()
399 @abstractmethod
400 def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
401 """Fetch collection record from database given its key, or fetch all
402 collctions if argument is None.
403 """
404 raise NotImplementedError()
406 def update_chain(
407 self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
408 ) -> ChainedCollectionRecord[K]:
409 # Docstring inherited from CollectionManager.
410 children_as_wildcard = CollectionWildcard.from_names(children)
411 for record in self.resolve_wildcard(
412 children_as_wildcard,
413 flatten_chains=True,
414 include_chains=True,
415 collection_types={CollectionType.CHAINED},
416 ):
417 if record == chain:
418 raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.")
419 if flatten:
420 children = tuple(
421 record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True)
422 )
424 rows = []
425 position = itertools.count()
426 names = []
427 for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
428 rows.append(
429 {
430 "parent": chain.key,
431 "child": child.key,
432 "position": next(position),
433 }
434 )
435 names.append(child.name)
436 with self._db.transaction():
437 self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key})
438 self._db.insert(self._tables.collection_chain, *rows)
440 record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names))
441 self._addCachedRecord(record)
442 return record