Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 20%
144 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-07 02:47 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-07 02:47 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import uuid
29from collections import defaultdict
30from collections.abc import Mapping
31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast
33import sqlalchemy
34from lsst.daf.butler import (
35 Config,
36 DataCoordinate,
37 DataCoordinateIterable,
38 DatasetRef,
39 Dimension,
40 DimensionUniverse,
41)
42from lsst.sphgeom import Region
43from lsst.utils.iteration import chunk_iterable
45from ..interfaces import ObsCoreTableManager, VersionTuple
46from ._config import ConfigCollectionType, ObsCoreManagerConfig
47from ._records import ExposureRegionFactory, RecordFactory
48from ._schema import ObsCoreSchema
50if TYPE_CHECKING: 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true
51 from ..interfaces import (
52 CollectionRecord,
53 Database,
54 DatasetRecordStorageManager,
55 DimensionRecordStorageManager,
56 StaticTablesContext,
57 )
59_VERSION = VersionTuple(0, 0, 1)
62class _ExposureRegionFactory(ExposureRegionFactory):
63 """Find exposure region from a matching visit dimensions records."""
65 def __init__(self, dimensions: DimensionRecordStorageManager):
66 self.dimensions = dimensions
67 self.universe = dimensions.universe
68 self.exposure = self.universe["exposure"]
69 self.visit = self.universe["visit"]
71 def exposure_region(self, dataId: DataCoordinate) -> Optional[Region]:
72 # Docstring is inherited from a base class.
73 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
74 if visit_definition_storage is None:
75 return None
76 exposureDataId = dataId.subset(self.exposure.graph)
77 records = visit_definition_storage.fetch(DataCoordinateIterable.fromScalar(exposureDataId))
78 # There may be more than one visit per exposure, they should nave the
79 # same region, so we use arbitrary one.
80 record = next(iter(records), None)
81 if record is None:
82 return None
83 visit: int = record.visit
85 detector = cast(Dimension, self.universe["detector"])
86 if detector in dataId:
87 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
88 if visit_detector_region_storage is None:
89 return None
90 visitDataId = DataCoordinate.standardize(
91 {
92 "instrument": dataId["instrument"],
93 "visit": visit,
94 "detector": dataId["detector"],
95 },
96 universe=self.universe,
97 )
98 records = visit_detector_region_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId))
99 record = next(iter(records), None)
100 if record is not None:
101 return record.region
103 else:
105 visit_storage = self.dimensions.get(self.visit)
106 if visit_storage is None:
107 return None
108 visitDataId = DataCoordinate.standardize(
109 {
110 "instrument": dataId["instrument"],
111 "visit": visit,
112 },
113 universe=self.universe,
114 )
115 records = visit_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId))
116 record = next(iter(records), None)
117 if record is not None:
118 return record.region
120 return None
123class ObsCoreLiveTableManager(ObsCoreTableManager):
124 """A manager class for ObsCore table, implements methods for updating the
125 records in that table.
126 """
128 def __init__(
129 self,
130 *,
131 db: Database,
132 table: sqlalchemy.schema.Table,
133 schema: ObsCoreSchema,
134 universe: DimensionUniverse,
135 config: ObsCoreManagerConfig,
136 dimensions: DimensionRecordStorageManager,
137 ):
138 self.db = db
139 self.table = table
140 self.schema = schema
141 self.universe = universe
142 self.config = config
143 exposure_region_factory = _ExposureRegionFactory(dimensions)
144 self.record_factory = RecordFactory(config, schema, universe, exposure_region_factory)
146 @classmethod
147 def initialize(
148 cls,
149 db: Database,
150 context: StaticTablesContext,
151 *,
152 universe: DimensionUniverse,
153 config: Mapping,
154 datasets: Type[DatasetRecordStorageManager],
155 dimensions: DimensionRecordStorageManager,
156 ) -> ObsCoreTableManager:
157 # Docstring inherited from base class.
158 config_data = Config(config)
159 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
161 schema = ObsCoreSchema(config=obscore_config, datasets=datasets)
162 table = context.addTable(obscore_config.table_name, schema.table_spec)
163 if obscore_config.collection_type is ConfigCollectionType.TAGGED:
164 # Configuration validation guarantees that there is exactly one
165 # collection for TAGGED type.
166 assert obscore_config.collections is not None, "Collections must be defined"
167 return _TaggedObsCoreTableManager(
168 db=db,
169 table=table,
170 schema=schema,
171 universe=universe,
172 config=obscore_config,
173 tagged_collection=obscore_config.collections[0],
174 dimensions=dimensions,
175 )
176 elif obscore_config.collection_type is ConfigCollectionType.RUN:
177 return _RunObsCoreTableManager(
178 db=db,
179 table=table,
180 schema=schema,
181 universe=universe,
182 config=obscore_config,
183 dimensions=dimensions,
184 )
185 else:
186 raise ValueError(f"Unexpected value of collection_type: {obscore_config.collection_type}")
188 def config_json(self) -> str:
189 """Dump configuration in JSON format.
191 Returns
192 -------
193 json : `str`
194 Configuration serialized in JSON format.
195 """
196 return json.dumps(self.config.dict())
198 @classmethod
199 def currentVersion(cls) -> Optional[VersionTuple]:
200 # Docstring inherited from base class.
201 return _VERSION
203 def schemaDigest(self) -> Optional[str]:
204 # Docstring inherited from base class.
205 return None
208class _TaggedObsCoreTableManager(ObsCoreLiveTableManager):
209 """Implementation of ObsCoreTableManager which is used for
210 ``collection_type=TAGGED``.
211 """
213 def __init__(
214 self,
215 *,
216 db: Database,
217 table: sqlalchemy.schema.Table,
218 schema: ObsCoreSchema,
219 universe: DimensionUniverse,
220 config: ObsCoreManagerConfig,
221 tagged_collection: str,
222 dimensions: DimensionRecordStorageManager,
223 ):
224 super().__init__(
225 db=db,
226 table=table,
227 schema=schema,
228 universe=universe,
229 config=config,
230 dimensions=dimensions,
231 )
232 self.tagged_collection = tagged_collection
234 def add_datasets(self, refs: Iterable[DatasetRef]) -> None:
235 # Docstring inherited from base class.
236 return
238 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None:
239 # Docstring inherited from base class.
241 if collection.name == self.tagged_collection:
243 records: List[dict] = []
244 for ref in refs:
245 if (record := self.record_factory(ref)) is not None:
246 records.append(record)
248 if records:
249 # Ignore potential conflicts with existing datasets.
250 self.db.ensure(self.table, *records, primary_key_only=True)
252 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None:
253 # Docstring inherited from base class.
255 if collection.name == self.tagged_collection:
257 # Sorting may improve performance
258 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs)
259 if dataset_ids:
260 fk_field = self.schema.dataset_fk
261 assert fk_field is not None, "Cannot be None by construction"
262 # There may be too many of them, do it in chunks.
263 for ids in chunk_iterable(dataset_ids):
264 where = self.table.columns[fk_field.name].in_(ids)
265 self.db.deleteWhere(self.table, where)
268class _RunObsCoreTableManager(ObsCoreLiveTableManager):
269 """Implementation of ObsCoreTableManager which is used for
270 ``collection_type=TAGGED``.
271 """
273 def __init__(
274 self,
275 *,
276 db: Database,
277 table: sqlalchemy.schema.Table,
278 schema: ObsCoreSchema,
279 universe: DimensionUniverse,
280 config: ObsCoreManagerConfig,
281 dimensions: DimensionRecordStorageManager,
282 ):
283 super().__init__(
284 db=db,
285 table=table,
286 schema=schema,
287 universe=universe,
288 config=config,
289 dimensions=dimensions,
290 )
292 self.run_patterns: List[re.Pattern] = []
293 if config.collections:
294 for coll in config.collections:
295 try:
296 self.run_patterns.append(re.compile(coll))
297 except re.error as exc:
298 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
300 def add_datasets(self, refs: Iterable[DatasetRef]) -> None:
301 # Docstring inherited from base class.
303 obscore_refs: Iterable[DatasetRef]
304 if self.run_patterns:
305 # Check each dataset run against configured run list. We want to
306 # reduce number of calls to _check_dataset_run, which may be
307 # expensive. Normally references are grouped by run, if there are
308 # multiple input references, they should have the same run.
309 # Instead of just checking that, we group them by run again.
310 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list)
311 for ref in refs:
313 # Record factory will filter dataset types, but to reduce
314 # collection checks we also pre-filter it here.
315 if ref.datasetType.name not in self.config.dataset_types:
316 continue
318 assert ref.run is not None, "Run cannot be None"
319 refs_by_run[ref.run].append(ref)
321 good_refs: List[DatasetRef] = []
322 for run, run_refs in refs_by_run.items():
323 if not self._check_dataset_run(run):
324 continue
325 good_refs.extend(run_refs)
326 obscore_refs = good_refs
328 else:
330 # Take all refs, no collection check.
331 obscore_refs = refs
333 # Convert them all to records.
334 records: List[dict] = []
335 for ref in obscore_refs:
336 if (record := self.record_factory(ref)) is not None:
337 records.append(record)
339 if records:
340 # Ignore potential conflicts with existing datasets.
341 self.db.ensure(self.table, *records, primary_key_only=True)
343 def _check_dataset_run(self, run: str) -> bool:
344 """Check that specified run collection matches know patterns."""
346 if not self.run_patterns:
347 # Empty list means take anything.
348 return True
350 # Try each pattern in turn.
351 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
353 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None:
354 # Docstring inherited from base class.
355 return
357 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None:
358 # Docstring inherited from base class.
359 return