Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 16%
173 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-05 03:17 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-05 03:17 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import uuid
29import warnings
30from collections import defaultdict
31from collections.abc import Collection, Iterable, Iterator, Mapping
32from contextlib import contextmanager
33from typing import TYPE_CHECKING, Any, Type, cast
35import sqlalchemy
36from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
37from lsst.daf.relation import Join
38from lsst.sphgeom import Region
39from lsst.utils.iteration import chunk_iterable
41from ..interfaces import ObsCoreTableManager, VersionTuple
42from ._config import ConfigCollectionType, ObsCoreManagerConfig
43from ._records import ExposureRegionFactory, Record, RecordFactory
44from ._schema import ObsCoreSchema
45from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
47if TYPE_CHECKING:
48 from ..interfaces import (
49 CollectionRecord,
50 Database,
51 DatasetRecordStorageManager,
52 DimensionRecordStorageManager,
53 StaticTablesContext,
54 )
55 from ..queries import SqlQueryContext
57_VERSION = VersionTuple(0, 0, 1)
60class _ExposureRegionFactory(ExposureRegionFactory):
61 """Find exposure region from a matching visit dimensions records."""
63 def __init__(self, dimensions: DimensionRecordStorageManager):
64 self.dimensions = dimensions
65 self.universe = dimensions.universe
66 self.exposure_dimensions = self.universe["exposure"].graph
67 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"])
69 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
70 # Docstring is inherited from a base class.
71 # Make a relation that starts with visit_definition (mapping between
72 # exposure and visit).
73 relation = context.make_initial_relation()
74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
75 if visit_definition_storage is None:
76 return None
77 relation = visit_definition_storage.join(relation, Join(), context)
78 # Join in a table with either visit+detector regions or visit regions.
79 if "detector" in dataId.names:
80 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
81 if visit_detector_region_storage is None:
82 return None
83 relation = visit_detector_region_storage.join(relation, Join(), context)
84 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
85 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
86 else:
87 visit_storage = self.dimensions.get(self.universe["visit"])
88 if visit_storage is None:
89 return None
90 relation = visit_storage.join(relation, Join(), context)
91 constraint_data_id = dataId.subset(self.exposure_dimensions)
92 region_tag = DimensionRecordColumnTag("visit", "region")
93 # Constrain the relation to match the given exposure and (if present)
94 # detector IDs.
95 relation = relation.with_rows_satisfying(
96 context.make_data_coordinate_predicate(constraint_data_id, full=False)
97 )
98 # If we get more than one result (because the exposure belongs to
99 # multiple visits), just pick an arbitrary one.
100 relation = relation[:1]
101 # Run the query and extract the region, if the query has any results.
102 for row in context.fetch_iterable(relation):
103 return row[region_tag]
104 return None
107class ObsCoreLiveTableManager(ObsCoreTableManager):
108 """A manager class for ObsCore table, implements methods for updating the
109 records in that table.
110 """
112 def __init__(
113 self,
114 *,
115 db: Database,
116 table: sqlalchemy.schema.Table,
117 schema: ObsCoreSchema,
118 universe: DimensionUniverse,
119 config: ObsCoreManagerConfig,
120 dimensions: DimensionRecordStorageManager,
121 spatial_plugins: Collection[SpatialObsCorePlugin],
122 registry_schema_version: VersionTuple | None = None,
123 ):
124 super().__init__(registry_schema_version=registry_schema_version)
125 self.db = db
126 self.table = table
127 self.schema = schema
128 self.universe = universe
129 self.config = config
130 self.spatial_plugins = spatial_plugins
131 exposure_region_factory = _ExposureRegionFactory(dimensions)
132 self.record_factory = RecordFactory(
133 config, schema, universe, spatial_plugins, exposure_region_factory
134 )
135 self.tagged_collection: str | None = None
136 self.run_patterns: list[re.Pattern] = []
137 if config.collection_type is ConfigCollectionType.TAGGED:
138 assert (
139 config.collections is not None and len(config.collections) == 1
140 ), "Exactly one collection name required for tagged type."
141 self.tagged_collection = config.collections[0]
142 elif config.collection_type is ConfigCollectionType.RUN:
143 if config.collections:
144 for coll in config.collections:
145 try:
146 self.run_patterns.append(re.compile(coll))
147 except re.error as exc:
148 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
149 else:
150 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
152 @classmethod
153 def initialize(
154 cls,
155 db: Database,
156 context: StaticTablesContext,
157 *,
158 universe: DimensionUniverse,
159 config: Mapping,
160 datasets: Type[DatasetRecordStorageManager],
161 dimensions: DimensionRecordStorageManager,
162 registry_schema_version: VersionTuple | None = None,
163 ) -> ObsCoreTableManager:
164 # Docstring inherited from base class.
165 config_data = Config(config)
166 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
168 # Instantiate all spatial plugins.
169 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
171 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
173 # Generate table specification for main obscore table.
174 table_spec = schema.table_spec
175 for plugin in spatial_plugins:
176 plugin.extend_table_spec(table_spec)
177 table = context.addTable(obscore_config.table_name, schema.table_spec)
179 return ObsCoreLiveTableManager(
180 db=db,
181 table=table,
182 schema=schema,
183 universe=universe,
184 config=obscore_config,
185 dimensions=dimensions,
186 spatial_plugins=spatial_plugins,
187 registry_schema_version=registry_schema_version,
188 )
190 def config_json(self) -> str:
191 """Dump configuration in JSON format.
193 Returns
194 -------
195 json : `str`
196 Configuration serialized in JSON format.
197 """
198 return json.dumps(self.config.dict())
200 @classmethod
201 def currentVersions(cls) -> list[VersionTuple]:
202 # Docstring inherited from base class.
203 return [_VERSION]
205 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
206 # Docstring inherited from base class.
208 # Only makes sense for RUN collection types
209 if self.config.collection_type is not ConfigCollectionType.RUN:
210 return 0
212 obscore_refs: Iterable[DatasetRef]
213 if self.run_patterns:
214 # Check each dataset run against configured run list. We want to
215 # reduce number of calls to _check_dataset_run, which may be
216 # expensive. Normally references are grouped by run, if there are
217 # multiple input references, they should have the same run.
218 # Instead of just checking that, we group them by run again.
219 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
220 for ref in refs:
221 # Record factory will filter dataset types, but to reduce
222 # collection checks we also pre-filter it here.
223 if ref.datasetType.name not in self.config.dataset_types:
224 continue
226 assert ref.run is not None, "Run cannot be None"
227 refs_by_run[ref.run].append(ref)
229 good_refs: list[DatasetRef] = []
230 for run, run_refs in refs_by_run.items():
231 if not self._check_dataset_run(run):
232 continue
233 good_refs.extend(run_refs)
234 obscore_refs = good_refs
236 else:
237 # Take all refs, no collection check.
238 obscore_refs = refs
240 return self._populate(obscore_refs, context)
242 def associate(
243 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
244 ) -> int:
245 # Docstring inherited from base class.
247 # Only works when collection type is TAGGED
248 if self.tagged_collection is None:
249 return 0
251 if collection.name == self.tagged_collection:
252 return self._populate(refs, context)
253 else:
254 return 0
256 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
257 # Docstring inherited from base class.
259 # Only works when collection type is TAGGED
260 if self.tagged_collection is None:
261 return 0
263 count = 0
264 if collection.name == self.tagged_collection:
265 # Sorting may improve performance
266 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs)
267 if dataset_ids:
268 fk_field = self.schema.dataset_fk
269 assert fk_field is not None, "Cannot be None by construction"
270 # There may be too many of them, do it in chunks.
271 for ids in chunk_iterable(dataset_ids):
272 where = self.table.columns[fk_field.name].in_(ids)
273 count += self.db.deleteWhere(self.table, where)
274 return count
276 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
277 """Populate obscore table with the data from given datasets."""
278 records: list[Record] = []
279 for ref in refs:
280 record = self.record_factory(ref, context)
281 if record is not None:
282 records.append(record)
284 if records:
285 # Ignore potential conflicts with existing datasets.
286 return self.db.ensure(self.table, *records, primary_key_only=True)
287 else:
288 return 0
290 def _check_dataset_run(self, run: str) -> bool:
291 """Check that specified run collection matches know patterns."""
293 if not self.run_patterns:
294 # Empty list means take anything.
295 return True
297 # Try each pattern in turn.
298 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
300 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
301 # Docstring inherited from base class.
302 instrument_column = self.schema.dimension_column("instrument")
303 exposure_column = self.schema.dimension_column("exposure")
304 detector_column = self.schema.dimension_column("detector")
305 if instrument_column is None or exposure_column is None or detector_column is None:
306 # Not all needed columns are in the table.
307 return 0
309 update_rows: list[Record] = []
310 for exposure, detector, region in region_data:
311 try:
312 record = self.record_factory.make_spatial_records(region)
313 except RegionTypeError as exc:
314 warnings.warn(
315 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
316 category=RegionTypeWarning,
317 )
318 continue
320 record.update(
321 {
322 "instrument_column": instrument,
323 "exposure_column": exposure,
324 "detector_column": detector,
325 }
326 )
327 update_rows.append(record)
329 where_dict: dict[str, str] = {
330 instrument_column: "instrument_column",
331 exposure_column: "exposure_column",
332 detector_column: "detector_column",
333 }
335 count = self.db.update(self.table, where_dict, *update_rows)
336 return count
338 @contextmanager
339 def query(
340 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
341 ) -> Iterator[sqlalchemy.engine.CursorResult]:
342 # Docstring inherited from base class.
343 if columns is not None:
344 column_elements: list[sqlalchemy.sql.ColumnElement] = []
345 for column in columns:
346 if isinstance(column, str):
347 column_elements.append(self.table.columns[column])
348 else:
349 column_elements.append(column)
350 query = sqlalchemy.sql.select(*column_elements).select_from(self.table)
351 else:
352 query = self.table.select()
354 if kwargs:
355 query = query.where(
356 sqlalchemy.sql.expression.and_(
357 *[self.table.columns[column] == value for column, value in kwargs.items()]
358 )
359 )
360 with self.db.query(query) as result:
361 yield result