Coverage for python/lsst/daf/butler/registry/collections/_base.py: 92%
180 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 07:59 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 07:59 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ()
31import itertools
32from abc import abstractmethod
33from collections import namedtuple
34from collections.abc import Iterable, Iterator, Set
35from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
37import sqlalchemy
39from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl
40from .._collectionType import CollectionType
41from .._exceptions import MissingCollectionError
42from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
43from ..wildcards import CollectionWildcard
45if TYPE_CHECKING:
46 from ..interfaces import Database, DimensionRecordStorageManager
49def _makeCollectionForeignKey(
50 sourceColumnName: str, collectionIdName: str, **kwargs: Any
51) -> ddl.ForeignKeySpec:
52 """Define foreign key specification that refers to collections table.
54 Parameters
55 ----------
56 sourceColumnName : `str`
57 Name of the column in the referring table.
58 collectionIdName : `str`
59 Name of the column in collections table that identifies it (PK).
60 **kwargs
61 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
63 Returns
64 -------
65 spec : `ddl.ForeignKeySpec`
66 Foreign key specification.
68 Notes
69 -----
70 This method assumes fixed name ("collection") of a collections table.
71 There is also a general assumption that collection primary key consists
72 of a single column.
73 """
74 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
77CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
80def makeRunTableSpec(
81 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
82) -> ddl.TableSpec:
83 """Define specification for "run" table.
85 Parameters
86 ----------
87 collectionIdName : `str`
88 Name of the column in collections table that identifies it (PK).
89 collectionIdType
90 Type of the PK column in the collections table, one of the
91 `sqlalchemy` types.
92 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
93 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
94 timespans are stored in this database.
97 Returns
98 -------
99 spec : `ddl.TableSpec`
100 Specification for run table.
102 Notes
103 -----
104 Assumption here and in the code below is that the name of the identifying
105 column is the same in both collections and run tables. The names of
106 non-identifying columns containing run metadata are fixed.
107 """
108 result = ddl.TableSpec(
109 fields=[
110 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
111 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
112 ],
113 foreignKeys=[
114 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
115 ],
116 )
117 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
118 result.fields.add(fieldSpec)
119 return result
122def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
123 """Define specification for "collection_chain" table.
125 Parameters
126 ----------
127 collectionIdName : `str`
128 Name of the column in collections table that identifies it (PK).
129 collectionIdType
130 Type of the PK column in the collections table, one of the
131 `sqlalchemy` types.
133 Returns
134 -------
135 spec : `ddl.TableSpec`
136 Specification for collection chain table.
138 Notes
139 -----
140 Collection chain is simply an ordered one-to-many relation between
141 collections. The names of the columns in the table are fixed and
142 also hardcoded in the code below.
143 """
144 return ddl.TableSpec(
145 fields=[
146 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
147 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
148 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
149 ],
150 foreignKeys=[
151 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
152 _makeCollectionForeignKey("child", collectionIdName),
153 ],
154 )
157class DefaultRunRecord(RunRecord):
158 """Default `RunRecord` implementation.
160 This method assumes the same run table definition as produced by
161 `makeRunTableSpec` method. The only non-fixed name in the schema
162 is the PK column name, this needs to be passed in a constructor.
164 Parameters
165 ----------
166 db : `Database`
167 Registry database.
168 key
169 Unique collection ID, can be the same as ``name`` if ``name`` is used
170 for identification. Usually this is an integer or string, but can be
171 other database-specific type.
172 name : `str`
173 Run collection name.
174 table : `sqlalchemy.schema.Table`
175 Table for run records.
176 idColumnName : `str`
177 Name of the identifying column in run table.
178 host : `str`, optional
179 Name of the host where run was produced.
180 timespan : `Timespan`, optional
181 Timespan for this run.
182 """
184 def __init__(
185 self,
186 db: Database,
187 key: Any,
188 name: str,
189 *,
190 table: sqlalchemy.schema.Table,
191 idColumnName: str,
192 host: str | None = None,
193 timespan: Timespan | None = None,
194 ):
195 super().__init__(key=key, name=name, type=CollectionType.RUN)
196 self._db = db
197 self._table = table
198 self._host = host
199 if timespan is None: 199 ↛ 201line 199 didn't jump to line 201, because the condition on line 199 was never false
200 timespan = Timespan(begin=None, end=None)
201 self._timespan = timespan
202 self._idName = idColumnName
204 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None:
205 # Docstring inherited from RunRecord.
206 if timespan is None:
207 timespan = Timespan(begin=None, end=None)
208 row = {
209 self._idName: self.key,
210 "host": host,
211 }
212 self._db.getTimespanRepresentation().update(timespan, result=row)
213 count = self._db.update(self._table, {self._idName: self.key}, row)
214 if count != 1:
215 raise RuntimeError(f"Run update affected {count} records; expected exactly one.")
216 self._host = host
217 self._timespan = timespan
219 @property
220 def host(self) -> str | None:
221 # Docstring inherited from RunRecord.
222 return self._host
224 @property
225 def timespan(self) -> Timespan:
226 # Docstring inherited from RunRecord.
227 return self._timespan
230class DefaultChainedCollectionRecord(ChainedCollectionRecord):
231 """Default `ChainedCollectionRecord` implementation.
233 This method assumes the same chain table definition as produced by
234 `makeCollectionChainTableSpec` method. All column names in the table are
235 fixed and hard-coded in the methods.
237 Parameters
238 ----------
239 db : `Database`
240 Registry database.
241 key
242 Unique collection ID, can be the same as ``name`` if ``name`` is used
243 for identification. Usually this is an integer or string, but can be
244 other database-specific type.
245 name : `str`
246 Collection name.
247 table : `sqlalchemy.schema.Table`
248 Table for chain relationship records.
249 universe : `DimensionUniverse`
250 Object managing all known dimensions.
251 """
253 def __init__(
254 self,
255 db: Database,
256 key: Any,
257 name: str,
258 *,
259 table: sqlalchemy.schema.Table,
260 universe: DimensionUniverse,
261 ):
262 super().__init__(key=key, name=name, universe=universe)
263 self._db = db
264 self._table = table
265 self._universe = universe
267 def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None:
268 # Docstring inherited from ChainedCollectionRecord.
269 rows = []
270 position = itertools.count()
271 for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
272 rows.append(
273 {
274 "parent": self.key,
275 "child": child.key,
276 "position": next(position),
277 }
278 )
279 with self._db.transaction():
280 self._db.delete(self._table, ["parent"], {"parent": self.key})
281 self._db.insert(self._table, *rows)
283 def _load(self, manager: CollectionManager) -> tuple[str, ...]:
284 # Docstring inherited from ChainedCollectionRecord.
285 sql = (
286 sqlalchemy.sql.select(
287 self._table.columns.child,
288 )
289 .select_from(self._table)
290 .where(self._table.columns.parent == self.key)
291 .order_by(self._table.columns.position)
292 )
293 with self._db.query(sql) as sql_result:
294 return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings())
297K = TypeVar("K")
300class DefaultCollectionManager(Generic[K], CollectionManager):
301 """Default `CollectionManager` implementation.
303 This implementation uses record classes defined in this module and is
304 based on the same assumptions about schema outlined in the record classes.
306 Parameters
307 ----------
308 db : `Database`
309 Interface to the underlying database engine and namespace.
310 tables : `CollectionTablesTuple`
311 Named tuple of SQLAlchemy table objects.
312 collectionIdName : `str`
313 Name of the column in collections table that identifies it (PK).
314 dimensions : `DimensionRecordStorageManager`
315 Manager object for the dimensions in this `Registry`.
317 Notes
318 -----
319 Implementation uses "aggressive" pre-fetching and caching of the records
320 in memory. Memory cache is synchronized from database when `refresh`
321 method is called.
322 """
324 def __init__(
325 self,
326 db: Database,
327 tables: CollectionTablesTuple,
328 collectionIdName: str,
329 *,
330 dimensions: DimensionRecordStorageManager,
331 registry_schema_version: VersionTuple | None = None,
332 ):
333 super().__init__(registry_schema_version=registry_schema_version)
334 self._db = db
335 self._tables = tables
336 self._collectionIdName = collectionIdName
337 self._records: dict[K, CollectionRecord] = {} # indexed by record ID
338 self._dimensions = dimensions
340 def refresh(self) -> None:
341 # Docstring inherited from CollectionManager.
342 sql = sqlalchemy.sql.select(
343 *(list(self._tables.collection.columns) + list(self._tables.run.columns))
344 ).select_from(self._tables.collection.join(self._tables.run, isouter=True))
345 # Put found records into a temporary instead of updating self._records
346 # in place, for exception safety.
347 records = []
348 chains = []
349 TimespanReprClass = self._db.getTimespanRepresentation()
350 with self._db.query(sql) as sql_result:
351 sql_rows = sql_result.mappings().fetchall()
352 for row in sql_rows:
353 collection_id = row[self._tables.collection.columns[self._collectionIdName]]
354 name = row[self._tables.collection.columns.name]
355 type = CollectionType(row["type"])
356 record: CollectionRecord
357 if type is CollectionType.RUN:
358 record = DefaultRunRecord(
359 key=collection_id,
360 name=name,
361 db=self._db,
362 table=self._tables.run,
363 idColumnName=self._collectionIdName,
364 host=row[self._tables.run.columns.host],
365 timespan=TimespanReprClass.extract(row),
366 )
367 elif type is CollectionType.CHAINED:
368 record = DefaultChainedCollectionRecord(
369 db=self._db,
370 key=collection_id,
371 table=self._tables.collection_chain,
372 name=name,
373 universe=self._dimensions.universe,
374 )
375 chains.append(record)
376 else:
377 record = CollectionRecord(key=collection_id, name=name, type=type)
378 records.append(record)
379 self._setRecordCache(records)
380 for chain in chains:
381 try:
382 chain.refresh(self)
383 except MissingCollectionError:
384 # This indicates a race condition in which some other client
385 # created a new collection and added it as a child of this
386 # (pre-existing) chain between the time we fetched all
387 # collections and the time we queried for parent-child
388 # relationships.
389 # Because that's some other unrelated client, we shouldn't care
390 # about that parent collection anyway, so we just drop it on
391 # the floor (a manual refresh can be used to get it back).
392 self._removeCachedRecord(chain)
394 def register(
395 self, name: str, type: CollectionType, doc: str | None = None
396 ) -> tuple[CollectionRecord, bool]:
397 # Docstring inherited from CollectionManager.
398 registered = False
399 record = self._getByName(name)
400 if record is None:
401 row, inserted_or_updated = self._db.sync(
402 self._tables.collection,
403 keys={"name": name},
404 compared={"type": int(type)},
405 extra={"doc": doc},
406 returning=[self._collectionIdName],
407 )
408 assert isinstance(inserted_or_updated, bool)
409 registered = inserted_or_updated
410 assert row is not None
411 collection_id = row[self._collectionIdName]
412 if type is CollectionType.RUN:
413 TimespanReprClass = self._db.getTimespanRepresentation()
414 row, _ = self._db.sync(
415 self._tables.run,
416 keys={self._collectionIdName: collection_id},
417 returning=("host",) + TimespanReprClass.getFieldNames(),
418 )
419 assert row is not None
420 record = DefaultRunRecord(
421 db=self._db,
422 key=collection_id,
423 name=name,
424 table=self._tables.run,
425 idColumnName=self._collectionIdName,
426 host=row["host"],
427 timespan=TimespanReprClass.extract(row),
428 )
429 elif type is CollectionType.CHAINED:
430 record = DefaultChainedCollectionRecord(
431 db=self._db,
432 key=collection_id,
433 name=name,
434 table=self._tables.collection_chain,
435 universe=self._dimensions.universe,
436 )
437 else:
438 record = CollectionRecord(key=collection_id, name=name, type=type)
439 self._addCachedRecord(record)
440 return record, registered
442 def remove(self, name: str) -> None:
443 # Docstring inherited from CollectionManager.
444 record = self._getByName(name)
445 if record is None: 445 ↛ 446line 445 didn't jump to line 446, because the condition on line 445 was never true
446 raise MissingCollectionError(f"No collection with name '{name}' found.")
447 # This may raise
448 self._db.delete(
449 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
450 )
451 self._removeCachedRecord(record)
453 def find(self, name: str) -> CollectionRecord:
454 # Docstring inherited from CollectionManager.
455 result = self._getByName(name)
456 if result is None:
457 raise MissingCollectionError(f"No collection with name '{name}' found.")
458 return result
460 def __getitem__(self, key: Any) -> CollectionRecord:
461 # Docstring inherited from CollectionManager.
462 try:
463 return self._records[key]
464 except KeyError as err:
465 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err
467 def resolve_wildcard(
468 self,
469 wildcard: CollectionWildcard,
470 *,
471 collection_types: Set[CollectionType] = CollectionType.all(),
472 done: set[str] | None = None,
473 flatten_chains: bool = True,
474 include_chains: bool | None = None,
475 ) -> list[CollectionRecord]:
476 # Docstring inherited
477 if done is None: 477 ↛ 479line 477 didn't jump to line 479, because the condition on line 477 was never false
478 done = set()
479 include_chains = include_chains if include_chains is not None else not flatten_chains
481 def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]:
482 if record.name in done:
483 return
484 if record.type in collection_types:
485 done.add(record.name)
486 if record.type is not CollectionType.CHAINED or include_chains:
487 yield record
488 if flatten_chains and record.type is CollectionType.CHAINED:
489 done.add(record.name)
490 for name in cast(ChainedCollectionRecord, record).children:
491 # flake8 can't tell that we only delete this closure when
492 # we're totally done with it.
493 yield from resolve_nested(self.find(name), done) # noqa: F821
495 result: list[CollectionRecord] = []
497 if wildcard.patterns is ...:
498 for record in self._records.values():
499 result.extend(resolve_nested(record, done))
500 del resolve_nested
501 return result
502 for name in wildcard.strings:
503 result.extend(resolve_nested(self.find(name), done))
504 if wildcard.patterns:
505 for record in self._records.values():
506 if any(p.fullmatch(record.name) for p in wildcard.patterns):
507 result.extend(resolve_nested(record, done))
508 del resolve_nested
509 return result
511 def getDocumentation(self, key: Any) -> str | None:
512 # Docstring inherited from CollectionManager.
513 sql = (
514 sqlalchemy.sql.select(self._tables.collection.columns.doc)
515 .select_from(self._tables.collection)
516 .where(self._tables.collection.columns[self._collectionIdName] == key)
517 )
518 with self._db.query(sql) as sql_result:
519 return sql_result.scalar()
521 def setDocumentation(self, key: Any, doc: str | None) -> None:
522 # Docstring inherited from CollectionManager.
523 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
525 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None:
526 """Set internal record cache to contain given records,
527 old cached records will be removed.
528 """
529 self._records = {}
530 for record in records:
531 self._records[record.key] = record
533 def _addCachedRecord(self, record: CollectionRecord) -> None:
534 """Add single record to cache."""
535 self._records[record.key] = record
537 def _removeCachedRecord(self, record: CollectionRecord) -> None:
538 """Remove single record from cache."""
539 del self._records[record.key]
541 @abstractmethod
542 def _getByName(self, name: str) -> CollectionRecord | None:
543 """Find collection record given collection name."""
544 raise NotImplementedError()