Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%
143 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-01 10:04 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-01 10:04 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ObsCoreLiveTableManager"]
26import json
27import re
28import uuid
29from collections import defaultdict
30from collections.abc import Collection, Mapping
31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast
33import sqlalchemy
34from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse
35from lsst.daf.relation import Join
36from lsst.sphgeom import Region
37from lsst.utils.iteration import chunk_iterable
39from ..interfaces import ObsCoreTableManager, VersionTuple
40from ._config import ConfigCollectionType, ObsCoreManagerConfig
41from ._records import ExposureRegionFactory, Record, RecordFactory
42from ._schema import ObsCoreSchema
43from ._spatial import SpatialObsCorePlugin
45if TYPE_CHECKING: 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true
46 from ..interfaces import (
47 CollectionRecord,
48 Database,
49 DatasetRecordStorageManager,
50 DimensionRecordStorageManager,
51 StaticTablesContext,
52 )
53 from ..queries import SqlQueryContext
55_VERSION = VersionTuple(0, 0, 1)
58class _ExposureRegionFactory(ExposureRegionFactory):
59 """Find exposure region from a matching visit dimensions records."""
61 def __init__(self, dimensions: DimensionRecordStorageManager):
62 self.dimensions = dimensions
63 self.universe = dimensions.universe
64 self.exposure_dimensions = self.universe["exposure"].graph
65 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"])
67 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Optional[Region]:
68 # Docstring is inherited from a base class.
69 # Make a relation that starts with visit_definition (mapping between
70 # exposure and visit).
71 relation = context.make_initial_relation()
72 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"])
73 if visit_definition_storage is None:
74 return None
75 relation = visit_definition_storage.join(relation, Join(), context)
76 # Join in a table with either visit+detector regions or visit regions.
77 if "detector" in dataId.names:
78 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"])
79 if visit_detector_region_storage is None:
80 return None
81 relation = visit_detector_region_storage.join(relation, Join(), context)
82 constraint_data_id = dataId.subset(self.exposure_detector_dimensions)
83 region_tag = DimensionRecordColumnTag("visit_detector_region", "region")
84 else:
85 visit_storage = self.dimensions.get(self.universe["visit"])
86 if visit_storage is None:
87 return None
88 relation = visit_storage.join(relation, Join(), context)
89 constraint_data_id = dataId.subset(self.exposure_dimensions)
90 region_tag = DimensionRecordColumnTag("visit", "region")
91 # Constrain the relation to match the given exposure and (if present)
92 # detector IDs.
93 relation = relation.with_rows_satisfying(
94 context.make_data_coordinate_predicate(constraint_data_id, full=False)
95 )
96 # If we get more than one result (because the exposure belongs to
97 # multiple visits), just pick an arbitrary one.
98 relation = relation[:1]
99 # Run the query and extract the region, if the query has any results.
100 for row in context.fetch_iterable(relation):
101 return row[region_tag]
102 return None
105class ObsCoreLiveTableManager(ObsCoreTableManager):
106 """A manager class for ObsCore table, implements methods for updating the
107 records in that table.
108 """
110 def __init__(
111 self,
112 *,
113 db: Database,
114 table: sqlalchemy.schema.Table,
115 schema: ObsCoreSchema,
116 universe: DimensionUniverse,
117 config: ObsCoreManagerConfig,
118 dimensions: DimensionRecordStorageManager,
119 spatial_plugins: Collection[SpatialObsCorePlugin],
120 ):
121 self.db = db
122 self.table = table
123 self.schema = schema
124 self.universe = universe
125 self.config = config
126 self.spatial_plugins = spatial_plugins
127 exposure_region_factory = _ExposureRegionFactory(dimensions)
128 self.record_factory = RecordFactory(
129 config, schema, universe, spatial_plugins, exposure_region_factory
130 )
131 self.tagged_collection: Optional[str] = None
132 self.run_patterns: list[re.Pattern] = []
133 if config.collection_type is ConfigCollectionType.TAGGED:
134 assert (
135 config.collections is not None and len(config.collections) == 1
136 ), "Exactly one collection name required for tagged type."
137 self.tagged_collection = config.collections[0]
138 elif config.collection_type is ConfigCollectionType.RUN:
139 if config.collections:
140 for coll in config.collections:
141 try:
142 self.run_patterns.append(re.compile(coll))
143 except re.error as exc:
144 raise ValueError(f"Failed to compile regex: {coll!r}") from exc
145 else:
146 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}")
148 @classmethod
149 def initialize(
150 cls,
151 db: Database,
152 context: StaticTablesContext,
153 *,
154 universe: DimensionUniverse,
155 config: Mapping,
156 datasets: Type[DatasetRecordStorageManager],
157 dimensions: DimensionRecordStorageManager,
158 ) -> ObsCoreTableManager:
159 # Docstring inherited from base class.
160 config_data = Config(config)
161 obscore_config = ObsCoreManagerConfig.parse_obj(config_data)
163 # Instantiate all spatial plugins.
164 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db)
166 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets)
168 # Generate table specification for main obscore table.
169 table_spec = schema.table_spec
170 for plugin in spatial_plugins:
171 plugin.extend_table_spec(table_spec)
172 table = context.addTable(obscore_config.table_name, schema.table_spec)
174 return ObsCoreLiveTableManager(
175 db=db,
176 table=table,
177 schema=schema,
178 universe=universe,
179 config=obscore_config,
180 dimensions=dimensions,
181 spatial_plugins=spatial_plugins,
182 )
184 def config_json(self) -> str:
185 """Dump configuration in JSON format.
187 Returns
188 -------
189 json : `str`
190 Configuration serialized in JSON format.
191 """
192 return json.dumps(self.config.dict())
194 @classmethod
195 def currentVersion(cls) -> Optional[VersionTuple]:
196 # Docstring inherited from base class.
197 return _VERSION
199 def schemaDigest(self) -> Optional[str]:
200 # Docstring inherited from base class.
201 return None
203 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
204 # Docstring inherited from base class.
206 # Only makes sense for RUN collection types
207 if self.config.collection_type is not ConfigCollectionType.RUN:
208 return 0
210 obscore_refs: Iterable[DatasetRef]
211 if self.run_patterns:
212 # Check each dataset run against configured run list. We want to
213 # reduce number of calls to _check_dataset_run, which may be
214 # expensive. Normally references are grouped by run, if there are
215 # multiple input references, they should have the same run.
216 # Instead of just checking that, we group them by run again.
217 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list)
218 for ref in refs:
220 # Record factory will filter dataset types, but to reduce
221 # collection checks we also pre-filter it here.
222 if ref.datasetType.name not in self.config.dataset_types:
223 continue
225 assert ref.run is not None, "Run cannot be None"
226 refs_by_run[ref.run].append(ref)
228 good_refs: List[DatasetRef] = []
229 for run, run_refs in refs_by_run.items():
230 if not self._check_dataset_run(run):
231 continue
232 good_refs.extend(run_refs)
233 obscore_refs = good_refs
235 else:
237 # Take all refs, no collection check.
238 obscore_refs = refs
240 return self._populate(obscore_refs, context)
242 def associate(
243 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
244 ) -> int:
245 # Docstring inherited from base class.
247 # Only works when collection type is TAGGED
248 if self.tagged_collection is None:
249 return 0
251 if collection.name == self.tagged_collection:
252 return self._populate(refs, context)
253 else:
254 return 0
256 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
257 # Docstring inherited from base class.
259 # Only works when collection type is TAGGED
260 if self.tagged_collection is None:
261 return 0
263 count = 0
264 if collection.name == self.tagged_collection:
266 # Sorting may improve performance
267 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs)
268 if dataset_ids:
269 fk_field = self.schema.dataset_fk
270 assert fk_field is not None, "Cannot be None by construction"
271 # There may be too many of them, do it in chunks.
272 for ids in chunk_iterable(dataset_ids):
273 where = self.table.columns[fk_field.name].in_(ids)
274 count += self.db.deleteWhere(self.table, where)
275 return count
277 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
278 """Populate obscore table with the data from given datasets."""
279 records: List[Record] = []
280 for ref in refs:
281 record = self.record_factory(ref, context)
282 if record is not None:
283 records.append(record)
285 if records:
286 # Ignore potential conflicts with existing datasets.
287 return self.db.ensure(self.table, *records, primary_key_only=True)
288 else:
289 return 0
291 def _check_dataset_run(self, run: str) -> bool:
292 """Check that specified run collection matches know patterns."""
294 if not self.run_patterns:
295 # Empty list means take anything.
296 return True
298 # Try each pattern in turn.
299 return any(pattern.fullmatch(run) for pattern in self.run_patterns)