Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%
169 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:44 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:44 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["ObsCoreLiveTableManager"]
32import re
33import warnings
34from collections import defaultdict
35from collections.abc import Collection, Iterable, Iterator, Mapping
36from contextlib import contextmanager
37from typing import TYPE_CHECKING, Any
39import sqlalchemy
40from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
41from lsst.daf.relation import Join
42from lsst.sphgeom import Region
43from lsst.utils.introspection import find_outside_stacklevel
44from lsst.utils.iteration import chunk_iterable
46from ..interfaces import ObsCoreTableManager, VersionTuple
47from ._config import ConfigCollectionType, ObsCoreManagerConfig
48from ._records import ExposureRegionFactory, Record, RecordFactory
49from ._schema import ObsCoreSchema
50from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
52if TYPE_CHECKING:
53 from ..interfaces import (
54 CollectionRecord,
55 Database,
56 DatasetRecordStorageManager,
57 DimensionRecordStorageManager,
58 StaticTablesContext,
59 )
60 from ..queries import SqlQueryContext
62_VERSION = VersionTuple(0, 0, 1)
65class _ExposureRegionFactory(ExposureRegionFactory):
66 """Find exposure region from a matching visit dimensions records.
68 Parameters
69 ----------
70 dimensions : `DimensionRecordStorageManager`
71 The dimension records storage manager.
72 """
74 def __init__(self, dimensions: DimensionRecordStorageManager):
75 self.dimensions = dimensions
76 self.universe = dimensions.universe
77 self.exposure_dimensions = self.universe["exposure"].minimal_group
78 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"])
80 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
81 # Docstring is inherited from a base class.
82 # Make a relation that starts with visit_definition (mapping between
83 # exposure and visit).
84 relation = context.make_initial_relation()
85 if "visit_definition" not in self.universe.elements.names:
86 return None
87 relation = self.dimensions.join("visit_definition", relation, Join(), context)
88 # Join in a table with either visit+detector regions or visit regions.
89 if "detector" in dataId.dimensions:
90 if "visit_detector_region" not in self.universe:
91 return None
92 relation = self.dimensions.join("visit_detector_region", relation, Join(), context)
93 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
94 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
95 else:
96 if "visit" not in self.universe:
97 return None
98 relation = self.dimensions.join("visit", relation, Join(), context)
99 constraint_data_id = dataId.subset(self.exposure_dimensions)
100 region_tag = DimensionRecordColumnTag("visit", "region")
101 # Constrain the relation to match the given exposure and (if present)
102 # detector IDs.
103 relation = relation.with_rows_satisfying(
104 context.make_data_coordinate_predicate(constraint_data_id, full=False)
105 )
106 # If we get more than one result (because the exposure belongs to
107 # multiple visits), just pick an arbitrary one.
108 relation = relation[:1]
109 # Run the query and extract the region, if the query has any results.
110 for row in context.fetch_iterable(relation):
111 return row[region_tag]
112 return None
115class ObsCoreLiveTableManager(ObsCoreTableManager):
116 """A manager class for ObsCore table, implements methods for updating the
117 records in that table.
119 Parameters
120 ----------
121 db : `Database`
122 The active database.
123 table : `sqlalchemy.schema.Table`
124 The ObsCore table.
125 schema : `ObsCoreSchema`
126 The relevant schema.
127 universe : `DimensionUniverse`
128 The dimension universe.
129 config : `ObsCoreManagerConfig`
130 The config controlling the manager.
131 dimensions : `DimensionRecordStorageManager`
132 The storage manager for the dimension records.
133 spatial_plugins : `~collections.abc.Collection` of `SpatialObsCorePlugin`
134 Spatial plugins.
135 registry_schema_version : `VersionTuple` or `None`, optional
136 Version of registry schema.
137 """
139 def __init__(
140 self,
141 *,
142 db: Database,
143 table: sqlalchemy.schema.Table,
144 schema: ObsCoreSchema,
145 universe: DimensionUniverse,
146 config: ObsCoreManagerConfig,
147 dimensions: DimensionRecordStorageManager,
148 spatial_plugins: Collection[SpatialObsCorePlugin],
149 registry_schema_version: VersionTuple | None = None,
150 ):
151 super().__init__(registry_schema_version=registry_schema_version)
152 self.db = db
153 self.table = table
154 self.schema = schema
155 self.universe = universe
156 self.config = config
157 self.spatial_plugins = spatial_plugins
158 exposure_region_factory = _ExposureRegionFactory(dimensions)
159 self.record_factory = RecordFactory(
160 config, schema, universe, spatial_plugins, exposure_region_factory
161 )
162 self.tagged_collection: str | None = None
163 self.run_patterns: list[re.Pattern] = []
164 if config.collection_type is ConfigCollectionType.TAGGED:
165 assert (
166 config.collections is not None and len(config.collections) == 1
167 ), "Exactly one collection name required for tagged type."
168 self.tagged_collection = config.collections[0]
169 elif config.collection_type is ConfigCollectionType.RUN:
170 if config.collections:
171 for coll in config.collections:
172 try:
173 self.run_patterns.append(re.compile(coll))
174 except re.error as exc:
175 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
176 else:
177 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
179 @classmethod
180 def initialize(
181 cls,
182 db: Database,
183 context: StaticTablesContext,
184 *,
185 universe: DimensionUniverse,
186 config: Mapping,
187 datasets: type[DatasetRecordStorageManager],
188 dimensions: DimensionRecordStorageManager,
189 registry_schema_version: VersionTuple | None = None,
190 ) -> ObsCoreTableManager:
191 # Docstring inherited from base class.
192 config_data = Config(config)
193 obscore_config = ObsCoreManagerConfig.model_validate(config_data)
195 # Instantiate all spatial plugins.
196 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
198 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
200 # Generate table specification for main obscore table.
201 table_spec = schema.table_spec
202 for plugin in spatial_plugins:
203 plugin.extend_table_spec(table_spec)
204 table = context.addTable(obscore_config.table_name, schema.table_spec)
206 return ObsCoreLiveTableManager(
207 db=db,
208 table=table,
209 schema=schema,
210 universe=universe,
211 config=obscore_config,
212 dimensions=dimensions,
213 spatial_plugins=spatial_plugins,
214 registry_schema_version=registry_schema_version,
215 )
217 def config_json(self) -> str:
218 """Dump configuration in JSON format.
220 Returns
221 -------
222 json : `str`
223 Configuration serialized in JSON format.
224 """
225 return self.config.model_dump_json()
227 @classmethod
228 def currentVersions(cls) -> list[VersionTuple]:
229 # Docstring inherited from base class.
230 return [_VERSION]
232 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
233 # Docstring inherited from base class.
235 # Only makes sense for RUN collection types
236 if self.config.collection_type is not ConfigCollectionType.RUN:
237 return 0
239 obscore_refs: Iterable[DatasetRef]
240 if self.run_patterns:
241 # Check each dataset run against configured run list. We want to
242 # reduce number of calls to _check_dataset_run, which may be
243 # expensive. Normally references are grouped by run, if there are
244 # multiple input references, they should have the same run.
245 # Instead of just checking that, we group them by run again.
246 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
247 for ref in refs:
248 # Record factory will filter dataset types, but to reduce
249 # collection checks we also pre-filter it here.
250 if ref.datasetType.name not in self.config.dataset_types:
251 continue
253 assert ref.run is not None, "Run cannot be None"
254 refs_by_run[ref.run].append(ref)
256 good_refs: list[DatasetRef] = []
257 for run, run_refs in refs_by_run.items():
258 if not self._check_dataset_run(run):
259 continue
260 good_refs.extend(run_refs)
261 obscore_refs = good_refs
263 else:
264 # Take all refs, no collection check.
265 obscore_refs = refs
267 return self._populate(obscore_refs, context)
269 def associate(
270 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
271 ) -> int:
272 # Docstring inherited from base class.
274 # Only works when collection type is TAGGED
275 if self.tagged_collection is None:
276 return 0
278 if collection.name == self.tagged_collection:
279 return self._populate(refs, context)
280 else:
281 return 0
283 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
284 # Docstring inherited from base class.
286 # Only works when collection type is TAGGED
287 if self.tagged_collection is None:
288 return 0
290 count = 0
291 if collection.name == self.tagged_collection:
292 # Sorting may improve performance
293 dataset_ids = sorted(ref.id for ref in refs)
294 if dataset_ids:
295 fk_field = self.schema.dataset_fk
296 assert fk_field is not None, "Cannot be None by construction"
297 # There may be too many of them, do it in chunks.
298 for ids in chunk_iterable(dataset_ids):
299 where = self.table.columns[fk_field.name].in_(ids)
300 count += self.db.deleteWhere(self.table, where)
301 return count
303 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
304 """Populate obscore table with the data from given datasets."""
305 records: list[Record] = []
306 for ref in refs:
307 record = self.record_factory(ref, context)
308 if record is not None:
309 records.append(record)
311 if records:
312 # Ignore potential conflicts with existing datasets.
313 return self.db.ensure(self.table, *records, primary_key_only=True)
314 else:
315 return 0
317 def _check_dataset_run(self, run: str) -> bool:
318 """Check that specified run collection matches know patterns."""
319 if not self.run_patterns:
320 # Empty list means take anything.
321 return True
323 # Try each pattern in turn.
324 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
326 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
327 # Docstring inherited from base class.
328 instrument_column = self.schema.dimension_column("instrument")
329 exposure_column = self.schema.dimension_column("exposure")
330 detector_column = self.schema.dimension_column("detector")
331 if instrument_column is None or exposure_column is None or detector_column is None:
332 # Not all needed columns are in the table.
333 return 0
335 update_rows: list[Record] = []
336 for exposure, detector, region in region_data:
337 try:
338 record = self.record_factory.make_spatial_records(region)
339 except RegionTypeError as exc:
340 warnings.warn(
341 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
342 category=RegionTypeWarning,
343 stacklevel=find_outside_stacklevel("lsst.daf.butler"),
344 )
345 continue
347 record.update(
348 {
349 "instrument_column": instrument,
350 "exposure_column": exposure,
351 "detector_column": detector,
352 }
353 )
354 update_rows.append(record)
356 where_dict: dict[str, str] = {
357 instrument_column: "instrument_column",
358 exposure_column: "exposure_column",
359 detector_column: "detector_column",
360 }
362 count = self.db.update(self.table, where_dict, *update_rows)
363 return count
365 @contextmanager
366 def query(
367 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
368 ) -> Iterator[sqlalchemy.engine.CursorResult]:
369 # Docstring inherited from base class.
370 if columns is not None:
371 column_elements: list[sqlalchemy.sql.ColumnElement] = []
372 for column in columns:
373 if isinstance(column, str):
374 column_elements.append(self.table.columns[column])
375 else:
376 column_elements.append(column)
377 query = sqlalchemy.sql.select(*column_elements).select_from(self.table)
378 else:
379 query = self.table.select()
381 if kwargs:
382 query = query.where(
383 sqlalchemy.sql.expression.and_(
384 *[self.table.columns[column] == value for column, value in kwargs.items()]
385 )
386 )
387 with self.db.query(query) as result:
388 yield result