Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%
171 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:07 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-12 10:07 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["ObsCoreLiveTableManager"]
32import re
33import warnings
34from collections import defaultdict
35from collections.abc import Collection, Iterable, Iterator, Mapping
36from contextlib import contextmanager
37from typing import TYPE_CHECKING, Any
39import sqlalchemy
40from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
41from lsst.daf.relation import Join
42from lsst.sphgeom import Region
43from lsst.utils.introspection import find_outside_stacklevel
44from lsst.utils.iteration import chunk_iterable
46from ..interfaces import ObsCoreTableManager, VersionTuple
47from ._config import ConfigCollectionType, ObsCoreManagerConfig
48from ._records import ExposureRegionFactory, Record, RecordFactory
49from ._schema import ObsCoreSchema
50from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
52if TYPE_CHECKING:
53 from ..interfaces import (
54 CollectionRecord,
55 Database,
56 DatasetRecordStorageManager,
57 DimensionRecordStorageManager,
58 StaticTablesContext,
59 )
60 from ..queries import SqlQueryContext
62_VERSION = VersionTuple(0, 0, 1)
65class _ExposureRegionFactory(ExposureRegionFactory):
66 """Find exposure region from a matching visit dimensions records.
68 Parameters
69 ----------
70 dimensions : `DimensionRecordStorageManager`
71 The dimension records storage manager.
72 """
74 def __init__(self, dimensions: DimensionRecordStorageManager):
75 self.dimensions = dimensions
76 self.universe = dimensions.universe
77 self.exposure_dimensions = self.universe["exposure"].minimal_group
78 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"])
80 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
81 # Docstring is inherited from a base class.
82 # Make a relation that starts with visit_definition (mapping between
83 # exposure and visit).
84 relation = context.make_initial_relation()
85 if "visit_definition" not in self.universe.elements.names:
86 return None
87 relation = self.dimensions.join("visit_definition", relation, Join(), context)
88 # Join in a table with either visit+detector regions or visit regions.
89 if "detector" in dataId.dimensions:
90 if "visit_detector_region" not in self.universe:
91 return None
92 relation = self.dimensions.join("visit_detector_region", relation, Join(), context)
93 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
94 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
95 else:
96 if "visit" not in self.universe:
97 return None
98 relation = self.dimensions.join("visit", relation, Join(), context)
99 constraint_data_id = dataId.subset(self.exposure_dimensions)
100 region_tag = DimensionRecordColumnTag("visit", "region")
101 # Constrain the relation to match the given exposure and (if present)
102 # detector IDs.
103 relation = relation.with_rows_satisfying(
104 context.make_data_coordinate_predicate(constraint_data_id, full=False)
105 )
106 # If we get more than one result (because the exposure belongs to
107 # multiple visits), just pick an arbitrary one.
108 relation = relation[:1]
109 # Run the query and extract the region, if the query has any results.
110 for row in context.fetch_iterable(relation):
111 return row[region_tag]
112 return None
115class ObsCoreLiveTableManager(ObsCoreTableManager):
116 """A manager class for ObsCore table, implements methods for updating the
117 records in that table.
119 Parameters
120 ----------
121 db : `Database`
122 The active database.
123 table : `sqlalchemy.schema.Table`
124 The ObsCore table.
125 schema : `ObsCoreSchema`
126 The relevant schema.
127 universe : `DimensionUniverse`
128 The dimension universe.
129 config : `ObsCoreManagerConfig`
130 The config controlling the manager.
131 dimensions : `DimensionRecordStorageManager`
132 The storage manager for the dimension records.
133 spatial_plugins : `~collections.abc.Collection` of `SpatialObsCorePlugin`
134 Spatial plugins.
135 registry_schema_version : `VersionTuple` or `None`, optional
136 Version of registry schema.
137 """
139 def __init__(
140 self,
141 *,
142 db: Database,
143 table: sqlalchemy.schema.Table,
144 schema: ObsCoreSchema,
145 universe: DimensionUniverse,
146 config: ObsCoreManagerConfig,
147 dimensions: DimensionRecordStorageManager,
148 spatial_plugins: Collection[SpatialObsCorePlugin],
149 registry_schema_version: VersionTuple | None = None,
150 ):
151 super().__init__(registry_schema_version=registry_schema_version)
152 self.db = db
153 self.table = table
154 self.schema = schema
155 self.universe = universe
156 self.config = config
157 self.spatial_plugins = spatial_plugins
158 exposure_region_factory = _ExposureRegionFactory(dimensions)
159 self.record_factory = RecordFactory(
160 config, schema, universe, spatial_plugins, exposure_region_factory
161 )
162 self.tagged_collection: str | None = None
163 self.run_patterns: list[re.Pattern] = []
164 if config.collection_type is ConfigCollectionType.TAGGED:
165 assert (
166 config.collections is not None and len(config.collections) == 1
167 ), "Exactly one collection name required for tagged type."
168 self.tagged_collection = config.collections[0]
169 elif config.collection_type is ConfigCollectionType.RUN:
170 if config.collections:
171 for coll in config.collections:
172 try:
173 self.run_patterns.append(re.compile(coll))
174 except re.error as exc:
175 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
176 else:
177 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
179 def clone(self, *, db: Database, dimensions: DimensionRecordStorageManager) -> ObsCoreLiveTableManager:
180 return ObsCoreLiveTableManager(
181 db=db,
182 table=self.table,
183 schema=self.schema,
184 universe=self.universe,
185 config=self.config,
186 dimensions=dimensions,
187 # Current spatial plugins are safe to share without cloning -- they
188 # are immutable and do not use their Database object outside of
189 # 'initialize'.
190 spatial_plugins=self.spatial_plugins,
191 registry_schema_version=self._registry_schema_version,
192 )
194 @classmethod
195 def initialize(
196 cls,
197 db: Database,
198 context: StaticTablesContext,
199 *,
200 universe: DimensionUniverse,
201 config: Mapping,
202 datasets: type[DatasetRecordStorageManager],
203 dimensions: DimensionRecordStorageManager,
204 registry_schema_version: VersionTuple | None = None,
205 ) -> ObsCoreTableManager:
206 # Docstring inherited from base class.
207 config_data = Config(config)
208 obscore_config = ObsCoreManagerConfig.model_validate(config_data)
210 # Instantiate all spatial plugins.
211 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
213 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
215 # Generate table specification for main obscore table.
216 table_spec = schema.table_spec
217 for plugin in spatial_plugins:
218 plugin.extend_table_spec(table_spec)
219 table = context.addTable(obscore_config.table_name, schema.table_spec)
221 return ObsCoreLiveTableManager(
222 db=db,
223 table=table,
224 schema=schema,
225 universe=universe,
226 config=obscore_config,
227 dimensions=dimensions,
228 spatial_plugins=spatial_plugins,
229 registry_schema_version=registry_schema_version,
230 )
232 def config_json(self) -> str:
233 """Dump configuration in JSON format.
235 Returns
236 -------
237 json : `str`
238 Configuration serialized in JSON format.
239 """
240 return self.config.model_dump_json()
242 @classmethod
243 def currentVersions(cls) -> list[VersionTuple]:
244 # Docstring inherited from base class.
245 return [_VERSION]
247 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
248 # Docstring inherited from base class.
250 # Only makes sense for RUN collection types
251 if self.config.collection_type is not ConfigCollectionType.RUN:
252 return 0
254 obscore_refs: Iterable[DatasetRef]
255 if self.run_patterns:
256 # Check each dataset run against configured run list. We want to
257 # reduce number of calls to _check_dataset_run, which may be
258 # expensive. Normally references are grouped by run, if there are
259 # multiple input references, they should have the same run.
260 # Instead of just checking that, we group them by run again.
261 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
262 for ref in refs:
263 # Record factory will filter dataset types, but to reduce
264 # collection checks we also pre-filter it here.
265 if ref.datasetType.name not in self.config.dataset_types:
266 continue
268 assert ref.run is not None, "Run cannot be None"
269 refs_by_run[ref.run].append(ref)
271 good_refs: list[DatasetRef] = []
272 for run, run_refs in refs_by_run.items():
273 if not self._check_dataset_run(run):
274 continue
275 good_refs.extend(run_refs)
276 obscore_refs = good_refs
278 else:
279 # Take all refs, no collection check.
280 obscore_refs = refs
282 return self._populate(obscore_refs, context)
284 def associate(
285 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
286 ) -> int:
287 # Docstring inherited from base class.
289 # Only works when collection type is TAGGED
290 if self.tagged_collection is None:
291 return 0
293 if collection.name == self.tagged_collection:
294 return self._populate(refs, context)
295 else:
296 return 0
298 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
299 # Docstring inherited from base class.
301 # Only works when collection type is TAGGED
302 if self.tagged_collection is None:
303 return 0
305 count = 0
306 if collection.name == self.tagged_collection:
307 # Sorting may improve performance
308 dataset_ids = sorted(ref.id for ref in refs)
309 if dataset_ids:
310 fk_field = self.schema.dataset_fk
311 assert fk_field is not None, "Cannot be None by construction"
312 # There may be too many of them, do it in chunks.
313 for ids in chunk_iterable(dataset_ids):
314 where = self.table.columns[fk_field.name].in_(ids)
315 count += self.db.deleteWhere(self.table, where)
316 return count
318 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
319 """Populate obscore table with the data from given datasets."""
320 records: list[Record] = []
321 for ref in refs:
322 record = self.record_factory(ref, context)
323 if record is not None:
324 records.append(record)
326 if records:
327 # Ignore potential conflicts with existing datasets.
328 return self.db.ensure(self.table, *records, primary_key_only=True)
329 else:
330 return 0
332 def _check_dataset_run(self, run: str) -> bool:
333 """Check that specified run collection matches know patterns."""
334 if not self.run_patterns:
335 # Empty list means take anything.
336 return True
338 # Try each pattern in turn.
339 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
341 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
342 # Docstring inherited from base class.
343 instrument_column = self.schema.dimension_column("instrument")
344 exposure_column = self.schema.dimension_column("exposure")
345 detector_column = self.schema.dimension_column("detector")
346 if instrument_column is None or exposure_column is None or detector_column is None:
347 # Not all needed columns are in the table.
348 return 0
350 update_rows: list[Record] = []
351 for exposure, detector, region in region_data:
352 try:
353 record = self.record_factory.make_spatial_records(region)
354 except RegionTypeError as exc:
355 warnings.warn(
356 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
357 category=RegionTypeWarning,
358 stacklevel=find_outside_stacklevel("lsst.daf.butler"),
359 )
360 continue
362 record.update(
363 {
364 "instrument_column": instrument,
365 "exposure_column": exposure,
366 "detector_column": detector,
367 }
368 )
369 update_rows.append(record)
371 where_dict: dict[str, str] = {
372 instrument_column: "instrument_column",
373 exposure_column: "exposure_column",
374 detector_column: "detector_column",
375 }
377 count = self.db.update(self.table, where_dict, *update_rows)
378 return count
380 @contextmanager
381 def query(
382 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
383 ) -> Iterator[sqlalchemy.engine.CursorResult]:
384 # Docstring inherited from base class.
385 if columns is not None:
386 column_elements: list[sqlalchemy.sql.ColumnElement] = []
387 for column in columns:
388 if isinstance(column, str):
389 column_elements.append(self.table.columns[column])
390 else:
391 column_elements.append(column)
392 query = sqlalchemy.sql.select(*column_elements).select_from(self.table)
393 else:
394 query = self.table.select()
396 if kwargs:
397 query = query.where(
398 sqlalchemy.sql.expression.and_(
399 *[self.table.columns[column] == value for column, value in kwargs.items()]
400 )
401 )
402 with self.db.query(query) as result:
403 yield result