Coverage for python/lsst/daf/butler/registry/collections/_base.py: 91%
181 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-05 03:17 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-05 03:17 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ()
25import itertools
26from abc import abstractmethod
27from collections import namedtuple
28from collections.abc import Iterable, Iterator, Set
29from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
31import sqlalchemy
32from lsst.utils.ellipsis import Ellipsis
34from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl
35from .._collectionType import CollectionType
36from .._exceptions import MissingCollectionError
37from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
38from ..wildcards import CollectionWildcard
40if TYPE_CHECKING:
41 from ..interfaces import Database, DimensionRecordStorageManager
44def _makeCollectionForeignKey(
45 sourceColumnName: str, collectionIdName: str, **kwargs: Any
46) -> ddl.ForeignKeySpec:
47 """Define foreign key specification that refers to collections table.
49 Parameters
50 ----------
51 sourceColumnName : `str`
52 Name of the column in the referring table.
53 collectionIdName : `str`
54 Name of the column in collections table that identifies it (PK).
55 **kwargs
56 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
58 Returns
59 -------
60 spec : `ddl.ForeignKeySpec`
61 Foreign key specification.
63 Notes
64 -----
65 This method assumes fixed name ("collection") of a collections table.
66 There is also a general assumption that collection primary key consists
67 of a single column.
68 """
69 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
72CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
75def makeRunTableSpec(
76 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
77) -> ddl.TableSpec:
78 """Define specification for "run" table.
80 Parameters
81 ----------
82 collectionIdName : `str`
83 Name of the column in collections table that identifies it (PK).
84 collectionIdType
85 Type of the PK column in the collections table, one of the
86 `sqlalchemy` types.
87 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
88 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
89 timespans are stored in this database.
92 Returns
93 -------
94 spec : `ddl.TableSpec`
95 Specification for run table.
97 Notes
98 -----
99 Assumption here and in the code below is that the name of the identifying
100 column is the same in both collections and run tables. The names of
101 non-identifying columns containing run metadata are fixed.
102 """
103 result = ddl.TableSpec(
104 fields=[
105 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
106 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
107 ],
108 foreignKeys=[
109 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
110 ],
111 )
112 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
113 result.fields.add(fieldSpec)
114 return result
117def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
118 """Define specification for "collection_chain" table.
120 Parameters
121 ----------
122 collectionIdName : `str`
123 Name of the column in collections table that identifies it (PK).
124 collectionIdType
125 Type of the PK column in the collections table, one of the
126 `sqlalchemy` types.
128 Returns
129 -------
130 spec : `ddl.TableSpec`
131 Specification for collection chain table.
133 Notes
134 -----
135 Collection chain is simply an ordered one-to-many relation between
136 collections. The names of the columns in the table are fixed and
137 also hardcoded in the code below.
138 """
139 return ddl.TableSpec(
140 fields=[
141 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
142 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
143 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
144 ],
145 foreignKeys=[
146 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
147 _makeCollectionForeignKey("child", collectionIdName),
148 ],
149 )
152class DefaultRunRecord(RunRecord):
153 """Default `RunRecord` implementation.
155 This method assumes the same run table definition as produced by
156 `makeRunTableSpec` method. The only non-fixed name in the schema
157 is the PK column name, this needs to be passed in a constructor.
159 Parameters
160 ----------
161 db : `Database`
162 Registry database.
163 key
164 Unique collection ID, can be the same as ``name`` if ``name`` is used
165 for identification. Usually this is an integer or string, but can be
166 other database-specific type.
167 name : `str`
168 Run collection name.
169 table : `sqlalchemy.schema.Table`
170 Table for run records.
171 idColumnName : `str`
172 Name of the identifying column in run table.
173 host : `str`, optional
174 Name of the host where run was produced.
175 timespan : `Timespan`, optional
176 Timespan for this run.
177 """
179 def __init__(
180 self,
181 db: Database,
182 key: Any,
183 name: str,
184 *,
185 table: sqlalchemy.schema.Table,
186 idColumnName: str,
187 host: str | None = None,
188 timespan: Timespan | None = None,
189 ):
190 super().__init__(key=key, name=name, type=CollectionType.RUN)
191 self._db = db
192 self._table = table
193 self._host = host
194 if timespan is None: 194 ↛ 196line 194 didn't jump to line 196, because the condition on line 194 was never false
195 timespan = Timespan(begin=None, end=None)
196 self._timespan = timespan
197 self._idName = idColumnName
199 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None:
200 # Docstring inherited from RunRecord.
201 if timespan is None:
202 timespan = Timespan(begin=None, end=None)
203 row = {
204 self._idName: self.key,
205 "host": host,
206 }
207 self._db.getTimespanRepresentation().update(timespan, result=row)
208 count = self._db.update(self._table, {self._idName: self.key}, row)
209 if count != 1:
210 raise RuntimeError(f"Run update affected {count} records; expected exactly one.")
211 self._host = host
212 self._timespan = timespan
214 @property
215 def host(self) -> str | None:
216 # Docstring inherited from RunRecord.
217 return self._host
219 @property
220 def timespan(self) -> Timespan:
221 # Docstring inherited from RunRecord.
222 return self._timespan
225class DefaultChainedCollectionRecord(ChainedCollectionRecord):
226 """Default `ChainedCollectionRecord` implementation.
228 This method assumes the same chain table definition as produced by
229 `makeCollectionChainTableSpec` method. All column names in the table are
230 fixed and hard-coded in the methods.
232 Parameters
233 ----------
234 db : `Database`
235 Registry database.
236 key
237 Unique collection ID, can be the same as ``name`` if ``name`` is used
238 for identification. Usually this is an integer or string, but can be
239 other database-specific type.
240 name : `str`
241 Collection name.
242 table : `sqlalchemy.schema.Table`
243 Table for chain relationship records.
244 universe : `DimensionUniverse`
245 Object managing all known dimensions.
246 """
248 def __init__(
249 self,
250 db: Database,
251 key: Any,
252 name: str,
253 *,
254 table: sqlalchemy.schema.Table,
255 universe: DimensionUniverse,
256 ):
257 super().__init__(key=key, name=name, universe=universe)
258 self._db = db
259 self._table = table
260 self._universe = universe
262 def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None:
263 # Docstring inherited from ChainedCollectionRecord.
264 rows = []
265 position = itertools.count()
266 for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
267 rows.append(
268 {
269 "parent": self.key,
270 "child": child.key,
271 "position": next(position),
272 }
273 )
274 with self._db.transaction():
275 self._db.delete(self._table, ["parent"], {"parent": self.key})
276 self._db.insert(self._table, *rows)
278 def _load(self, manager: CollectionManager) -> tuple[str, ...]:
279 # Docstring inherited from ChainedCollectionRecord.
280 sql = (
281 sqlalchemy.sql.select(
282 self._table.columns.child,
283 )
284 .select_from(self._table)
285 .where(self._table.columns.parent == self.key)
286 .order_by(self._table.columns.position)
287 )
288 with self._db.query(sql) as sql_result:
289 return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings())
292K = TypeVar("K")
295class DefaultCollectionManager(Generic[K], CollectionManager):
296 """Default `CollectionManager` implementation.
298 This implementation uses record classes defined in this module and is
299 based on the same assumptions about schema outlined in the record classes.
301 Parameters
302 ----------
303 db : `Database`
304 Interface to the underlying database engine and namespace.
305 tables : `CollectionTablesTuple`
306 Named tuple of SQLAlchemy table objects.
307 collectionIdName : `str`
308 Name of the column in collections table that identifies it (PK).
309 dimensions : `DimensionRecordStorageManager`
310 Manager object for the dimensions in this `Registry`.
312 Notes
313 -----
314 Implementation uses "aggressive" pre-fetching and caching of the records
315 in memory. Memory cache is synchronized from database when `refresh`
316 method is called.
317 """
319 def __init__(
320 self,
321 db: Database,
322 tables: CollectionTablesTuple,
323 collectionIdName: str,
324 *,
325 dimensions: DimensionRecordStorageManager,
326 registry_schema_version: VersionTuple | None = None,
327 ):
328 super().__init__(registry_schema_version=registry_schema_version)
329 self._db = db
330 self._tables = tables
331 self._collectionIdName = collectionIdName
332 self._records: dict[K, CollectionRecord] = {} # indexed by record ID
333 self._dimensions = dimensions
335 def refresh(self) -> None:
336 # Docstring inherited from CollectionManager.
337 sql = sqlalchemy.sql.select(
338 *(list(self._tables.collection.columns) + list(self._tables.run.columns))
339 ).select_from(self._tables.collection.join(self._tables.run, isouter=True))
340 # Put found records into a temporary instead of updating self._records
341 # in place, for exception safety.
342 records = []
343 chains = []
344 TimespanReprClass = self._db.getTimespanRepresentation()
345 with self._db.query(sql) as sql_result:
346 sql_rows = sql_result.mappings().fetchall()
347 for row in sql_rows:
348 collection_id = row[self._tables.collection.columns[self._collectionIdName]]
349 name = row[self._tables.collection.columns.name]
350 type = CollectionType(row["type"])
351 record: CollectionRecord
352 if type is CollectionType.RUN:
353 record = DefaultRunRecord(
354 key=collection_id,
355 name=name,
356 db=self._db,
357 table=self._tables.run,
358 idColumnName=self._collectionIdName,
359 host=row[self._tables.run.columns.host],
360 timespan=TimespanReprClass.extract(row),
361 )
362 elif type is CollectionType.CHAINED:
363 record = DefaultChainedCollectionRecord(
364 db=self._db,
365 key=collection_id,
366 table=self._tables.collection_chain,
367 name=name,
368 universe=self._dimensions.universe,
369 )
370 chains.append(record)
371 else:
372 record = CollectionRecord(key=collection_id, name=name, type=type)
373 records.append(record)
374 self._setRecordCache(records)
375 for chain in chains:
376 try:
377 chain.refresh(self)
378 except MissingCollectionError:
379 # This indicates a race condition in which some other client
380 # created a new collection and added it as a child of this
381 # (pre-existing) chain between the time we fetched all
382 # collections and the time we queried for parent-child
383 # relationships.
384 # Because that's some other unrelated client, we shouldn't care
385 # about that parent collection anyway, so we just drop it on
386 # the floor (a manual refresh can be used to get it back).
387 self._removeCachedRecord(chain)
389 def register(
390 self, name: str, type: CollectionType, doc: str | None = None
391 ) -> tuple[CollectionRecord, bool]:
392 # Docstring inherited from CollectionManager.
393 registered = False
394 record = self._getByName(name)
395 if record is None:
396 row, inserted_or_updated = self._db.sync(
397 self._tables.collection,
398 keys={"name": name},
399 compared={"type": int(type)},
400 extra={"doc": doc},
401 returning=[self._collectionIdName],
402 )
403 assert isinstance(inserted_or_updated, bool)
404 registered = inserted_or_updated
405 assert row is not None
406 collection_id = row[self._collectionIdName]
407 if type is CollectionType.RUN:
408 TimespanReprClass = self._db.getTimespanRepresentation()
409 row, _ = self._db.sync(
410 self._tables.run,
411 keys={self._collectionIdName: collection_id},
412 returning=("host",) + TimespanReprClass.getFieldNames(),
413 )
414 assert row is not None
415 record = DefaultRunRecord(
416 db=self._db,
417 key=collection_id,
418 name=name,
419 table=self._tables.run,
420 idColumnName=self._collectionIdName,
421 host=row["host"],
422 timespan=TimespanReprClass.extract(row),
423 )
424 elif type is CollectionType.CHAINED:
425 record = DefaultChainedCollectionRecord(
426 db=self._db,
427 key=collection_id,
428 name=name,
429 table=self._tables.collection_chain,
430 universe=self._dimensions.universe,
431 )
432 else:
433 record = CollectionRecord(key=collection_id, name=name, type=type)
434 self._addCachedRecord(record)
435 return record, registered
437 def remove(self, name: str) -> None:
438 # Docstring inherited from CollectionManager.
439 record = self._getByName(name)
440 if record is None: 440 ↛ 441line 440 didn't jump to line 441, because the condition on line 440 was never true
441 raise MissingCollectionError(f"No collection with name '{name}' found.")
442 # This may raise
443 self._db.delete(
444 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
445 )
446 self._removeCachedRecord(record)
448 def find(self, name: str) -> CollectionRecord:
449 # Docstring inherited from CollectionManager.
450 result = self._getByName(name)
451 if result is None:
452 raise MissingCollectionError(f"No collection with name '{name}' found.")
453 return result
455 def __getitem__(self, key: Any) -> CollectionRecord:
456 # Docstring inherited from CollectionManager.
457 try:
458 return self._records[key]
459 except KeyError as err:
460 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err
462 def resolve_wildcard(
463 self,
464 wildcard: CollectionWildcard,
465 *,
466 collection_types: Set[CollectionType] = CollectionType.all(),
467 done: set[str] | None = None,
468 flatten_chains: bool = True,
469 include_chains: bool | None = None,
470 ) -> list[CollectionRecord]:
471 # Docstring inherited
472 if done is None: 472 ↛ 474line 472 didn't jump to line 474, because the condition on line 472 was never false
473 done = set()
474 include_chains = include_chains if include_chains is not None else not flatten_chains
476 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]:
477 if record.name in done:
478 return
479 if record.type in collection_types:
480 done.add(record.name)
481 if record.type is not CollectionType.CHAINED or include_chains:
482 yield record
483 if flatten_chains and record.type is CollectionType.CHAINED:
484 done.add(record.name)
485 for name in cast(ChainedCollectionRecord, record).children:
486 # flake8 can't tell that we only delete this closure when
487 # we're totally done with it.
488 yield from resolve_nested(self.find(name), done) # noqa: F821
490 result: list[CollectionRecord] = []
492 if wildcard.patterns is Ellipsis:
493 for record in self._records.values():
494 result.extend(resolve_nested(record, done))
495 del resolve_nested
496 return result
497 for name in wildcard.strings:
498 result.extend(resolve_nested(self.find(name), done))
499 if wildcard.patterns:
500 for record in self._records.values():
501 if any(p.fullmatch(record.name) for p in wildcard.patterns):
502 result.extend(resolve_nested(record, done))
503 del resolve_nested
504 return result
506 def getDocumentation(self, key: Any) -> str | None:
507 # Docstring inherited from CollectionManager.
508 sql = (
509 sqlalchemy.sql.select(self._tables.collection.columns.doc)
510 .select_from(self._tables.collection)
511 .where(self._tables.collection.columns[self._collectionIdName] == key)
512 )
513 with self._db.query(sql) as sql_result:
514 return sql_result.scalar()
516 def setDocumentation(self, key: Any, doc: str | None) -> None:
517 # Docstring inherited from CollectionManager.
518 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
520 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None:
521 """Set internal record cache to contain given records,
522 old cached records will be removed.
523 """
524 self._records = {}
525 for record in records:
526 self._records[record.key] = record
528 def _addCachedRecord(self, record: CollectionRecord) -> None:
529 """Add single record to cache."""
530 self._records[record.key] = record
532 def _removeCachedRecord(self, record: CollectionRecord) -> None:
533 """Remove single record from cache."""
534 del self._records[record.key]
536 @abstractmethod
537 def _getByName(self, name: str) -> CollectionRecord | None:
538 """Find collection record given collection name."""
539 raise NotImplementedError()