Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 17%
146 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-15 00:10 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-15 00:10 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import uuid
29from collections import defaultdict
30from collections.abc import Collection, Mapping
31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast
33import sqlalchemy
34from lsst.daf.butler import (
35 Config,
36 DataCoordinate,
37 DataCoordinateIterable,
38 DatasetRef,
39 Dimension,
40 DimensionUniverse,
41)
42from lsst.sphgeom import Region
43from lsst.utils.iteration import chunk_iterable
45from ..interfaces import ObsCoreTableManager, VersionTuple
46from ._config import ConfigCollectionType, ObsCoreManagerConfig
47from ._records import ExposureRegionFactory, Record, RecordFactory
48from ._schema import ObsCoreSchema
49from ._spatial import SpatialObsCorePlugin
51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true
52 from ..interfaces import (
53 CollectionRecord,
54 Database,
55 DatasetRecordStorageManager,
56 DimensionRecordStorageManager,
57 StaticTablesContext,
58 )
60_VERSION = VersionTuple(0, 0, 1)
63class _ExposureRegionFactory(ExposureRegionFactory):
64 """Find exposure region from a matching visit dimensions records."""
66 def __init__(self, dimensions: DimensionRecordStorageManager):
67 self.dimensions = dimensions
68 self.universe = dimensions.universe
69 self.exposure = self.universe["exposure"]
70 self.visit = self.universe["visit"]
72 def exposure_region(self, dataId: DataCoordinate) -> Optional[Region]:
73 # Docstring is inherited from a base class.
74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
75 if visit_definition_storage is None:
76 return None
77 exposureDataId = dataId.subset(self.exposure.graph)
78 records = visit_definition_storage.fetch(DataCoordinateIterable.fromScalar(exposureDataId))
79 # There may be more than one visit per exposure, they should nave the
80 # same region, so we use arbitrary one.
81 record = next(iter(records), None)
82 if record is None:
83 return None
84 visit: int = record.visit
86 detector = cast(Dimension, self.universe["detector"])
87 if detector in dataId:
88 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
89 if visit_detector_region_storage is None:
90 return None
91 visitDataId = DataCoordinate.standardize(
92 {
93 "instrument": dataId["instrument"],
94 "visit": visit,
95 "detector": dataId["detector"],
96 },
97 universe=self.universe,
98 )
99 records = visit_detector_region_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId))
100 record = next(iter(records), None)
101 if record is not None:
102 return record.region
104 else:
105 visit_storage = self.dimensions.get(self.visit)
106 if visit_storage is None:
107 return None
108 visitDataId = DataCoordinate.standardize(
109 {
110 "instrument": dataId["instrument"],
111 "visit": visit,
112 },
113 universe=self.universe,
114 )
115 records = visit_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId))
116 record = next(iter(records), None)
117 if record is not None:
118 return record.region
120 return None
123class ObsCoreLiveTableManager(ObsCoreTableManager):
124 """A manager class for ObsCore table, implements methods for updating the
125 records in that table.
126 """
128 def __init__(
129 self,
130 *,
131 db: Database,
132 table: sqlalchemy.schema.Table,
133 schema: ObsCoreSchema,
134 universe: DimensionUniverse,
135 config: ObsCoreManagerConfig,
136 dimensions: DimensionRecordStorageManager,
137 spatial_plugins: Collection[SpatialObsCorePlugin],
138 ):
139 self.db = db
140 self.table = table
141 self.schema = schema
142 self.universe = universe
143 self.config = config
144 self.spatial_plugins = spatial_plugins
145 exposure_region_factory = _ExposureRegionFactory(dimensions)
146 self.record_factory = RecordFactory(
147 config, schema, universe, spatial_plugins, exposure_region_factory
148 )
149 self.tagged_collection: Optional[str] = None
150 self.run_patterns: list[re.Pattern] = []
151 if config.collection_type is ConfigCollectionType.TAGGED:
152 assert (
153 config.collections is not None and len(config.collections) == 1
154 ), "Exactly one collection name required for tagged type."
155 self.tagged_collection = config.collections[0]
156 elif config.collection_type is ConfigCollectionType.RUN:
157 if config.collections:
158 for coll in config.collections:
159 try:
160 self.run_patterns.append(re.compile(coll))
161 except re.error as exc:
162 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
163 else:
164 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
166 @classmethod
167 def initialize(
168 cls,
169 db: Database,
170 context: StaticTablesContext,
171 *,
172 universe: DimensionUniverse,
173 config: Mapping,
174 datasets: Type[DatasetRecordStorageManager],
175 dimensions: DimensionRecordStorageManager,
176 ) -> ObsCoreTableManager:
177 # Docstring inherited from base class.
178 config_data = Config(config)
179 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
181 # Instantiate all spatial plugins.
182 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
184 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
186 # Generate table specification for main obscore table.
187 table_spec = schema.table_spec
188 for plugin in spatial_plugins:
189 plugin.extend_table_spec(table_spec)
190 table = context.addTable(obscore_config.table_name, schema.table_spec)
192 return ObsCoreLiveTableManager(
193 db=db,
194 table=table,
195 schema=schema,
196 universe=universe,
197 config=obscore_config,
198 dimensions=dimensions,
199 spatial_plugins=spatial_plugins,
200 )
202 def config_json(self) -> str:
203 """Dump configuration in JSON format.
205 Returns
206 -------
207 json : `str`
208 Configuration serialized in JSON format.
209 """
210 return json.dumps(self.config.dict())
212 @classmethod
213 def currentVersion(cls) -> Optional[VersionTuple]:
214 # Docstring inherited from base class.
215 return _VERSION
217 def schemaDigest(self) -> Optional[str]:
218 # Docstring inherited from base class.
219 return None
221 def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
222 # Docstring inherited from base class.
224 # Only makes sense for RUN collection types
225 if self.config.collection_type is not ConfigCollectionType.RUN:
226 return 0
228 obscore_refs: Iterable[DatasetRef]
229 if self.run_patterns:
230 # Check each dataset run against configured run list. We want to
231 # reduce number of calls to _check_dataset_run, which may be
232 # expensive. Normally references are grouped by run, if there are
233 # multiple input references, they should have the same run.
234 # Instead of just checking that, we group them by run again.
235 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list)
236 for ref in refs:
237 # Record factory will filter dataset types, but to reduce
238 # collection checks we also pre-filter it here.
239 if ref.datasetType.name not in self.config.dataset_types:
240 continue
242 assert ref.run is not None, "Run cannot be None"
243 refs_by_run[ref.run].append(ref)
245 good_refs: List[DatasetRef] = []
246 for run, run_refs in refs_by_run.items():
247 if not self._check_dataset_run(run):
248 continue
249 good_refs.extend(run_refs)
250 obscore_refs = good_refs
252 else:
253 # Take all refs, no collection check.
254 obscore_refs = refs
256 return self._populate(obscore_refs)
258 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
259 # Docstring inherited from base class.
261 # Only works when collection type is TAGGED
262 if self.tagged_collection is None:
263 return 0
265 if collection.name == self.tagged_collection:
266 return self._populate(refs)
267 else:
268 return 0
270 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
271 # Docstring inherited from base class.
273 # Only works when collection type is TAGGED
274 if self.tagged_collection is None:
275 return 0
277 count = 0
278 if collection.name == self.tagged_collection:
279 # Sorting may improve performance
280 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs)
281 if dataset_ids:
282 fk_field = self.schema.dataset_fk
283 assert fk_field is not None, "Cannot be None by construction"
284 # There may be too many of them, do it in chunks.
285 for ids in chunk_iterable(dataset_ids):
286 where = self.table.columns[fk_field.name].in_(ids)
287 count += self.db.deleteWhere(self.table, where)
289 return count
291 def _populate(self, refs: Iterable[DatasetRef]) -> int:
292 """Populate obscore table with the data from given datasets."""
293 records: List[Record] = []
294 for ref in refs:
295 record = self.record_factory(ref)
296 if record is not None:
297 records.append(record)
299 if records:
300 # Ignore potential conflicts with existing datasets.
301 return self.db.ensure(self.table, *records, primary_key_only=True)
302 else:
303 return 0
305 def _check_dataset_run(self, run: str) -> bool:
306 """Check that specified run collection matches know patterns."""
308 if not self.run_patterns:
309 # Empty list means take anything.
310 return True
312 # Try each pattern in turn.
313 return any(pattern.fullmatch(run) for pattern in self.run_patterns)