Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 17%
168 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-28 04:40 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-28 04:40 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import uuid
29import warnings
30from collections import defaultdict
31from collections.abc import Collection, Iterable, Iterator, Mapping
32from contextlib import contextmanager
33from typing import TYPE_CHECKING, Any, Type, cast
35import sqlalchemy
36from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
37from lsst.daf.relation import Join
38from lsst.sphgeom import Region
39from lsst.utils.iteration import chunk_iterable
41from ..interfaces import ObsCoreTableManager, VersionTuple
42from ._config import ConfigCollectionType, ObsCoreManagerConfig
43from ._records import ExposureRegionFactory, Record, RecordFactory
44from ._schema import ObsCoreSchema
45from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
47if TYPE_CHECKING: 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true
48 from ..interfaces import (
49 CollectionRecord,
50 Database,
51 DatasetRecordStorageManager,
52 DimensionRecordStorageManager,
53 StaticTablesContext,
54 )
55 from ..queries import SqlQueryContext
57_VERSION = VersionTuple(0, 0, 1)
60class _ExposureRegionFactory(ExposureRegionFactory):
61 """Find exposure region from a matching visit dimensions records."""
63 def __init__(self, dimensions: DimensionRecordStorageManager):
64 self.dimensions = dimensions
65 self.universe = dimensions.universe
66 self.exposure_dimensions = self.universe["exposure"].graph
67 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"])
69 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
70 # Docstring is inherited from a base class.
71 # Make a relation that starts with visit_definition (mapping between
72 # exposure and visit).
73 relation = context.make_initial_relation()
74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
75 if visit_definition_storage is None:
76 return None
77 relation = visit_definition_storage.join(relation, Join(), context)
78 # Join in a table with either visit+detector regions or visit regions.
79 if "detector" in dataId.names:
80 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
81 if visit_detector_region_storage is None:
82 return None
83 relation = visit_detector_region_storage.join(relation, Join(), context)
84 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
85 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
86 else:
87 visit_storage = self.dimensions.get(self.universe["visit"])
88 if visit_storage is None:
89 return None
90 relation = visit_storage.join(relation, Join(), context)
91 constraint_data_id = dataId.subset(self.exposure_dimensions)
92 region_tag = DimensionRecordColumnTag("visit", "region")
93 # Constrain the relation to match the given exposure and (if present)
94 # detector IDs.
95 relation = relation.with_rows_satisfying(
96 context.make_data_coordinate_predicate(constraint_data_id, full=False)
97 )
98 # If we get more than one result (because the exposure belongs to
99 # multiple visits), just pick an arbitrary one.
100 relation = relation[:1]
101 # Run the query and extract the region, if the query has any results.
102 for row in context.fetch_iterable(relation):
103 return row[region_tag]
104 return None
107class ObsCoreLiveTableManager(ObsCoreTableManager):
108 """A manager class for ObsCore table, implements methods for updating the
109 records in that table.
110 """
112 def __init__(
113 self,
114 *,
115 db: Database,
116 table: sqlalchemy.schema.Table,
117 schema: ObsCoreSchema,
118 universe: DimensionUniverse,
119 config: ObsCoreManagerConfig,
120 dimensions: DimensionRecordStorageManager,
121 spatial_plugins: Collection[SpatialObsCorePlugin],
122 ):
123 self.db = db
124 self.table = table
125 self.schema = schema
126 self.universe = universe
127 self.config = config
128 self.spatial_plugins = spatial_plugins
129 exposure_region_factory = _ExposureRegionFactory(dimensions)
130 self.record_factory = RecordFactory(
131 config, schema, universe, spatial_plugins, exposure_region_factory
132 )
133 self.tagged_collection: str | None = None
134 self.run_patterns: list[re.Pattern] = []
135 if config.collection_type is ConfigCollectionType.TAGGED:
136 assert (
137 config.collections is not None and len(config.collections) == 1
138 ), "Exactly one collection name required for tagged type."
139 self.tagged_collection = config.collections[0]
140 elif config.collection_type is ConfigCollectionType.RUN:
141 if config.collections:
142 for coll in config.collections:
143 try:
144 self.run_patterns.append(re.compile(coll))
145 except re.error as exc:
146 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
147 else:
148 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
150 @classmethod
151 def initialize(
152 cls,
153 db: Database,
154 context: StaticTablesContext,
155 *,
156 universe: DimensionUniverse,
157 config: Mapping,
158 datasets: Type[DatasetRecordStorageManager],
159 dimensions: DimensionRecordStorageManager,
160 ) -> ObsCoreTableManager:
161 # Docstring inherited from base class.
162 config_data = Config(config)
163 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
165 # Instantiate all spatial plugins.
166 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
168 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
170 # Generate table specification for main obscore table.
171 table_spec = schema.table_spec
172 for plugin in spatial_plugins:
173 plugin.extend_table_spec(table_spec)
174 table = context.addTable(obscore_config.table_name, schema.table_spec)
176 return ObsCoreLiveTableManager(
177 db=db,
178 table=table,
179 schema=schema,
180 universe=universe,
181 config=obscore_config,
182 dimensions=dimensions,
183 spatial_plugins=spatial_plugins,
184 )
186 def config_json(self) -> str:
187 """Dump configuration in JSON format.
189 Returns
190 -------
191 json : `str`
192 Configuration serialized in JSON format.
193 """
194 return json.dumps(self.config.dict())
196 @classmethod
197 def currentVersion(cls) -> VersionTuple | None:
198 # Docstring inherited from base class.
199 return _VERSION
201 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
202 # Docstring inherited from base class.
204 # Only makes sense for RUN collection types
205 if self.config.collection_type is not ConfigCollectionType.RUN:
206 return 0
208 obscore_refs: Iterable[DatasetRef]
209 if self.run_patterns:
210 # Check each dataset run against configured run list. We want to
211 # reduce number of calls to _check_dataset_run, which may be
212 # expensive. Normally references are grouped by run, if there are
213 # multiple input references, they should have the same run.
214 # Instead of just checking that, we group them by run again.
215 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
216 for ref in refs:
217 # Record factory will filter dataset types, but to reduce
218 # collection checks we also pre-filter it here.
219 if ref.datasetType.name not in self.config.dataset_types:
220 continue
222 assert ref.run is not None, "Run cannot be None"
223 refs_by_run[ref.run].append(ref)
225 good_refs: list[DatasetRef] = []
226 for run, run_refs in refs_by_run.items():
227 if not self._check_dataset_run(run):
228 continue
229 good_refs.extend(run_refs)
230 obscore_refs = good_refs
232 else:
233 # Take all refs, no collection check.
234 obscore_refs = refs
236 return self._populate(obscore_refs, context)
238 def associate(
239 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
240 ) -> int:
241 # Docstring inherited from base class.
243 # Only works when collection type is TAGGED
244 if self.tagged_collection is None:
245 return 0
247 if collection.name == self.tagged_collection:
248 return self._populate(refs, context)
249 else:
250 return 0
252 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
253 # Docstring inherited from base class.
255 # Only works when collection type is TAGGED
256 if self.tagged_collection is None:
257 return 0
259 count = 0
260 if collection.name == self.tagged_collection:
261 # Sorting may improve performance
262 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs)
263 if dataset_ids:
264 fk_field = self.schema.dataset_fk
265 assert fk_field is not None, "Cannot be None by construction"
266 # There may be too many of them, do it in chunks.
267 for ids in chunk_iterable(dataset_ids):
268 where = self.table.columns[fk_field.name].in_(ids)
269 count += self.db.deleteWhere(self.table, where)
270 return count
272 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
273 """Populate obscore table with the data from given datasets."""
274 records: list[Record] = []
275 for ref in refs:
276 record = self.record_factory(ref, context)
277 if record is not None:
278 records.append(record)
280 if records:
281 # Ignore potential conflicts with existing datasets.
282 return self.db.ensure(self.table, *records, primary_key_only=True)
283 else:
284 return 0
286 def _check_dataset_run(self, run: str) -> bool:
287 """Check that specified run collection matches know patterns."""
289 if not self.run_patterns:
290 # Empty list means take anything.
291 return True
293 # Try each pattern in turn.
294 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
296 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
297 # Docstring inherited from base class.
298 instrument_column = self.schema.dimension_column("instrument")
299 exposure_column = self.schema.dimension_column("exposure")
300 detector_column = self.schema.dimension_column("detector")
301 if instrument_column is None or exposure_column is None or detector_column is None:
302 # Not all needed columns are in the table.
303 return 0
305 update_rows: list[Record] = []
306 for exposure, detector, region in region_data:
307 try:
308 record = self.record_factory.make_spatial_records(region)
309 except RegionTypeError as exc:
310 warnings.warn(
311 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
312 category=RegionTypeWarning,
313 )
314 continue
316 record.update(
317 {
318 "instrument_column": instrument,
319 "exposure_column": exposure,
320 "detector_column": detector,
321 }
322 )
323 update_rows.append(record)
325 where_dict: dict[str, str] = {
326 instrument_column: "instrument_column",
327 exposure_column: "exposure_column",
328 detector_column: "detector_column",
329 }
331 count = self.db.update(self.table, where_dict, *update_rows)
332 return count
334 @contextmanager
335 def query(self, **kwargs: Any) -> Iterator[sqlalchemy.engine.CursorResult]:
336 """Run a SELECT query against obscore table and return result rows.
338 Parameters
339 ----------
340 **kwargs
341 Restriction on values of individual obscore columns. Key is the
342 column name, value is the required value of the column. Multiple
343 restrictions are ANDed together.
344 """
345 query = self.table.select()
346 if kwargs:
347 query = query.where(
348 sqlalchemy.sql.expression.and_(
349 *[self.table.columns[column] == value for column, value in kwargs.items()]
350 )
351 )
352 with self.db.query(query) as result:
353 yield result