Coverage for python/lsst/daf/butler/registry/collections/_base.py: 88%
155 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-02 18:18 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-02 18:18 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ()
25import itertools
26from abc import abstractmethod
27from collections import namedtuple
28from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Tuple, Type, TypeVar
30import sqlalchemy
32from ...core import DimensionUniverse, Timespan, TimespanDatabaseRepresentation, ddl
33from .._collectionType import CollectionType
34from .._exceptions import MissingCollectionError
35from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord
36from ..wildcards import CollectionSearch
38if TYPE_CHECKING: 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true
39 from ..interfaces import Database, DimensionRecordStorageManager
42def _makeCollectionForeignKey(
43 sourceColumnName: str, collectionIdName: str, **kwargs: Any
44) -> ddl.ForeignKeySpec:
45 """Define foreign key specification that refers to collections table.
47 Parameters
48 ----------
49 sourceColumnName : `str`
50 Name of the column in the referring table.
51 collectionIdName : `str`
52 Name of the column in collections table that identifies it (PK).
53 **kwargs
54 Additional keyword arguments passed directly to `ddl.ForeignKeySpec`.
56 Returns
57 -------
58 spec : `ddl.ForeignKeySpec`
59 Foreign key specification.
61 Notes
62 -----
63 This method assumes fixed name ("collection") of a collections table.
64 There is also a general assumption that collection primary key consists
65 of a single column.
66 """
67 return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs)
70CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"])
73def makeRunTableSpec(
74 collectionIdName: str, collectionIdType: type, TimespanReprClass: Type[TimespanDatabaseRepresentation]
75) -> ddl.TableSpec:
76 """Define specification for "run" table.
78 Parameters
79 ----------
80 collectionIdName : `str`
81 Name of the column in collections table that identifies it (PK).
82 collectionIdType
83 Type of the PK column in the collections table, one of the
84 `sqlalchemy` types.
85 TimespanReprClass : `type` [ `TimespanDatabaseRepresentation` ]
86 Subclass of `TimespanDatabaseRepresentation` that encapsulates how
87 timespans are stored in this database.
90 Returns
91 -------
92 spec : `ddl.TableSpec`
93 Specification for run table.
95 Notes
96 -----
97 Assumption here and in the code below is that the name of the identifying
98 column is the same in both collections and run tables. The names of
99 non-identifying columns containing run metadata are fixed.
100 """
101 result = ddl.TableSpec(
102 fields=[
103 ddl.FieldSpec(collectionIdName, dtype=collectionIdType, primaryKey=True),
104 ddl.FieldSpec("host", dtype=sqlalchemy.String, length=128),
105 ],
106 foreignKeys=[
107 _makeCollectionForeignKey(collectionIdName, collectionIdName, onDelete="CASCADE"),
108 ],
109 )
110 for fieldSpec in TimespanReprClass.makeFieldSpecs(nullable=True):
111 result.fields.add(fieldSpec)
112 return result
115def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) -> ddl.TableSpec:
116 """Define specification for "collection_chain" table.
118 Parameters
119 ----------
120 collectionIdName : `str`
121 Name of the column in collections table that identifies it (PK).
122 collectionIdType
123 Type of the PK column in the collections table, one of the
124 `sqlalchemy` types.
126 Returns
127 -------
128 spec : `ddl.TableSpec`
129 Specification for collection chain table.
131 Notes
132 -----
133 Collection chain is simply an ordered one-to-many relation between
134 collections. The names of the columns in the table are fixed and
135 also hardcoded in the code below.
136 """
137 return ddl.TableSpec(
138 fields=[
139 ddl.FieldSpec("parent", dtype=collectionIdType, primaryKey=True),
140 ddl.FieldSpec("position", dtype=sqlalchemy.SmallInteger, primaryKey=True),
141 ddl.FieldSpec("child", dtype=collectionIdType, nullable=False),
142 ],
143 foreignKeys=[
144 _makeCollectionForeignKey("parent", collectionIdName, onDelete="CASCADE"),
145 _makeCollectionForeignKey("child", collectionIdName),
146 ],
147 )
150class DefaultRunRecord(RunRecord):
151 """Default `RunRecord` implementation.
153 This method assumes the same run table definition as produced by
154 `makeRunTableSpec` method. The only non-fixed name in the schema
155 is the PK column name, this needs to be passed in a constructor.
157 Parameters
158 ----------
159 db : `Database`
160 Registry database.
161 key
162 Unique collection ID, can be the same as ``name`` if ``name`` is used
163 for identification. Usually this is an integer or string, but can be
164 other database-specific type.
165 name : `str`
166 Run collection name.
167 table : `sqlalchemy.schema.Table`
168 Table for run records.
169 idColumnName : `str`
170 Name of the identifying column in run table.
171 host : `str`, optional
172 Name of the host where run was produced.
173 timespan : `Timespan`, optional
174 Timespan for this run.
175 """
177 def __init__(
178 self,
179 db: Database,
180 key: Any,
181 name: str,
182 *,
183 table: sqlalchemy.schema.Table,
184 idColumnName: str,
185 host: Optional[str] = None,
186 timespan: Optional[Timespan] = None,
187 ):
188 super().__init__(key=key, name=name, type=CollectionType.RUN)
189 self._db = db
190 self._table = table
191 self._host = host
192 if timespan is None: 192 ↛ 194line 192 didn't jump to line 194, because the condition on line 192 was never false
193 timespan = Timespan(begin=None, end=None)
194 self._timespan = timespan
195 self._idName = idColumnName
197 def update(self, host: Optional[str] = None, timespan: Optional[Timespan] = None) -> None:
198 # Docstring inherited from RunRecord.
199 if timespan is None:
200 timespan = Timespan(begin=None, end=None)
201 row = {
202 self._idName: self.key,
203 "host": host,
204 }
205 self._db.getTimespanRepresentation().update(timespan, result=row)
206 count = self._db.update(self._table, {self._idName: self.key}, row)
207 if count != 1:
208 raise RuntimeError(f"Run update affected {count} records; expected exactly one.")
209 self._host = host
210 self._timespan = timespan
212 @property
213 def host(self) -> Optional[str]:
214 # Docstring inherited from RunRecord.
215 return self._host
217 @property
218 def timespan(self) -> Timespan:
219 # Docstring inherited from RunRecord.
220 return self._timespan
223class DefaultChainedCollectionRecord(ChainedCollectionRecord):
224 """Default `ChainedCollectionRecord` implementation.
226 This method assumes the same chain table definition as produced by
227 `makeCollectionChainTableSpec` method. All column names in the table are
228 fixed and hard-coded in the methods.
230 Parameters
231 ----------
232 db : `Database`
233 Registry database.
234 key
235 Unique collection ID, can be the same as ``name`` if ``name`` is used
236 for identification. Usually this is an integer or string, but can be
237 other database-specific type.
238 name : `str`
239 Collection name.
240 table : `sqlalchemy.schema.Table`
241 Table for chain relationship records.
242 universe : `DimensionUniverse`
243 Object managing all known dimensions.
244 """
246 def __init__(
247 self,
248 db: Database,
249 key: Any,
250 name: str,
251 *,
252 table: sqlalchemy.schema.Table,
253 universe: DimensionUniverse,
254 ):
255 super().__init__(key=key, name=name, universe=universe)
256 self._db = db
257 self._table = table
258 self._universe = universe
260 def _update(self, manager: CollectionManager, children: CollectionSearch) -> None:
261 # Docstring inherited from ChainedCollectionRecord.
262 rows = []
263 position = itertools.count()
264 for child in children.iter(manager, flattenChains=False):
265 rows.append(
266 {
267 "parent": self.key,
268 "child": child.key,
269 "position": next(position),
270 }
271 )
272 with self._db.transaction():
273 self._db.delete(self._table, ["parent"], {"parent": self.key})
274 self._db.insert(self._table, *rows)
276 def _load(self, manager: CollectionManager) -> CollectionSearch:
277 # Docstring inherited from ChainedCollectionRecord.
278 sql = (
279 sqlalchemy.sql.select(
280 self._table.columns.child,
281 )
282 .select_from(self._table)
283 .where(self._table.columns.parent == self.key)
284 .order_by(self._table.columns.position)
285 )
286 with self._db.query(sql) as sql_result:
287 return CollectionSearch.fromExpression(
288 tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings())
289 )
292K = TypeVar("K")
295class DefaultCollectionManager(Generic[K], CollectionManager):
296 """Default `CollectionManager` implementation.
298 This implementation uses record classes defined in this module and is
299 based on the same assumptions about schema outlined in the record classes.
301 Parameters
302 ----------
303 db : `Database`
304 Interface to the underlying database engine and namespace.
305 tables : `CollectionTablesTuple`
306 Named tuple of SQLAlchemy table objects.
307 collectionIdName : `str`
308 Name of the column in collections table that identifies it (PK).
309 dimensions : `DimensionRecordStorageManager`
310 Manager object for the dimensions in this `Registry`.
312 Notes
313 -----
314 Implementation uses "aggressive" pre-fetching and caching of the records
315 in memory. Memory cache is synchronized from database when `refresh`
316 method is called.
317 """
319 def __init__(
320 self,
321 db: Database,
322 tables: CollectionTablesTuple,
323 collectionIdName: str,
324 *,
325 dimensions: DimensionRecordStorageManager,
326 ):
327 super().__init__()
328 self._db = db
329 self._tables = tables
330 self._collectionIdName = collectionIdName
331 self._records: Dict[K, CollectionRecord] = {} # indexed by record ID
332 self._dimensions = dimensions
334 def refresh(self) -> None:
335 # Docstring inherited from CollectionManager.
336 sql = sqlalchemy.sql.select(
337 *(list(self._tables.collection.columns) + list(self._tables.run.columns))
338 ).select_from(self._tables.collection.join(self._tables.run, isouter=True))
339 # Put found records into a temporary instead of updating self._records
340 # in place, for exception safety.
341 records = []
342 chains = []
343 TimespanReprClass = self._db.getTimespanRepresentation()
344 with self._db.query(sql) as sql_result:
345 sql_rows = sql_result.mappings().fetchall()
346 for row in sql_rows:
347 collection_id = row[self._tables.collection.columns[self._collectionIdName]]
348 name = row[self._tables.collection.columns.name]
349 type = CollectionType(row["type"])
350 record: CollectionRecord
351 if type is CollectionType.RUN:
352 record = DefaultRunRecord(
353 key=collection_id,
354 name=name,
355 db=self._db,
356 table=self._tables.run,
357 idColumnName=self._collectionIdName,
358 host=row[self._tables.run.columns.host],
359 timespan=TimespanReprClass.extract(row),
360 )
361 elif type is CollectionType.CHAINED:
362 record = DefaultChainedCollectionRecord(
363 db=self._db,
364 key=collection_id,
365 table=self._tables.collection_chain,
366 name=name,
367 universe=self._dimensions.universe,
368 )
369 chains.append(record)
370 else:
371 record = CollectionRecord(key=collection_id, name=name, type=type)
372 records.append(record)
373 self._setRecordCache(records)
374 for chain in chains:
375 try:
376 chain.refresh(self)
377 except MissingCollectionError:
378 # This indicates a race condition in which some other client
379 # created a new collection and added it as a child of this
380 # (pre-existing) chain between the time we fetched all
381 # collections and the time we queried for parent-child
382 # relationships.
383 # Because that's some other unrelated client, we shouldn't care
384 # about that parent collection anyway, so we just drop it on
385 # the floor (a manual refresh can be used to get it back).
386 self._removeCachedRecord(chain)
388 def register(
389 self, name: str, type: CollectionType, doc: Optional[str] = None
390 ) -> Tuple[CollectionRecord, bool]:
391 # Docstring inherited from CollectionManager.
392 registered = False
393 record = self._getByName(name)
394 if record is None:
395 row, inserted_or_updated = self._db.sync(
396 self._tables.collection,
397 keys={"name": name},
398 compared={"type": int(type)},
399 extra={"doc": doc},
400 returning=[self._collectionIdName],
401 )
402 assert isinstance(inserted_or_updated, bool)
403 registered = inserted_or_updated
404 assert row is not None
405 collection_id = row[self._collectionIdName]
406 if type is CollectionType.RUN:
407 TimespanReprClass = self._db.getTimespanRepresentation()
408 row, _ = self._db.sync(
409 self._tables.run,
410 keys={self._collectionIdName: collection_id},
411 returning=("host",) + TimespanReprClass.getFieldNames(),
412 )
413 assert row is not None
414 record = DefaultRunRecord(
415 db=self._db,
416 key=collection_id,
417 name=name,
418 table=self._tables.run,
419 idColumnName=self._collectionIdName,
420 host=row["host"],
421 timespan=TimespanReprClass.extract(row),
422 )
423 elif type is CollectionType.CHAINED:
424 record = DefaultChainedCollectionRecord(
425 db=self._db,
426 key=collection_id,
427 name=name,
428 table=self._tables.collection_chain,
429 universe=self._dimensions.universe,
430 )
431 else:
432 record = CollectionRecord(key=collection_id, name=name, type=type)
433 self._addCachedRecord(record)
434 return record, registered
436 def remove(self, name: str) -> None:
437 # Docstring inherited from CollectionManager.
438 record = self._getByName(name)
439 if record is None: 439 ↛ 440line 439 didn't jump to line 440, because the condition on line 439 was never true
440 raise MissingCollectionError(f"No collection with name '{name}' found.")
441 # This may raise
442 self._db.delete(
443 self._tables.collection, [self._collectionIdName], {self._collectionIdName: record.key}
444 )
445 self._removeCachedRecord(record)
447 def find(self, name: str) -> CollectionRecord:
448 # Docstring inherited from CollectionManager.
449 result = self._getByName(name)
450 if result is None:
451 raise MissingCollectionError(f"No collection with name '{name}' found.")
452 return result
454 def __getitem__(self, key: Any) -> CollectionRecord:
455 # Docstring inherited from CollectionManager.
456 try:
457 return self._records[key]
458 except KeyError as err:
459 raise MissingCollectionError(f"Collection with key '{key}' not found.") from err
461 def __iter__(self) -> Iterator[CollectionRecord]:
462 yield from self._records.values()
464 def getDocumentation(self, key: Any) -> Optional[str]:
465 # Docstring inherited from CollectionManager.
466 sql = (
467 sqlalchemy.sql.select(self._tables.collection.columns.doc)
468 .select_from(self._tables.collection)
469 .where(self._tables.collection.columns[self._collectionIdName] == key)
470 )
471 with self._db.query(sql) as sql_result:
472 return sql_result.scalar()
474 def setDocumentation(self, key: Any, doc: Optional[str]) -> None:
475 # Docstring inherited from CollectionManager.
476 self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
478 def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None:
479 """Set internal record cache to contain given records,
480 old cached records will be removed.
481 """
482 self._records = {}
483 for record in records:
484 self._records[record.key] = record
486 def _addCachedRecord(self, record: CollectionRecord) -> None:
487 """Add single record to cache."""
488 self._records[record.key] = record
490 def _removeCachedRecord(self, record: CollectionRecord) -> None:
491 """Remove single record from cache."""
492 del self._records[record.key]
494 @abstractmethod
495 def _getByName(self, name: str) -> Optional[CollectionRecord]:
496 """Find collection record given collection name."""
497 raise NotImplementedError()