Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 16%
172 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-15 09:13 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-15 09:13 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import warnings
29from collections import defaultdict
30from collections.abc import Collection, Iterable, Iterator, Mapping
31from contextlib import contextmanager
32from typing import TYPE_CHECKING, Any
34import sqlalchemy
35from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
36from lsst.daf.relation import Join
37from lsst.sphgeom import Region
38from lsst.utils.iteration import chunk_iterable
40from ..interfaces import ObsCoreTableManager, VersionTuple
41from ._config import ConfigCollectionType, ObsCoreManagerConfig
42from ._records import ExposureRegionFactory, Record, RecordFactory
43from ._schema import ObsCoreSchema
44from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin
46if TYPE_CHECKING:
47 from ..interfaces import (
48 CollectionRecord,
49 Database,
50 DatasetRecordStorageManager,
51 DimensionRecordStorageManager,
52 StaticTablesContext,
53 )
54 from ..queries import SqlQueryContext
56_VERSION = VersionTuple(0, 0, 1)
59class _ExposureRegionFactory(ExposureRegionFactory):
60 """Find exposure region from a matching visit dimensions records."""
62 def __init__(self, dimensions: DimensionRecordStorageManager):
63 self.dimensions = dimensions
64 self.universe = dimensions.universe
65 self.exposure_dimensions = self.universe["exposure"].graph
66 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"])
68 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
69 # Docstring is inherited from a base class.
70 # Make a relation that starts with visit_definition (mapping between
71 # exposure and visit).
72 relation = context.make_initial_relation()
73 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
74 if visit_definition_storage is None:
75 return None
76 relation = visit_definition_storage.join(relation, Join(), context)
77 # Join in a table with either visit+detector regions or visit regions.
78 if "detector" in dataId.names:
79 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
80 if visit_detector_region_storage is None:
81 return None
82 relation = visit_detector_region_storage.join(relation, Join(), context)
83 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
84 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
85 else:
86 visit_storage = self.dimensions.get(self.universe["visit"])
87 if visit_storage is None:
88 return None
89 relation = visit_storage.join(relation, Join(), context)
90 constraint_data_id = dataId.subset(self.exposure_dimensions)
91 region_tag = DimensionRecordColumnTag("visit", "region")
92 # Constrain the relation to match the given exposure and (if present)
93 # detector IDs.
94 relation = relation.with_rows_satisfying(
95 context.make_data_coordinate_predicate(constraint_data_id, full=False)
96 )
97 # If we get more than one result (because the exposure belongs to
98 # multiple visits), just pick an arbitrary one.
99 relation = relation[:1]
100 # Run the query and extract the region, if the query has any results.
101 for row in context.fetch_iterable(relation):
102 return row[region_tag]
103 return None
106class ObsCoreLiveTableManager(ObsCoreTableManager):
107 """A manager class for ObsCore table, implements methods for updating the
108 records in that table.
109 """
111 def __init__(
112 self,
113 *,
114 db: Database,
115 table: sqlalchemy.schema.Table,
116 schema: ObsCoreSchema,
117 universe: DimensionUniverse,
118 config: ObsCoreManagerConfig,
119 dimensions: DimensionRecordStorageManager,
120 spatial_plugins: Collection[SpatialObsCorePlugin],
121 registry_schema_version: VersionTuple | None = None,
122 ):
123 super().__init__(registry_schema_version=registry_schema_version)
124 self.db = db
125 self.table = table
126 self.schema = schema
127 self.universe = universe
128 self.config = config
129 self.spatial_plugins = spatial_plugins
130 exposure_region_factory = _ExposureRegionFactory(dimensions)
131 self.record_factory = RecordFactory(
132 config, schema, universe, spatial_plugins, exposure_region_factory
133 )
134 self.tagged_collection: str | None = None
135 self.run_patterns: list[re.Pattern] = []
136 if config.collection_type is ConfigCollectionType.TAGGED:
137 assert (
138 config.collections is not None and len(config.collections) == 1
139 ), "Exactly one collection name required for tagged type."
140 self.tagged_collection = config.collections[0]
141 elif config.collection_type is ConfigCollectionType.RUN:
142 if config.collections:
143 for coll in config.collections:
144 try:
145 self.run_patterns.append(re.compile(coll))
146 except re.error as exc:
147 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
148 else:
149 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
151 @classmethod
152 def initialize(
153 cls,
154 db: Database,
155 context: StaticTablesContext,
156 *,
157 universe: DimensionUniverse,
158 config: Mapping,
159 datasets: type[DatasetRecordStorageManager],
160 dimensions: DimensionRecordStorageManager,
161 registry_schema_version: VersionTuple | None = None,
162 ) -> ObsCoreTableManager:
163 # Docstring inherited from base class.
164 config_data = Config(config)
165 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
167 # Instantiate all spatial plugins.
168 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
170 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
172 # Generate table specification for main obscore table.
173 table_spec = schema.table_spec
174 for plugin in spatial_plugins:
175 plugin.extend_table_spec(table_spec)
176 table = context.addTable(obscore_config.table_name, schema.table_spec)
178 return ObsCoreLiveTableManager(
179 db=db,
180 table=table,
181 schema=schema,
182 universe=universe,
183 config=obscore_config,
184 dimensions=dimensions,
185 spatial_plugins=spatial_plugins,
186 registry_schema_version=registry_schema_version,
187 )
189 def config_json(self) -> str:
190 """Dump configuration in JSON format.
192 Returns
193 -------
194 json : `str`
195 Configuration serialized in JSON format.
196 """
197 return json.dumps(self.config.dict())
199 @classmethod
200 def currentVersions(cls) -> list[VersionTuple]:
201 # Docstring inherited from base class.
202 return [_VERSION]
204 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
205 # Docstring inherited from base class.
207 # Only makes sense for RUN collection types
208 if self.config.collection_type is not ConfigCollectionType.RUN:
209 return 0
211 obscore_refs: Iterable[DatasetRef]
212 if self.run_patterns:
213 # Check each dataset run against configured run list. We want to
214 # reduce number of calls to _check_dataset_run, which may be
215 # expensive. Normally references are grouped by run, if there are
216 # multiple input references, they should have the same run.
217 # Instead of just checking that, we group them by run again.
218 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list)
219 for ref in refs:
220 # Record factory will filter dataset types, but to reduce
221 # collection checks we also pre-filter it here.
222 if ref.datasetType.name not in self.config.dataset_types:
223 continue
225 assert ref.run is not None, "Run cannot be None"
226 refs_by_run[ref.run].append(ref)
228 good_refs: list[DatasetRef] = []
229 for run, run_refs in refs_by_run.items():
230 if not self._check_dataset_run(run):
231 continue
232 good_refs.extend(run_refs)
233 obscore_refs = good_refs
235 else:
236 # Take all refs, no collection check.
237 obscore_refs = refs
239 return self._populate(obscore_refs, context)
241 def associate(
242 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
243 ) -> int:
244 # Docstring inherited from base class.
246 # Only works when collection type is TAGGED
247 if self.tagged_collection is None:
248 return 0
250 if collection.name == self.tagged_collection:
251 return self._populate(refs, context)
252 else:
253 return 0
255 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
256 # Docstring inherited from base class.
258 # Only works when collection type is TAGGED
259 if self.tagged_collection is None:
260 return 0
262 count = 0
263 if collection.name == self.tagged_collection:
264 # Sorting may improve performance
265 dataset_ids = sorted(ref.id for ref in refs)
266 if dataset_ids:
267 fk_field = self.schema.dataset_fk
268 assert fk_field is not None, "Cannot be None by construction"
269 # There may be too many of them, do it in chunks.
270 for ids in chunk_iterable(dataset_ids):
271 where = self.table.columns[fk_field.name].in_(ids)
272 count += self.db.deleteWhere(self.table, where)
273 return count
275 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
276 """Populate obscore table with the data from given datasets."""
277 records: list[Record] = []
278 for ref in refs:
279 record = self.record_factory(ref, context)
280 if record is not None:
281 records.append(record)
283 if records:
284 # Ignore potential conflicts with existing datasets.
285 return self.db.ensure(self.table, *records, primary_key_only=True)
286 else:
287 return 0
289 def _check_dataset_run(self, run: str) -> bool:
290 """Check that specified run collection matches know patterns."""
292 if not self.run_patterns:
293 # Empty list means take anything.
294 return True
296 # Try each pattern in turn.
297 return any(pattern.fullmatch(run) for pattern in self.run_patterns)
299 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int:
300 # Docstring inherited from base class.
301 instrument_column = self.schema.dimension_column("instrument")
302 exposure_column = self.schema.dimension_column("exposure")
303 detector_column = self.schema.dimension_column("detector")
304 if instrument_column is None or exposure_column is None or detector_column is None:
305 # Not all needed columns are in the table.
306 return 0
308 update_rows: list[Record] = []
309 for exposure, detector, region in region_data:
310 try:
311 record = self.record_factory.make_spatial_records(region)
312 except RegionTypeError as exc:
313 warnings.warn(
314 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}",
315 category=RegionTypeWarning,
316 )
317 continue
319 record.update(
320 {
321 "instrument_column": instrument,
322 "exposure_column": exposure,
323 "detector_column": detector,
324 }
325 )
326 update_rows.append(record)
328 where_dict: dict[str, str] = {
329 instrument_column: "instrument_column",
330 exposure_column: "exposure_column",
331 detector_column: "detector_column",
332 }
334 count = self.db.update(self.table, where_dict, *update_rows)
335 return count
337 @contextmanager
338 def query(
339 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
340 ) -> Iterator[sqlalchemy.engine.CursorResult]:
341 # Docstring inherited from base class.
342 if columns is not None:
343 column_elements: list[sqlalchemy.sql.ColumnElement] = []
344 for column in columns:
345 if isinstance(column, str):
346 column_elements.append(self.table.columns[column])
347 else:
348 column_elements.append(column)
349 query = sqlalchemy.sql.select(*column_elements).select_from(self.table)
350 else:
351 query = self.table.select()
353 if kwargs:
354 query = query.where(
355 sqlalchemy.sql.expression.and_(
356 *[self.table.columns[column] == value for column, value in kwargs.items()]
357 )
358 )
359 with self.db.query(query) as result:
360 yield result