Coverage for python / lsst / daf / butler / registry / obscore / _manager.py: 17%
177 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:55 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["ObsCoreLiveTableManager"]
32import re
33import warnings
34from collections import defaultdict
35from collections.abc import Collection, Iterable, Iterator, Mapping
36from contextlib import AbstractContextManager, contextmanager
37from typing import TYPE_CHECKING, Any
39import sqlalchemy
41from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionUniverse
42from lsst.sphgeom import Region
43from lsst.utils.introspection import find_outside_stacklevel
44from lsst.utils.iteration import chunk_iterable
46from ...queries import Query, QueryFactoryFunction
47from ..interfaces import ObsCoreTableManager, VersionTuple
48from ._config import ConfigCollectionType, ObsCoreManagerConfig
49from ._records import DerivedRegionFactory, Record, RecordFactory
50from ._schema import ObsCoreSchema
51from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
53if TYPE_CHECKING:
54 from ..interfaces import (
55 CollectionRecord,
56 Database,
57 DatasetRecordStorageManager,
58 DimensionRecordStorageManager,
59 StaticTablesContext,
60 )
62_VERSION = VersionTuple(0, 0, 1)
65class _ExposureRegionFactory(DerivedRegionFactory):
66 """Find exposure region from a matching visit dimensions records.
68 Parameters
69 ----------
70 dimensions : `DimensionRecordStorageManager`
71 The dimension records storage manager.
72 query_func : `QueryFactoryFunction`
73 Function returning a context manager that sets up a `Query` object
74 for querying the registry. (That is, a function equivalent to
75 ``Butler.query()``).
76 """
78 def __init__(
79 self,
80 dimensions: DimensionRecordStorageManager,
81 query_func: QueryFactoryFunction,
82 ):
83 self.dimensions = dimensions
84 self.universe = dimensions.universe
85 self.exposure_dimensions = self.universe["exposure"].minimal_group
86 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"])
87 self._query_func = query_func
89 def derived_region(self, dataId: DataCoordinate) -> Region | None:
90 # Docstring is inherited from a base class.
92 # Make sure the dimension universe contains a table that can be used
93 # to find visits from exposures.
94 if "visit_definition" not in self.universe.elements.names:
95 return None
96 # Choose the table we will use to look up the visit region:
97 # either visit+detector regions or visit regions.
98 if "detector" in dataId.dimensions:
99 if "visit_detector_region" not in self.universe:
100 return None
101 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
102 region_dimension = "visit_detector_region"
103 else:
104 if "visit" not in self.universe:
105 return None
106 constraint_data_id = dataId.subset(self.exposure_dimensions)
107 region_dimension = "visit"
109 with self._query_func() as query:
110 result = list(
111 query.dimension_records(region_dimension)
112 # Constrain the relation to match the given exposure and (if
113 # present) detector IDs.
114 .where(constraint_data_id)
115 # If we get more than one result (because the exposure belongs
116 # to multiple visits), just pick an arbitrary one.
117 .limit(1)
118 )
119 if len(result) > 0:
120 return result[0].region
121 else:
122 return None
125class ObsCoreLiveTableManager(ObsCoreTableManager):
126 """A manager class for ObsCore table, implements methods for updating the
127 records in that table.
129 Parameters
130 ----------
131 db : `Database`
132 The active database.
133 table : `sqlalchemy.schema.Table`
134 The ObsCore table.
135 schema : `ObsCoreSchema`
136 The relevant schema.
137 universe : `DimensionUniverse`
138 The dimension universe.
139 config : `ObsCoreManagerConfig`
140 The config controlling the manager.
141 dimensions : `DimensionRecordStorageManager`
142 The storage manager for the dimension records.
143 spatial_plugins : `~collections.abc.Collection` of `SpatialObsCorePlugin`
144 Spatial plugins.
145 registry_schema_version : `VersionTuple` or `None`, optional
146 Version of registry schema.
147 """
149 def __init__(
150 self,
151 *,
152 db: Database,
153 table: sqlalchemy.schema.Table,
154 schema: ObsCoreSchema,
155 universe: DimensionUniverse,
156 config: ObsCoreManagerConfig,
157 dimensions: DimensionRecordStorageManager,
158 spatial_plugins: Collection[SpatialObsCorePlugin],
159 registry_schema_version: VersionTuple | None = None,
160 ):
161 super().__init__(registry_schema_version=registry_schema_version)
162 self.db = db
163 self.table = table
164 self.schema = schema
165 self.universe = universe
166 self.config = config
167 self.spatial_plugins = spatial_plugins
168 self._query_func: QueryFactoryFunction | None = None
169 exposure_region_factory = _ExposureRegionFactory(dimensions, self._get_query_object)
170 self.record_factory = RecordFactory.get_record_type_from_universe(universe)(
171 config, schema, universe, spatial_plugins, exposure_region_factory
172 )
173 self.tagged_collection: str | None = None
174 self.run_patterns: list[re.Pattern] = []
175 if config.collection_type is ConfigCollectionType.TAGGED:
176 assert config.collections is not None and len(config.collections) == 1, (
177 "Exactly one collection name required for tagged type."
178 )
179 self.tagged_collection = config.collections[0]
180 elif config.collection_type is ConfigCollectionType.RUN:
181 if config.collections:
182 for coll in config.collections:
183 try:
184 self.run_patterns.append(re.compile(coll))
185 except re.error as exc:
186 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
187 else:
188 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
190 def clone(self, *, db: Database, dimensions: DimensionRecordStorageManager) -> ObsCoreLiveTableManager:
191 manager = ObsCoreLiveTableManager(
192 db=db,
193 table=self.table,
194 schema=self.schema,
195 universe=self.universe,
196 config=self.config,
197 dimensions=dimensions,
198 # Current spatial plugins are safe to share without cloning -- they
199 # are immutable and do not use their Database object outside of
200 # 'initialize'.
201 spatial_plugins=self.spatial_plugins,
202 registry_schema_version=self._registry_schema_version,
203 )
204 if self._query_func is not None:
205 manager.set_query_function(self._query_func)
206 return manager
208 @classmethod
209 def initialize(
210 cls,
211 db: Database,
212 context: StaticTablesContext,
213 *,
214 universe: DimensionUniverse,
215 config: Mapping,
216 datasets: type[DatasetRecordStorageManager],
217 dimensions: DimensionRecordStorageManager,
218 registry_schema_version: VersionTuple | None = None,
219 ) -> ObsCoreTableManager:
220 # Docstring inherited from base class.
221 config_data = Config(config)
222 obscore_config = ObsCoreManagerConfig.model_validate(config_data)
224 # Instantiate all spatial plugins.
225 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
227 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
229 # Generate table specification for main obscore table.
230 table_spec = schema.table_spec
231 for plugin in spatial_plugins:
232 plugin.extend_table_spec(table_spec)
233 table = context.addTable(obscore_config.table_name, schema.table_spec)
235 return ObsCoreLiveTableManager(
236 db=db,
237 table=table,
238 schema=schema,
239 universe=universe,
240 config=obscore_config,
241 dimensions=dimensions,
242 spatial_plugins=spatial_plugins,
243 registry_schema_version=registry_schema_version,
244 )
246 def _get_query_object(self) -> AbstractContextManager[Query]:
247 if self._query_func is None:
248 raise AssertionError("set_query_function should have been called prior to this.")
250 return self._query_func()
252 def set_query_function(self, query_func: QueryFactoryFunction) -> None:
253 self._query_func = query_func
255 def config_json(self) -> str:
256 """Dump configuration in JSON format.
258 Returns
259 -------
260 json : `str`
261 Configuration serialized in JSON format.
262 """
263 return self.config.model_dump_json()
265 @classmethod
266 def currentVersions(cls) -> list[VersionTuple]:
267 # Docstring inherited from base class.
268 return [_VERSION]
270 def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
271 # Docstring inherited from base class.
273 # Only makes sense for RUN collection types
274 if self.config.collection_type is not ConfigCollectionType.RUN:
275 return 0
277 obscore_refs: Iterable[DatasetRef]
278 if self.run_patterns:
279 # Check each dataset run against configured run list. We want to
280 # reduce number of calls to _check_dataset_run, which may be
281 # expensive. Normally references are grouped by run, if there are
282 # multiple input references, they should have the same run.
283 # Instead of just checking that, we group them by run again.
284 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
285 for ref in refs:
286 # Record factory will filter dataset types, but to reduce
287 # collection checks we also pre-filter it here.
288 if ref.datasetType.name not in self.config.dataset_types:
289 continue
291 assert ref.run is not None, "Run cannot be None"
292 refs_by_run[ref.run].append(ref)
294 good_refs: list[DatasetRef] = []
295 for run, run_refs in refs_by_run.items():
296 if not self._check_dataset_run(run):
297 continue
298 good_refs.extend(run_refs)
299 obscore_refs = good_refs
301 else:
302 # Take all refs, no collection check.
303 obscore_refs = refs
305 return self._populate(obscore_refs)
307 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
308 # Docstring inherited from base class.
310 # Only works when collection type is TAGGED
311 if self.tagged_collection is None:
312 return 0
314 if collection.name == self.tagged_collection:
315 return self._populate(refs)
316 else:
317 return 0
319 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
320 # Docstring inherited from base class.
322 # Only works when collection type is TAGGED
323 if self.tagged_collection is None:
324 return 0
326 count = 0
327 if collection.name == self.tagged_collection:
328 # Sorting may improve performance
329 dataset_ids = sorted(ref.id for ref in refs)
330 if dataset_ids:
331 fk_field = self.schema.dataset_fk
332 assert fk_field is not None, "Cannot be None by construction"
333 # There may be too many of them, do it in chunks.
334 for ids in chunk_iterable(dataset_ids):
335 where = self.table.columns[fk_field.name].in_(ids)
336 count += self.db.deleteWhere(self.table, where)
337 return count
339 def _populate(self, refs: Iterable[DatasetRef]) -> int:
340 """Populate obscore table with the data from given datasets."""
341 records: list[Record] = []
342 for ref in refs:
343 record = self.record_factory(ref)
344 if record is not None:
345 records.append(record)
347 if records:
348 # Ignore potential conflicts with existing datasets.
349 return self.db.ensure(self.table, *records, primary_key_only=True)
350 else:
351 return 0
353 def _check_dataset_run(self, run: str) -> bool:
354 """Check that specified run collection matches know patterns."""
355 if not self.run_patterns:
356 # Empty list means take anything.
357 return True
359 # Try each pattern in turn.
360 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
362 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
363 # Docstring inherited from base class.
364 instrument_column = self.schema.dimension_column("instrument")
365 exposure_column = self.schema.dimension_column("exposure")
366 detector_column = self.schema.dimension_column("detector")
367 if instrument_column is None or exposure_column is None or detector_column is None:
368 # Not all needed columns are in the table.
369 return 0
371 update_rows: list[Record] = []
372 for exposure, detector, region in region_data:
373 try:
374 record = self.record_factory.make_spatial_records(region)
375 except RegionTypeError as exc:
376 warnings.warn(
377 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
378 category=RegionTypeWarning,
379 stacklevel=find_outside_stacklevel("lsst.daf.butler"),
380 )
381 continue
383 record.update(
384 {
385 "instrument_column": instrument,
386 "exposure_column": exposure,
387 "detector_column": detector,
388 }
389 )
390 update_rows.append(record)
392 where_dict: dict[str, str] = {
393 instrument_column: "instrument_column",
394 exposure_column: "exposure_column",
395 detector_column: "detector_column",
396 }
398 count = self.db.update(self.table, where_dict, *update_rows)
399 return count
401 @contextmanager
402 def query(
403 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
404 ) -> Iterator[sqlalchemy.engine.CursorResult]:
405 # Docstring inherited from base class.
406 if columns is not None:
407 column_elements: list[sqlalchemy.sql.ColumnElement] = []
408 for column in columns:
409 if isinstance(column, str):
410 column_elements.append(self.table.columns[column])
411 else:
412 column_elements.append(column)
413 query = sqlalchemy.sql.select(*column_elements).select_from(self.table)
414 else:
415 query = self.table.select()
417 if kwargs:
418 query = query.where(
419 sqlalchemy.sql.expression.and_(
420 *[self.table.columns[column] == value for column, value in kwargs.items()]
421 )
422 )
423 with self.db.query(query) as result:
424 yield result