Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 17%
146 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-29 02:00 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-29 02:00 -0800
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import uuid
29from collections import defaultdict
30from collections.abc import Collection, Mapping
31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast
33import sqlalchemy
34from lsst.daf.butler import (
35 Config,
36 DataCoordinate,
37 DataCoordinateIterable,
38 DatasetRef,
39 Dimension,
40 DimensionUniverse,
41)
42from lsst.sphgeom import Region
43from lsst.utils.iteration import chunk_iterable
45from ..interfaces import ObsCoreTableManager, VersionTuple
46from ._config import ConfigCollectionType, ObsCoreManagerConfig
47from ._records import ExposureRegionFactory, Record, RecordFactory
48from ._schema import ObsCoreSchema
49from ._spatial import SpatialObsCorePlugin
51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true
52 from ..interfaces import (
53 CollectionRecord,
54 Database,
55 DatasetRecordStorageManager,
56 DimensionRecordStorageManager,
57 StaticTablesContext,
58 )
60_VERSION = VersionTuple(0, 0, 1)
63class _ExposureRegionFactory(ExposureRegionFactory):
64 """Find exposure region from a matching visit dimensions records."""
66 def __init__(self, dimensions: DimensionRecordStorageManager):
67 self.dimensions = dimensions
68 self.universe = dimensions.universe
69 self.exposure = self.universe["exposure"]
70 self.visit = self.universe["visit"]
72 def exposure_region(self, dataId: DataCoordinate) -> Optional[Region]:
73 # Docstring is inherited from a base class.
74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
75 if visit_definition_storage is None:
76 return None
77 exposureDataId = dataId.subset(self.exposure.graph)
78 records = visit_definition_storage.fetch(DataCoordinateIterable.fromScalar(exposureDataId))
79 # There may be more than one visit per exposure, they should nave the
80 # same region, so we use arbitrary one.
81 record = next(iter(records), None)
82 if record is None:
83 return None
84 visit: int = record.visit
86 detector = cast(Dimension, self.universe["detector"])
87 if detector in dataId:
88 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
89 if visit_detector_region_storage is None:
90 return None
91 visitDataId = DataCoordinate.standardize(
92 {
93 "instrument": dataId["instrument"],
94 "visit": visit,
95 "detector": dataId["detector"],
96 },
97 universe=self.universe,
98 )
99 records = visit_detector_region_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId))
100 record = next(iter(records), None)
101 if record is not None:
102 return record.region
104 else:
106 visit_storage = self.dimensions.get(self.visit)
107 if visit_storage is None:
108 return None
109 visitDataId = DataCoordinate.standardize(
110 {
111 "instrument": dataId["instrument"],
112 "visit": visit,
113 },
114 universe=self.universe,
115 )
116 records = visit_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId))
117 record = next(iter(records), None)
118 if record is not None:
119 return record.region
121 return None
124class ObsCoreLiveTableManager(ObsCoreTableManager):
125 """A manager class for ObsCore table, implements methods for updating the
126 records in that table.
127 """
129 def __init__(
130 self,
131 *,
132 db: Database,
133 table: sqlalchemy.schema.Table,
134 schema: ObsCoreSchema,
135 universe: DimensionUniverse,
136 config: ObsCoreManagerConfig,
137 dimensions: DimensionRecordStorageManager,
138 spatial_plugins: Collection[SpatialObsCorePlugin],
139 ):
140 self.db = db
141 self.table = table
142 self.schema = schema
143 self.universe = universe
144 self.config = config
145 self.spatial_plugins = spatial_plugins
146 exposure_region_factory = _ExposureRegionFactory(dimensions)
147 self.record_factory = RecordFactory(
148 config, schema, universe, spatial_plugins, exposure_region_factory
149 )
150 self.tagged_collection: Optional[str] = None
151 self.run_patterns: list[re.Pattern] = []
152 if config.collection_type is ConfigCollectionType.TAGGED:
153 assert (
154 config.collections is not None and len(config.collections) == 1
155 ), "Exactly one collection name required for tagged type."
156 self.tagged_collection = config.collections[0]
157 elif config.collection_type is ConfigCollectionType.RUN:
158 if config.collections:
159 for coll in config.collections:
160 try:
161 self.run_patterns.append(re.compile(coll))
162 except re.error as exc:
163 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
164 else:
165 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
167 @classmethod
168 def initialize(
169 cls,
170 db: Database,
171 context: StaticTablesContext,
172 *,
173 universe: DimensionUniverse,
174 config: Mapping,
175 datasets: Type[DatasetRecordStorageManager],
176 dimensions: DimensionRecordStorageManager,
177 ) -> ObsCoreTableManager:
178 # Docstring inherited from base class.
179 config_data = Config(config)
180 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
182 # Instantiate all spatial plugins.
183 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
185 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
187 # Generate table specification for main obscore table.
188 table_spec = schema.table_spec
189 for plugin in spatial_plugins:
190 plugin.extend_table_spec(table_spec)
191 table = context.addTable(obscore_config.table_name, schema.table_spec)
193 return ObsCoreLiveTableManager(
194 db=db,
195 table=table,
196 schema=schema,
197 universe=universe,
198 config=obscore_config,
199 dimensions=dimensions,
200 spatial_plugins=spatial_plugins,
201 )
203 def config_json(self) -> str:
204 """Dump configuration in JSON format.
206 Returns
207 -------
208 json : `str`
209 Configuration serialized in JSON format.
210 """
211 return json.dumps(self.config.dict())
213 @classmethod
214 def currentVersion(cls) -> Optional[VersionTuple]:
215 # Docstring inherited from base class.
216 return _VERSION
218 def schemaDigest(self) -> Optional[str]:
219 # Docstring inherited from base class.
220 return None
222 def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
223 # Docstring inherited from base class.
225 # Only makes sense for RUN collection types
226 if self.config.collection_type is not ConfigCollectionType.RUN:
227 return 0
229 obscore_refs: Iterable[DatasetRef]
230 if self.run_patterns:
231 # Check each dataset run against configured run list. We want to
232 # reduce number of calls to _check_dataset_run, which may be
233 # expensive. Normally references are grouped by run, if there are
234 # multiple input references, they should have the same run.
235 # Instead of just checking that, we group them by run again.
236 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list)
237 for ref in refs:
239 # Record factory will filter dataset types, but to reduce
240 # collection checks we also pre-filter it here.
241 if ref.datasetType.name not in self.config.dataset_types:
242 continue
244 assert ref.run is not None, "Run cannot be None"
245 refs_by_run[ref.run].append(ref)
247 good_refs: List[DatasetRef] = []
248 for run, run_refs in refs_by_run.items():
249 if not self._check_dataset_run(run):
250 continue
251 good_refs.extend(run_refs)
252 obscore_refs = good_refs
254 else:
256 # Take all refs, no collection check.
257 obscore_refs = refs
259 return self._populate(obscore_refs)
261 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
262 # Docstring inherited from base class.
264 # Only works when collection type is TAGGED
265 if self.tagged_collection is None:
266 return 0
268 if collection.name == self.tagged_collection:
269 return self._populate(refs)
270 else:
271 return 0
273 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
274 # Docstring inherited from base class.
276 # Only works when collection type is TAGGED
277 if self.tagged_collection is None:
278 return 0
280 count = 0
281 if collection.name == self.tagged_collection:
283 # Sorting may improve performance
284 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs)
285 if dataset_ids:
286 fk_field = self.schema.dataset_fk
287 assert fk_field is not None, "Cannot be None by construction"
288 # There may be too many of them, do it in chunks.
289 for ids in chunk_iterable(dataset_ids):
290 where = self.table.columns[fk_field.name].in_(ids)
291 count += self.db.deleteWhere(self.table, where)
293 return count
295 def _populate(self, refs: Iterable[DatasetRef]) -> int:
296 """Populate obscore table with the data from given datasets."""
297 records: List[Record] = []
298 for ref in refs:
299 record = self.record_factory(ref)
300 if record is not None:
301 records.append(record)
303 if records:
304 # Ignore potential conflicts with existing datasets.
305 return self.db.ensure(self.table, *records, primary_key_only=True)
306 else:
307 return 0
309 def _check_dataset_run(self, run: str) -> bool:
310 """Check that specified run collection matches know patterns."""
312 if not self.run_patterns:
313 # Empty list means take anything.
314 return True
316 # Try each pattern in turn.
317 return any(pattern.fullmatch(run) for pattern in self.run_patterns)