Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%
172 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:07 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:07 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["ObsCoreLiveTableManager"]
32import re
33import warnings
34from collections import defaultdict
35from collections.abc import Collection, Iterable, Iterator, Mapping
36from contextlib import contextmanager
37from typing import TYPE_CHECKING, Any
39import sqlalchemy
40from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
41from lsst.daf.relation import Join
42from lsst.sphgeom import Region
43from lsst.utils.introspection import find_outside_stacklevel
44from lsst.utils.iteration import chunk_iterable
46from ..interfaces import ObsCoreTableManager, VersionTuple
47from ._config import ConfigCollectionType, ObsCoreManagerConfig
48from ._records import ExposureRegionFactory, Record, RecordFactory
49from ._schema import ObsCoreSchema
50from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
52if TYPE_CHECKING:
53 from ..interfaces import (
54 CollectionRecord,
55 Database,
56 DatasetRecordStorageManager,
57 DimensionRecordStorageManager,
58 StaticTablesContext,
59 )
60 from ..queries import SqlQueryContext
62_VERSION = VersionTuple(0, 0, 1)
65class _ExposureRegionFactory(ExposureRegionFactory):
66 """Find exposure region from a matching visit dimensions records."""
68 def __init__(self, dimensions: DimensionRecordStorageManager):
69 self.dimensions = dimensions
70 self.universe = dimensions.universe
71 self.exposure_dimensions = self.universe["exposure"].minimal_group
72 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"])
74 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
75 # Docstring is inherited from a base class.
76 # Make a relation that starts with visit_definition (mapping between
77 # exposure and visit).
78 relation = context.make_initial_relation()
79 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
80 if visit_definition_storage is None:
81 return None
82 relation = visit_definition_storage.join(relation, Join(), context)
83 # Join in a table with either visit+detector regions or visit regions.
84 if "detector" in dataId.dimensions:
85 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
86 if visit_detector_region_storage is None:
87 return None
88 relation = visit_detector_region_storage.join(relation, Join(), context)
89 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
90 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
91 else:
92 visit_storage = self.dimensions.get(self.universe["visit"])
93 if visit_storage is None:
94 return None
95 relation = visit_storage.join(relation, Join(), context)
96 constraint_data_id = dataId.subset(self.exposure_dimensions)
97 region_tag = DimensionRecordColumnTag("visit", "region")
98 # Constrain the relation to match the given exposure and (if present)
99 # detector IDs.
100 relation = relation.with_rows_satisfying(
101 context.make_data_coordinate_predicate(constraint_data_id, full=False)
102 )
103 # If we get more than one result (because the exposure belongs to
104 # multiple visits), just pick an arbitrary one.
105 relation = relation[:1]
106 # Run the query and extract the region, if the query has any results.
107 for row in context.fetch_iterable(relation):
108 return row[region_tag]
109 return None
112class ObsCoreLiveTableManager(ObsCoreTableManager):
113 """A manager class for ObsCore table, implements methods for updating the
114 records in that table.
115 """
117 def __init__(
118 self,
119 *,
120 db: Database,
121 table: sqlalchemy.schema.Table,
122 schema: ObsCoreSchema,
123 universe: DimensionUniverse,
124 config: ObsCoreManagerConfig,
125 dimensions: DimensionRecordStorageManager,
126 spatial_plugins: Collection[SpatialObsCorePlugin],
127 registry_schema_version: VersionTuple | None = None,
128 ):
129 super().__init__(registry_schema_version=registry_schema_version)
130 self.db = db
131 self.table = table
132 self.schema = schema
133 self.universe = universe
134 self.config = config
135 self.spatial_plugins = spatial_plugins
136 exposure_region_factory = _ExposureRegionFactory(dimensions)
137 self.record_factory = RecordFactory(
138 config, schema, universe, spatial_plugins, exposure_region_factory
139 )
140 self.tagged_collection: str | None = None
141 self.run_patterns: list[re.Pattern] = []
142 if config.collection_type is ConfigCollectionType.TAGGED:
143 assert (
144 config.collections is not None and len(config.collections) == 1
145 ), "Exactly one collection name required for tagged type."
146 self.tagged_collection = config.collections[0]
147 elif config.collection_type is ConfigCollectionType.RUN:
148 if config.collections:
149 for coll in config.collections:
150 try:
151 self.run_patterns.append(re.compile(coll))
152 except re.error as exc:
153 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
154 else:
155 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
157 @classmethod
158 def initialize(
159 cls,
160 db: Database,
161 context: StaticTablesContext,
162 *,
163 universe: DimensionUniverse,
164 config: Mapping,
165 datasets: type[DatasetRecordStorageManager],
166 dimensions: DimensionRecordStorageManager,
167 registry_schema_version: VersionTuple | None = None,
168 ) -> ObsCoreTableManager:
169 # Docstring inherited from base class.
170 config_data = Config(config)
171 obscore_config = ObsCoreManagerConfig.model_validate(config_data)
173 # Instantiate all spatial plugins.
174 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
176 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
178 # Generate table specification for main obscore table.
179 table_spec = schema.table_spec
180 for plugin in spatial_plugins:
181 plugin.extend_table_spec(table_spec)
182 table = context.addTable(obscore_config.table_name, schema.table_spec)
184 return ObsCoreLiveTableManager(
185 db=db,
186 table=table,
187 schema=schema,
188 universe=universe,
189 config=obscore_config,
190 dimensions=dimensions,
191 spatial_plugins=spatial_plugins,
192 registry_schema_version=registry_schema_version,
193 )
195 def config_json(self) -> str:
196 """Dump configuration in JSON format.
198 Returns
199 -------
200 json : `str`
201 Configuration serialized in JSON format.
202 """
203 return self.config.model_dump_json()
205 @classmethod
206 def currentVersions(cls) -> list[VersionTuple]:
207 # Docstring inherited from base class.
208 return [_VERSION]
210 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
211 # Docstring inherited from base class.
213 # Only makes sense for RUN collection types
214 if self.config.collection_type is not ConfigCollectionType.RUN:
215 return 0
217 obscore_refs: Iterable[DatasetRef]
218 if self.run_patterns:
219 # Check each dataset run against configured run list. We want to
220 # reduce number of calls to _check_dataset_run, which may be
221 # expensive. Normally references are grouped by run, if there are
222 # multiple input references, they should have the same run.
223 # Instead of just checking that, we group them by run again.
224 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
225 for ref in refs:
226 # Record factory will filter dataset types, but to reduce
227 # collection checks we also pre-filter it here.
228 if ref.datasetType.name not in self.config.dataset_types:
229 continue
231 assert ref.run is not None, "Run cannot be None"
232 refs_by_run[ref.run].append(ref)
234 good_refs: list[DatasetRef] = []
235 for run, run_refs in refs_by_run.items():
236 if not self._check_dataset_run(run):
237 continue
238 good_refs.extend(run_refs)
239 obscore_refs = good_refs
241 else:
242 # Take all refs, no collection check.
243 obscore_refs = refs
245 return self._populate(obscore_refs, context)
247 def associate(
248 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
249 ) -> int:
250 # Docstring inherited from base class.
252 # Only works when collection type is TAGGED
253 if self.tagged_collection is None:
254 return 0
256 if collection.name == self.tagged_collection:
257 return self._populate(refs, context)
258 else:
259 return 0
261 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
262 # Docstring inherited from base class.
264 # Only works when collection type is TAGGED
265 if self.tagged_collection is None:
266 return 0
268 count = 0
269 if collection.name == self.tagged_collection:
270 # Sorting may improve performance
271 dataset_ids = sorted(ref.id for ref in refs)
272 if dataset_ids:
273 fk_field = self.schema.dataset_fk
274 assert fk_field is not None, "Cannot be None by construction"
275 # There may be too many of them, do it in chunks.
276 for ids in chunk_iterable(dataset_ids):
277 where = self.table.columns[fk_field.name].in_(ids)
278 count += self.db.deleteWhere(self.table, where)
279 return count
281 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
282 """Populate obscore table with the data from given datasets."""
283 records: list[Record] = []
284 for ref in refs:
285 record = self.record_factory(ref, context)
286 if record is not None:
287 records.append(record)
289 if records:
290 # Ignore potential conflicts with existing datasets.
291 return self.db.ensure(self.table, *records, primary_key_only=True)
292 else:
293 return 0
295 def _check_dataset_run(self, run: str) -> bool:
296 """Check that specified run collection matches know patterns."""
297 if not self.run_patterns:
298 # Empty list means take anything.
299 return True
301 # Try each pattern in turn.
302 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
304 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
305 # Docstring inherited from base class.
306 instrument_column = self.schema.dimension_column("instrument")
307 exposure_column = self.schema.dimension_column("exposure")
308 detector_column = self.schema.dimension_column("detector")
309 if instrument_column is None or exposure_column is None or detector_column is None:
310 # Not all needed columns are in the table.
311 return 0
313 update_rows: list[Record] = []
314 for exposure, detector, region in region_data:
315 try:
316 record = self.record_factory.make_spatial_records(region)
317 except RegionTypeError as exc:
318 warnings.warn(
319 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
320 category=RegionTypeWarning,
321 stacklevel=find_outside_stacklevel("lsst.daf.butler"),
322 )
323 continue
325 record.update(
326 {
327 "instrument_column": instrument,
328 "exposure_column": exposure,
329 "detector_column": detector,
330 }
331 )
332 update_rows.append(record)
334 where_dict: dict[str, str] = {
335 instrument_column: "instrument_column",
336 exposure_column: "exposure_column",
337 detector_column: "detector_column",
338 }
340 count = self.db.update(self.table, where_dict, *update_rows)
341 return count
343 @contextmanager
344 def query(
345 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
346 ) -> Iterator[sqlalchemy.engine.CursorResult]:
347 # Docstring inherited from base class.
348 if columns is not None:
349 column_elements: list[sqlalchemy.sql.ColumnElement] = []
350 for column in columns:
351 if isinstance(column, str):
352 column_elements.append(self.table.columns[column])
353 else:
354 column_elements.append(column)
355 query = sqlalchemy.sql.select(*column_elements).select_from(self.table)
356 else:
357 query = self.table.select()
359 if kwargs:
360 query = query.where(
361 sqlalchemy.sql.expression.and_(
362 *[self.table.columns[column] == value for column, value in kwargs.items()]
363 )
364 )
365 with self.db.query(query) as result:
366 yield result