Coverage for python/lsst/daf/butler/registry/collections/_base.py: 88%
152 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-27 01:59 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-27 01:59 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ()
25import itertools
26from abc import abstractmethod
27from collections import namedtuple
28from collections.abc import Iterable, Iterator
29from typing import TYPE_CHECKING, Any, Generic, TypeVar
31import sqlalchemy
33from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl
34from .._collectionType import CollectionType
35from .._exceptions import MissingCollectionError
36from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord
37from ..wildcards import CollectionSearch
39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true
40 from ..interfaces import Database, DimensionRecordStorageManager
43def _makeCollectionForeignKey(
44 sourceColumnName: str, collectionIdName: str, **kwargs: Any
45) -> ddl.ForeignKeySpec:
46 """Define foreign key specification that refers to collections table.
48 Parameters
49 ----------
50 sourceColumnName : `str`
51 Name of the column in the referring table.
52 collectionIdName : `str`
53 Name of the column in collections table that identifies it (PK).
54 **kwargs
55 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
57 Returns
58 -------
59 spec : `ddl.ForeignKeySpec`
60 Foreign key specification.
62 Notes
63 -----
64 This method assumes fixed name ("collection") of a collections table.
65 There is also a general assumption that collection primary key consists
66 of a single column.
67 """
68 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
71CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
74def makeRunTableSpec(
75 collectionIdName: str, collectionIdType: type, TimespanReprClass: type[TimespanDatabaseRepresentation]
76) -> ddl.TableSpec:
77 """Define specification for "run" table.
79 Parameters
80 ----------
81 collectionIdName : `str`
82 Name of the column in collections table that identifies it (PK).
83 collectionIdType
84 Type of the PK column in the collections table, one of the
85 `sqlalchemy` types.
86 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
87 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
88 timespans are stored in this database.
91 Returns
92 -------
93 spec : `ddl.TableSpec`
94 Specification for run table.
96 Notes
97 -----
98 Assumption here and in the code below is that the name of the identifying
99 column is the same in both collections and run tables. The names of
100 non-identifying columns containing run metadata are fixed.
101 """
102 result = ddl.TableSpec(
103 fields=[
104 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
105 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
106 ],
107 foreignKeys=[
108 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
109 ],
110 )
111 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
112 result.fields.add(fieldSpec)
113 return result
116def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
117 """Define specification for "collection_chain" table.
119 Parameters
120 ----------
121 collectionIdName : `str`
122 Name of the column in collections table that identifies it (PK).
123 collectionIdType
124 Type of the PK column in the collections table, one of the
125 `sqlalchemy` types.
127 Returns
128 -------
129 spec : `ddl.TableSpec`
130 Specification for collection chain table.
132 Notes
133 -----
134 Collection chain is simply an ordered one-to-many relation between
135 collections. The names of the columns in the table are fixed and
136 also hardcoded in the code below.
137 """
138 return ddl.TableSpec(
139 fields=[
140 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
141 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
142 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
143 ],
144 foreignKeys=[
145 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
146 _makeCollectionForeignKey("child", collectionIdName),
147 ],
148 )
151class DefaultRunRecord(RunRecord):
152 """Default `RunRecord` implementation.
154 This method assumes the same run table definition as produced by
155 `makeRunTableSpec` method. The only non-fixed name in the schema
156 is the PK column name, this needs to be passed in a constructor.
158 Parameters
159 ----------
160 db : `Database`
161 Registry database.
162 key
163 Unique collection ID, can be the same as ``name`` if ``name`` is used
164 for identification. Usually this is an integer or string, but can be
165 other database-specific type.
166 name : `str`
167 Run collection name.
168 table : `sqlalchemy.schema.Table`
169 Table for run records.
170 idColumnName : `str`
171 Name of the identifying column in run table.
172 host : `str`, optional
173 Name of the host where run was produced.
174 timespan : `Timespan`, optional
175 Timespan for this run.
176 """
178 def __init__(
179 self,
180 db: Database,
181 key: Any,
182 name: str,
183 *,
184 table: sqlalchemy.schema.Table,
185 idColumnName: str,
186 host: str | None = None,
187 timespan: Timespan | None = None,
188 ):
189 super().__init__(key=key, name=name, type=CollectionType.RUN)
190 self._db = db
191 self._table = table
192 self._host = host
193 if timespan is None: 193 ↛ 195line 193 didn't jump to line 195, because the condition on line 193 was never false
194 timespan = Timespan(begin=None, end=None)
195 self._timespan = timespan
196 self._idName = idColumnName
198 def update(self, host: str | None = None, timespan: Timespan | None = None) -> None:
199 # Docstring inherited from RunRecord.
200 if timespan is None:
201 timespan = Timespan(begin=None, end=None)
202 row = {
203 self._idName: self.key,
204 "host": host,
205 }
206 self._db.getTimespanRepresentation().update(timespan, result=row)
207 count = self._db.update(self._table, {self._idName: self.key}, row)
208 if count != 1:
209 raise RuntimeError(f"Run update affected {count} records; expected exactly one.")
210 self._host = host
211 self._timespan = timespan
213 @property
214 def host(self) -> str | None:
215 # Docstring inherited from RunRecord.
216 return self._host
218 @property
219 def timespan(self) -> Timespan:
220 # Docstring inherited from RunRecord.
221 return self._timespan
224class DefaultChainedCollectionRecord(ChainedCollectionRecord):
225 """Default `ChainedCollectionRecord` implementation.
227 This method assumes the same chain table definition as produced by
228 `makeCollectionChainTableSpec` method. All column names in the table are
229 fixed and hard-coded in the methods.
231 Parameters
232 ----------
233 db : `Database`
234 Registry database.
235 key
236 Unique collection ID, can be the same as ``name`` if ``name`` is used
237 for identification. Usually this is an integer or string, but can be
238 other database-specific type.
239 name : `str`
240 Collection name.
241 table : `sqlalchemy.schema.Table`
242 Table for chain relationship records.
243 universe : `DimensionUniverse`
244 Object managing all known dimensions.
245 """
247 def __init__(
248 self,
249 db: Database,
250 key: Any,
251 name: str,
252 *,
253 table: sqlalchemy.schema.Table,
254 universe: DimensionUniverse,
255 ):
256 super().__init__(key=key, name=name, universe=universe)
257 self._db = db
258 self._table = table
259 self._universe = universe
261 def _update(self, manager: CollectionManager, children: CollectionSearch) -> None:
262 # Docstring inherited from ChainedCollectionRecord.
263 rows = []
264 position = itertools.count()
265 for child in children.iter(manager, flattenChains=False):
266 rows.append(
267 {
268 "parent": self.key,
269 "child": child.key,
270 "position": next(position),
271 }
272 )
273 with self._db.transaction():
274 self._db.delete(self._table, ["parent"], {"parent": self.key})
275 self._db.insert(self._table, *rows)
277 def _load(self, manager: CollectionManager) -> CollectionSearch:
278 # Docstring inherited from ChainedCollectionRecord.
279 sql = (
280 sqlalchemy.sql.select(
281 self._table.columns.child,
282 )
283 .select_from(self._table)
284 .where(self._table.columns.parent == self.key)
285 .order_by(self._table.columns.position)
286 )
287 return CollectionSearch.fromExpression(
288 [manager[row._mapping[self._table.columns.child]].name for row in self._db.query(sql)]
289 )
292K = TypeVar("K")
295class DefaultCollectionManager(Generic[K], CollectionManager):
296 """Default `CollectionManager` implementation.
298 This implementation uses record classes defined in this module and is
299 based on the same assumptions about schema outlined in the record classes.
301 Parameters
302 ----------
303 db : `Database`
304 Interface to the underlying database engine and namespace.
305 tables : `CollectionTablesTuple`
306 Named tuple of SQLAlchemy table objects.
307 collectionIdName : `str`
308 Name of the column in collections table that identifies it (PK).
309 dimensions : `DimensionRecordStorageManager`
310 Manager object for the dimensions in this `Registry`.
312 Notes
313 -----
314 Implementation uses "aggressive" pre-fetching and caching of the records
315 in memory. Memory cache is synchronized from database when `refresh`
316 method is called.
317 """
319 def __init__(
320 self,
321 db: Database,
322 tables: CollectionTablesTuple,
323 collectionIdName: str,
324 *,
325 dimensions: DimensionRecordStorageManager,
326 ):
327 super().__init__()
328 self._db = db
329 self._tables = tables
330 self._collectionIdName = collectionIdName
331 self._records: dict[K, CollectionRecord] = {} # indexed by record ID
332 self._dimensions = dimensions
334 def refresh(self) -> None:
335 # Docstring inherited from CollectionManager.
336 sql = sqlalchemy.sql.select(
337 *(list(self._tables.collection.columns) + list(self._tables.run.columns))
338 ).select_from(self._tables.collection.join(self._tables.run, isouter=True))
339 # Put found records into a temporary instead of updating self._records
340 # in place, for exception safety.
341 records = []
342 chains = []
343 TimespanReprClass = self._db.getTimespanRepresentation()
344 for row in self._db.query(sql).mappings():
345 collection_id = row[self._tables.collection.columns[self._collectionIdName]]
346 name = row[self._tables.collection.columns.name]
347 type = CollectionType(row["type"])
348 record: CollectionRecord
349 if type is CollectionType.RUN:
350 record = DefaultRunRecord(
351 key=collection_id,
352 name=name,
353 db=self._db,
354 table=self._tables.run,
355 idColumnName=self._collectionIdName,
356 host=row[self._tables.run.columns.host],
357 timespan=TimespanReprClass.extract(row),
358 )
359 elif type is CollectionType.CHAINED:
360 record = DefaultChainedCollectionRecord(
361 db=self._db,
362 key=collection_id,
363 table=self._tables.collection_chain,
364 name=name,
365 universe=self._dimensions.universe,
366 )
367 chains.append(record)
368 else:
369 record = CollectionRecord(key=collection_id, name=name, type=type)
370 records.append(record)
371 self._setRecordCache(records)
372 for chain in chains:
373 try:
374 chain.refresh(self)
375 except MissingCollectionError:
376 # This indicates a race condition in which some other client
377 # created a new collection and added it as a child of this
378 # (pre-existing) chain between the time we fetched all
379 # collections and the time we queried for parent-child
380 # relationships.
381 # Because that's some other unrelated client, we shouldn't care
382 # about that parent collection anyway, so we just drop it on
383 # the floor (a manual refresh can be used to get it back).
384 self._removeCachedRecord(chain)
386 def register(
387 self, name: str, type: CollectionType, doc: str | None = None
388 ) -> tuple[CollectionRecord, bool]:
389 # Docstring inherited from CollectionManager.
390 registered = False
391 record = self._getByName(name)
392 if record is None:
393 row, inserted_or_updated = self._db.sync(
394 self._tables.collection,
395 keys={"name": name},
396 compared={"type": int(type)},
397 extra={"doc": doc},
398 returning=[self._collectionIdName],
399 )
400 assert isinstance(inserted_or_updated, bool)
401 registered = inserted_or_updated
402 assert row is not None
403 collection_id = row[self._collectionIdName]
404 if type is CollectionType.RUN:
405 TimespanReprClass = self._db.getTimespanRepresentation()
406 row, _ = self._db.sync(
407 self._tables.run,
408 keys={self._collectionIdName: collection_id},
409 returning=("host",) + TimespanReprClass.getFieldNames(),
410 )
411 assert row is not None
412 record = DefaultRunRecord(
413 db=self._db,
414 key=collection_id,
415 name=name,
416 table=self._tables.run,
417 idColumnName=self._collectionIdName,
418 host=row["host"],
419 timespan=TimespanReprClass.extract(row),
420 )
421 elif type is CollectionType.CHAINED:
422 record = DefaultChainedCollectionRecord(
423 db=self._db,
424 key=collection_id,
425 name=name,
426 table=self._tables.collection_chain,
427 universe=self._dimensions.universe,
428 )
429 else:
430 record = CollectionRecord(key=collection_id, name=name, type=type)
431 self._addCachedRecord(record)
432 return record, registered
434 def remove(self, name: str) -> None:
435 # Docstring inherited from CollectionManager.
436 record = self._getByName(name)
437 if record is None: 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true
438 raise MissingCollectionError(f"No collection with name '{name}' found.")
439 # This may raise
440 self._db.delete(
441 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
442 )
443 self._removeCachedRecord(record)
445 def find(self, name: str) -> CollectionRecord:
446 # Docstring inherited from CollectionManager.
447 result = self._getByName(name)
448 if result is None:
449 raise MissingCollectionError(f"No collection with name '{name}' found.")
450 return result
452 def __getitem__(self, key: Any) -> CollectionRecord:
453 # Docstring inherited from CollectionManager.
454 try:
455 return self._records[key]
456 except KeyError as err:
457 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err
459 def __iter__(self) -> Iterator[CollectionRecord]:
460 yield from self._records.values()
462 def getDocumentation(self, key: Any) -> str | None:
463 # Docstring inherited from CollectionManager.
464 sql = (
465 sqlalchemy.sql.select(self._tables.collection.columns.doc)
466 .select_from(self._tables.collection)
467 .where(self._tables.collection.columns[self._collectionIdName] == key)
468 )
469 return self._db.query(sql).scalar()
471 def setDocumentation(self, key: Any, doc: str | None) -> None:
472 # Docstring inherited from CollectionManager.
473 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
475 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None:
476 """Set internal record cache to contain given records,
477 old cached records will be removed.
478 """
479 self._records = {}
480 for record in records:
481 self._records[record.key] = record
483 def _addCachedRecord(self, record: CollectionRecord) -> None:
484 """Add single record to cache."""
485 self._records[record.key] = record
487 def _removeCachedRecord(self, record: CollectionRecord) -> None:
488 """Remove single record from cache."""
489 del self._records[record.key]
491 @abstractmethod
492 def _getByName(self, name: str) -> CollectionRecord | None:
493 """Find collection record given collection name."""
494 raise NotImplementedError()