Coverage for python/lsst/daf/butler/registry/collections/_base.py: 92%
182 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:43 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:43 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ()
33import itertools
34from abc import abstractmethod
35from collections import namedtuple
36from collections.abc import Iterable, Iterator, Set
37from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
39import sqlalchemy
41from ..._timespan import Timespan, TimespanDatabaseRepresentation
42from ...dimensions import DimensionUniverse
43from .._collection_type import CollectionType
44from .._exceptions import MissingCollectionError
45from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
46from ..wildcards import CollectionWildcard
48if TYPE_CHECKING:
49 from ..interfaces import Database, DimensionRecordStorageManager
52def _makeCollectionForeignKey(
53 sourceColumnName: str, collectionIdName: str, **kwargs: Any
54) -> ddl.ForeignKeySpec:
55 """Define foreign key specification that refers to collections table.
57 Parameters
58 ----------
59 sourceColumnName : `str`
60 Name of the column in the referring table.
61 collectionIdName : `str`
62 Name of the column in collections table that identifies it (PK).
63 **kwargs
64 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
66 Returns
67 -------
68 spec : `ddl.ForeignKeySpec`
69 Foreign key specification.
71 Notes
72 -----
73 This method assumes fixed name ("collection") of a collections table.
74 There is also a general assumption that collection primary key consists
75 of a single column.
76 """
77 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
80CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
83def makeRunTableSpec(
84 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
85) -> ddl.TableSpec:
86 """Define specification for "run" table.
88 Parameters
89 ----------
90 collectionIdName : `str`
91 Name of the column in collections table that identifies it (PK).
92 collectionIdType
93 Type of the PK column in the collections table, one of the
94 `sqlalchemy` types.
95 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
96 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
97 timespans are stored in this database.
100 Returns
101 -------
102 spec : `ddl.TableSpec`
103 Specification for run table.
105 Notes
106 -----
107 Assumption here and in the code below is that the name of the identifying
108 column is the same in both collections and run tables. The names of
109 non-identifying columns containing run metadata are fixed.
110 """
111 result = ddl.TableSpec(
112 fields=[
113 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
114 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
115 ],
116 foreignKeys=[
117 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
118 ],
119 )
120 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
121 result.fields.add(fieldSpec)
122 return result
125def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
126 """Define specification for "collection_chain" table.
128 Parameters
129 ----------
130 collectionIdName : `str`
131 Name of the column in collections table that identifies it (PK).
132 collectionIdType
133 Type of the PK column in the collections table, one of the
134 `sqlalchemy` types.
136 Returns
137 -------
138 spec : `ddl.TableSpec`
139 Specification for collection chain table.
141 Notes
142 -----
143 Collection chain is simply an ordered one-to-many relation between
144 collections. The names of the columns in the table are fixed and
145 also hardcoded in the code below.
146 """
147 return ddl.TableSpec(
148 fields=[
149 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
150 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
151 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
152 ],
153 foreignKeys=[
154 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
155 _makeCollectionForeignKey("child", collectionIdName),
156 ],
157 )
160class DefaultRunRecord(RunRecord):
161 """Default `RunRecord` implementation.
163 This method assumes the same run table definition as produced by
164 `makeRunTableSpec` method. The only non-fixed name in the schema
165 is the PK column name, this needs to be passed in a constructor.
167 Parameters
168 ----------
169 db : `Database`
170 Registry database.
171 key
172 Unique collection ID, can be the same as ``name`` if ``name`` is used
173 for identification. Usually this is an integer or string, but can be
174 other database-specific type.
175 name : `str`
176 Run collection name.
177 table : `sqlalchemy.schema.Table`
178 Table for run records.
179 idColumnName : `str`
180 Name of the identifying column in run table.
181 host : `str`, optional
182 Name of the host where run was produced.
183 timespan : `Timespan`, optional
184 Timespan for this run.
185 """
187 def __init__(
188 self,
189 db: Database,
190 key: Any,
191 name: str,
192 *,
193 table: sqlalchemy.schema.Table,
194 idColumnName: str,
195 host: str | None = None,
196 timespan: Timespan | None = None,
197 ):
198 super().__init__(key=key, name=name, type=CollectionType.RUN)
199 self._db = db
200 self._table = table
201 self._host = host
202 if timespan is None: 202 ↛ 204line 202 didn't jump to line 204, because the condition on line 202 was never false
203 timespan = Timespan(begin=None, end=None)
204 self._timespan = timespan
205 self._idName = idColumnName
207 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None:
208 # Docstring inherited from RunRecord.
209 if timespan is None:
210 timespan = Timespan(begin=None, end=None)
211 row = {
212 self._idName: self.key,
213 "host": host,
214 }
215 self._db.getTimespanRepresentation().update(timespan, result=row)
216 count = self._db.update(self._table, {self._idName: self.key}, row)
217 if count != 1:
218 raise RuntimeError(f"Run update affected {count} records; expected exactly one.")
219 self._host = host
220 self._timespan = timespan
222 @property
223 def host(self) -> str | None:
224 # Docstring inherited from RunRecord.
225 return self._host
227 @property
228 def timespan(self) -> Timespan:
229 # Docstring inherited from RunRecord.
230 return self._timespan
233class DefaultChainedCollectionRecord(ChainedCollectionRecord):
234 """Default `ChainedCollectionRecord` implementation.
236 This method assumes the same chain table definition as produced by
237 `makeCollectionChainTableSpec` method. All column names in the table are
238 fixed and hard-coded in the methods.
240 Parameters
241 ----------
242 db : `Database`
243 Registry database.
244 key
245 Unique collection ID, can be the same as ``name`` if ``name`` is used
246 for identification. Usually this is an integer or string, but can be
247 other database-specific type.
248 name : `str`
249 Collection name.
250 table : `sqlalchemy.schema.Table`
251 Table for chain relationship records.
252 universe : `DimensionUniverse`
253 Object managing all known dimensions.
254 """
256 def __init__(
257 self,
258 db: Database,
259 key: Any,
260 name: str,
261 *,
262 table: sqlalchemy.schema.Table,
263 universe: DimensionUniverse,
264 ):
265 super().__init__(key=key, name=name, universe=universe)
266 self._db = db
267 self._table = table
268 self._universe = universe
270 def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None:
271 # Docstring inherited from ChainedCollectionRecord.
272 rows = []
273 position = itertools.count()
274 for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
275 rows.append(
276 {
277 "parent": self.key,
278 "child": child.key,
279 "position": next(position),
280 }
281 )
282 with self._db.transaction():
283 self._db.delete(self._table, ["parent"], {"parent": self.key})
284 self._db.insert(self._table, *rows)
286 def _load(self, manager: CollectionManager) -> tuple[str, ...]:
287 # Docstring inherited from ChainedCollectionRecord.
288 sql = (
289 sqlalchemy.sql.select(
290 self._table.columns.child,
291 )
292 .select_from(self._table)
293 .where(self._table.columns.parent == self.key)
294 .order_by(self._table.columns.position)
295 )
296 with self._db.query(sql) as sql_result:
297 return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings())
300K = TypeVar("K")
303class DefaultCollectionManager(Generic[K], CollectionManager):
304 """Default `CollectionManager` implementation.
306 This implementation uses record classes defined in this module and is
307 based on the same assumptions about schema outlined in the record classes.
309 Parameters
310 ----------
311 db : `Database`
312 Interface to the underlying database engine and namespace.
313 tables : `CollectionTablesTuple`
314 Named tuple of SQLAlchemy table objects.
315 collectionIdName : `str`
316 Name of the column in collections table that identifies it (PK).
317 dimensions : `DimensionRecordStorageManager`
318 Manager object for the dimensions in this `Registry`.
320 Notes
321 -----
322 Implementation uses "aggressive" pre-fetching and caching of the records
323 in memory. Memory cache is synchronized from database when `refresh`
324 method is called.
325 """
327 def __init__(
328 self,
329 db: Database,
330 tables: CollectionTablesTuple,
331 collectionIdName: str,
332 *,
333 dimensions: DimensionRecordStorageManager,
334 registry_schema_version: VersionTuple | None = None,
335 ):
336 super().__init__(registry_schema_version=registry_schema_version)
337 self._db = db
338 self._tables = tables
339 self._collectionIdName = collectionIdName
340 self._records: dict[K, CollectionRecord] = {} # indexed by record ID
341 self._dimensions = dimensions
343 def refresh(self) -> None:
344 # Docstring inherited from CollectionManager.
345 sql = sqlalchemy.sql.select(
346 *(list(self._tables.collection.columns) + list(self._tables.run.columns))
347 ).select_from(self._tables.collection.join(self._tables.run, isouter=True))
348 # Put found records into a temporary instead of updating self._records
349 # in place, for exception safety.
350 records = []
351 chains = []
352 TimespanReprClass = self._db.getTimespanRepresentation()
353 with self._db.query(sql) as sql_result:
354 sql_rows = sql_result.mappings().fetchall()
355 for row in sql_rows:
356 collection_id = row[self._tables.collection.columns[self._collectionIdName]]
357 name = row[self._tables.collection.columns.name]
358 type = CollectionType(row["type"])
359 record: CollectionRecord
360 if type is CollectionType.RUN:
361 record = DefaultRunRecord(
362 key=collection_id,
363 name=name,
364 db=self._db,
365 table=self._tables.run,
366 idColumnName=self._collectionIdName,
367 host=row[self._tables.run.columns.host],
368 timespan=TimespanReprClass.extract(row),
369 )
370 elif type is CollectionType.CHAINED:
371 record = DefaultChainedCollectionRecord(
372 db=self._db,
373 key=collection_id,
374 table=self._tables.collection_chain,
375 name=name,
376 universe=self._dimensions.universe,
377 )
378 chains.append(record)
379 else:
380 record = CollectionRecord(key=collection_id, name=name, type=type)
381 records.append(record)
382 self._setRecordCache(records)
383 for chain in chains:
384 try:
385 chain.refresh(self)
386 except MissingCollectionError:
387 # This indicates a race condition in which some other client
388 # created a new collection and added it as a child of this
389 # (pre-existing) chain between the time we fetched all
390 # collections and the time we queried for parent-child
391 # relationships.
392 # Because that's some other unrelated client, we shouldn't care
393 # about that parent collection anyway, so we just drop it on
394 # the floor (a manual refresh can be used to get it back).
395 self._removeCachedRecord(chain)
397 def register(
398 self, name: str, type: CollectionType, doc: str | None = None
399 ) -> tuple[CollectionRecord, bool]:
400 # Docstring inherited from CollectionManager.
401 registered = False
402 record = self._getByName(name)
403 if record is None:
404 row, inserted_or_updated = self._db.sync(
405 self._tables.collection,
406 keys={"name": name},
407 compared={"type": int(type)},
408 extra={"doc": doc},
409 returning=[self._collectionIdName],
410 )
411 assert isinstance(inserted_or_updated, bool)
412 registered = inserted_or_updated
413 assert row is not None
414 collection_id = row[self._collectionIdName]
415 if type is CollectionType.RUN:
416 TimespanReprClass = self._db.getTimespanRepresentation()
417 row, _ = self._db.sync(
418 self._tables.run,
419 keys={self._collectionIdName: collection_id},
420 returning=("host",) + TimespanReprClass.getFieldNames(),
421 )
422 assert row is not None
423 record = DefaultRunRecord(
424 db=self._db,
425 key=collection_id,
426 name=name,
427 table=self._tables.run,
428 idColumnName=self._collectionIdName,
429 host=row["host"],
430 timespan=TimespanReprClass.extract(row),
431 )
432 elif type is CollectionType.CHAINED:
433 record = DefaultChainedCollectionRecord(
434 db=self._db,
435 key=collection_id,
436 name=name,
437 table=self._tables.collection_chain,
438 universe=self._dimensions.universe,
439 )
440 else:
441 record = CollectionRecord(key=collection_id, name=name, type=type)
442 self._addCachedRecord(record)
443 return record, registered
445 def remove(self, name: str) -> None:
446 # Docstring inherited from CollectionManager.
447 record = self._getByName(name)
448 if record is None: 448 ↛ 449line 448 didn't jump to line 449, because the condition on line 448 was never true
449 raise MissingCollectionError(f"No collection with name '{name}' found.")
450 # This may raise
451 self._db.delete(
452 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
453 )
454 self._removeCachedRecord(record)
456 def find(self, name: str) -> CollectionRecord:
457 # Docstring inherited from CollectionManager.
458 result = self._getByName(name)
459 if result is None:
460 raise MissingCollectionError(f"No collection with name '{name}' found.")
461 return result
463 def __getitem__(self, key: Any) -> CollectionRecord:
464 # Docstring inherited from CollectionManager.
465 try:
466 return self._records[key]
467 except KeyError as err:
468 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err
470 def resolve_wildcard(
471 self,
472 wildcard: CollectionWildcard,
473 *,
474 collection_types: Set[CollectionType] = CollectionType.all(),
475 done: set[str] | None = None,
476 flatten_chains: bool = True,
477 include_chains: bool | None = None,
478 ) -> list[CollectionRecord]:
479 # Docstring inherited
480 if done is None: 480 ↛ 482line 480 didn't jump to line 482, because the condition on line 480 was never false
481 done = set()
482 include_chains = include_chains if include_chains is not None else not flatten_chains
484 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]:
485 if record.name in done:
486 return
487 if record.type in collection_types:
488 done.add(record.name)
489 if record.type is not CollectionType.CHAINED or include_chains:
490 yield record
491 if flatten_chains and record.type is CollectionType.CHAINED:
492 done.add(record.name)
493 for name in cast(ChainedCollectionRecord, record).children:
494 # flake8 can't tell that we only delete this closure when
495 # we're totally done with it.
496 yield from resolve_nested(self.find(name), done) # noqa: F821
498 result: list[CollectionRecord] = []
500 if wildcard.patterns is ...:
501 for record in self._records.values():
502 result.extend(resolve_nested(record, done))
503 del resolve_nested
504 return result
505 for name in wildcard.strings:
506 result.extend(resolve_nested(self.find(name), done))
507 if wildcard.patterns:
508 for record in self._records.values():
509 if any(p.fullmatch(record.name) for p in wildcard.patterns):
510 result.extend(resolve_nested(record, done))
511 del resolve_nested
512 return result
514 def getDocumentation(self, key: Any) -> str | None:
515 # Docstring inherited from CollectionManager.
516 sql = (
517 sqlalchemy.sql.select(self._tables.collection.columns.doc)
518 .select_from(self._tables.collection)
519 .where(self._tables.collection.columns[self._collectionIdName] == key)
520 )
521 with self._db.query(sql) as sql_result:
522 return sql_result.scalar()
524 def setDocumentation(self, key: Any, doc: str | None) -> None:
525 # Docstring inherited from CollectionManager.
526 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
528 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None:
529 """Set internal record cache to contain given records,
530 old cached records will be removed.
531 """
532 self._records = {}
533 for record in records:
534 self._records[record.key] = record
536 def _addCachedRecord(self, record: CollectionRecord) -> None:
537 """Add single record to cache."""
538 self._records[record.key] = record
540 def _removeCachedRecord(self, record: CollectionRecord) -> None:
541 """Remove single record from cache."""
542 del self._records[record.key]
544 @abstractmethod
545 def _getByName(self, name: str) -> CollectionRecord | None:
546 """Find collection record given collection name."""
547 raise NotImplementedError()