Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 17%
166 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-31 02:41 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-31 02:41 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import uuid
29import warnings
30from collections import defaultdict
31from collections.abc import Collection, Iterable, Iterator, Mapping
32from contextlib import contextmanager
33from typing import TYPE_CHECKING, Any, Type, cast
35import sqlalchemy
36from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
37from lsst.daf.relation import Join
38from lsst.sphgeom import Region
39from lsst.utils.iteration import chunk_iterable
41from ..interfaces import ObsCoreTableManager, VersionTuple
42from ._config import ConfigCollectionType, ObsCoreManagerConfig
43from ._records import ExposureRegionFactory, Record, RecordFactory
44from ._schema import ObsCoreSchema
45from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
47if TYPE_CHECKING:
48 from ..interfaces import (
49 CollectionRecord,
50 Database,
51 DatasetRecordStorageManager,
52 DimensionRecordStorageManager,
53 StaticTablesContext,
54 )
55 from ..queries import SqlQueryContext
57_VERSION = VersionTuple(0, 0, 1)
60class _ExposureRegionFactory(ExposureRegionFactory):
61 """Find exposure region from a matching visit dimensions records."""
63 def __init__(self, dimensions: DimensionRecordStorageManager):
64 self.dimensions = dimensions
65 self.universe = dimensions.universe
66 self.exposure_dimensions = self.universe["exposure"].graph
67 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"])
69 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
70 # Docstring is inherited from a base class.
71 # Make a relation that starts with visit_definition (mapping between
72 # exposure and visit).
73 relation = context.make_initial_relation()
74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
75 if visit_definition_storage is None:
76 return None
77 relation = visit_definition_storage.join(relation, Join(), context)
78 # Join in a table with either visit+detector regions or visit regions.
79 if "detector" in dataId.names:
80 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
81 if visit_detector_region_storage is None:
82 return None
83 relation = visit_detector_region_storage.join(relation, Join(), context)
84 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
85 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
86 else:
87 visit_storage = self.dimensions.get(self.universe["visit"])
88 if visit_storage is None:
89 return None
90 relation = visit_storage.join(relation, Join(), context)
91 constraint_data_id = dataId.subset(self.exposure_dimensions)
92 region_tag = DimensionRecordColumnTag("visit", "region")
93 # Constrain the relation to match the given exposure and (if present)
94 # detector IDs.
95 relation = relation.with_rows_satisfying(
96 context.make_data_coordinate_predicate(constraint_data_id, full=False)
97 )
98 # If we get more than one result (because the exposure belongs to
99 # multiple visits), just pick an arbitrary one.
100 relation = relation[:1]
101 # Run the query and extract the region, if the query has any results.
102 for row in context.fetch_iterable(relation):
103 return row[region_tag]
104 return None
107class ObsCoreLiveTableManager(ObsCoreTableManager):
108 """A manager class for ObsCore table, implements methods for updating the
109 records in that table.
110 """
112 def __init__(
113 self,
114 *,
115 db: Database,
116 table: sqlalchemy.schema.Table,
117 schema: ObsCoreSchema,
118 universe: DimensionUniverse,
119 config: ObsCoreManagerConfig,
120 dimensions: DimensionRecordStorageManager,
121 spatial_plugins: Collection[SpatialObsCorePlugin],
122 registry_schema_version: VersionTuple | None = None,
123 ):
124 super().__init__(registry_schema_version=registry_schema_version)
125 self.db = db
126 self.table = table
127 self.schema = schema
128 self.universe = universe
129 self.config = config
130 self.spatial_plugins = spatial_plugins
131 exposure_region_factory = _ExposureRegionFactory(dimensions)
132 self.record_factory = RecordFactory(
133 config, schema, universe, spatial_plugins, exposure_region_factory
134 )
135 self.tagged_collection: str | None = None
136 self.run_patterns: list[re.Pattern] = []
137 if config.collection_type is ConfigCollectionType.TAGGED:
138 assert (
139 config.collections is not None and len(config.collections) == 1
140 ), "Exactly one collection name required for tagged type."
141 self.tagged_collection = config.collections[0]
142 elif config.collection_type is ConfigCollectionType.RUN:
143 if config.collections:
144 for coll in config.collections:
145 try:
146 self.run_patterns.append(re.compile(coll))
147 except re.error as exc:
148 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
149 else:
150 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
152 @classmethod
153 def initialize(
154 cls,
155 db: Database,
156 context: StaticTablesContext,
157 *,
158 universe: DimensionUniverse,
159 config: Mapping,
160 datasets: Type[DatasetRecordStorageManager],
161 dimensions: DimensionRecordStorageManager,
162 registry_schema_version: VersionTuple | None = None,
163 ) -> ObsCoreTableManager:
164 # Docstring inherited from base class.
165 config_data = Config(config)
166 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
168 # Instantiate all spatial plugins.
169 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
171 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
173 # Generate table specification for main obscore table.
174 table_spec = schema.table_spec
175 for plugin in spatial_plugins:
176 plugin.extend_table_spec(table_spec)
177 table = context.addTable(obscore_config.table_name, schema.table_spec)
179 return ObsCoreLiveTableManager(
180 db=db,
181 table=table,
182 schema=schema,
183 universe=universe,
184 config=obscore_config,
185 dimensions=dimensions,
186 spatial_plugins=spatial_plugins,
187 registry_schema_version=registry_schema_version,
188 )
190 def config_json(self) -> str:
191 """Dump configuration in JSON format.
193 Returns
194 -------
195 json : `str`
196 Configuration serialized in JSON format.
197 """
198 return json.dumps(self.config.dict())
200 @classmethod
201 def currentVersions(cls) -> list[VersionTuple]:
202 # Docstring inherited from base class.
203 return [_VERSION]
205 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
206 # Docstring inherited from base class.
208 # Only makes sense for RUN collection types
209 if self.config.collection_type is not ConfigCollectionType.RUN:
210 return 0
212 obscore_refs: Iterable[DatasetRef]
213 if self.run_patterns:
214 # Check each dataset run against configured run list. We want to
215 # reduce number of calls to _check_dataset_run, which may be
216 # expensive. Normally references are grouped by run, if there are
217 # multiple input references, they should have the same run.
218 # Instead of just checking that, we group them by run again.
219 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
220 for ref in refs:
221 # Record factory will filter dataset types, but to reduce
222 # collection checks we also pre-filter it here.
223 if ref.datasetType.name not in self.config.dataset_types:
224 continue
226 assert ref.run is not None, "Run cannot be None"
227 refs_by_run[ref.run].append(ref)
229 good_refs: list[DatasetRef] = []
230 for run, run_refs in refs_by_run.items():
231 if not self._check_dataset_run(run):
232 continue
233 good_refs.extend(run_refs)
234 obscore_refs = good_refs
236 else:
237 # Take all refs, no collection check.
238 obscore_refs = refs
240 return self._populate(obscore_refs, context)
242 def associate(
243 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
244 ) -> int:
245 # Docstring inherited from base class.
247 # Only works when collection type is TAGGED
248 if self.tagged_collection is None:
249 return 0
251 if collection.name == self.tagged_collection:
252 return self._populate(refs, context)
253 else:
254 return 0
256 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
257 # Docstring inherited from base class.
259 # Only works when collection type is TAGGED
260 if self.tagged_collection is None:
261 return 0
263 count = 0
264 if collection.name == self.tagged_collection:
265 # Sorting may improve performance
266 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs)
267 if dataset_ids:
268 fk_field = self.schema.dataset_fk
269 assert fk_field is not None, "Cannot be None by construction"
270 # There may be too many of them, do it in chunks.
271 for ids in chunk_iterable(dataset_ids):
272 where = self.table.columns[fk_field.name].in_(ids)
273 count += self.db.deleteWhere(self.table, where)
274 return count
276 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
277 """Populate obscore table with the data from given datasets."""
278 records: list[Record] = []
279 for ref in refs:
280 record = self.record_factory(ref, context)
281 if record is not None:
282 records.append(record)
284 if records:
285 # Ignore potential conflicts with existing datasets.
286 return self.db.ensure(self.table, *records, primary_key_only=True)
287 else:
288 return 0
290 def _check_dataset_run(self, run: str) -> bool:
291 """Check that specified run collection matches know patterns."""
293 if not self.run_patterns:
294 # Empty list means take anything.
295 return True
297 # Try each pattern in turn.
298 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
300 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
301 # Docstring inherited from base class.
302 instrument_column = self.schema.dimension_column("instrument")
303 exposure_column = self.schema.dimension_column("exposure")
304 detector_column = self.schema.dimension_column("detector")
305 if instrument_column is None or exposure_column is None or detector_column is None:
306 # Not all needed columns are in the table.
307 return 0
309 update_rows: list[Record] = []
310 for exposure, detector, region in region_data:
311 try:
312 record = self.record_factory.make_spatial_records(region)
313 except RegionTypeError as exc:
314 warnings.warn(
315 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
316 category=RegionTypeWarning,
317 )
318 continue
320 record.update(
321 {
322 "instrument_column": instrument,
323 "exposure_column": exposure,
324 "detector_column": detector,
325 }
326 )
327 update_rows.append(record)
329 where_dict: dict[str, str] = {
330 instrument_column: "instrument_column",
331 exposure_column: "exposure_column",
332 detector_column: "detector_column",
333 }
335 count = self.db.update(self.table, where_dict, *update_rows)
336 return count
338 @contextmanager
339 def query(self, **kwargs: Any) -> Iterator[sqlalchemy.engine.CursorResult]:
340 """Run a SELECT query against obscore table and return result rows.
342 Parameters
343 ----------
344 **kwargs
345 Restriction on values of individual obscore columns. Key is the
346 column name, value is the required value of the column. Multiple
347 restrictions are ANDed together.
348 """
349 query = self.table.select()
350 if kwargs:
351 query = query.where(
352 sqlalchemy.sql.expression.and_(
353 *[self.table.columns[column] == value for column, value in kwargs.items()]
354 )
355 )
356 with self.db.query(query) as result:
357 yield result