Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%

172 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-05 11:07 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ["ObsCoreLiveTableManager"] 

31 

32import re 

33import warnings 

34from collections import defaultdict 

35from collections.abc import Collection, Iterable, Iterator, Mapping 

36from contextlib import contextmanager 

37from typing import TYPE_CHECKING, Any 

38 

39import sqlalchemy 

40from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

41from lsst.daf.relation import Join 

42from lsst.sphgeom import Region 

43from lsst.utils.introspection import find_outside_stacklevel 

44from lsst.utils.iteration import chunk_iterable 

45 

46from ..interfaces import ObsCoreTableManager, VersionTuple 

47from ._config import ConfigCollectionType, ObsCoreManagerConfig 

48from ._records import ExposureRegionFactory, Record, RecordFactory 

49from ._schema import ObsCoreSchema 

50from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

51 

52if TYPE_CHECKING: 

53 from ..interfaces import ( 

54 CollectionRecord, 

55 Database, 

56 DatasetRecordStorageManager, 

57 DimensionRecordStorageManager, 

58 StaticTablesContext, 

59 ) 

60 from ..queries import SqlQueryContext 

61 

62_VERSION = VersionTuple(0, 0, 1) 

63 

64 

65class _ExposureRegionFactory(ExposureRegionFactory): 

66 """Find exposure region from a matching visit dimensions records.""" 

67 

68 def __init__(self, dimensions: DimensionRecordStorageManager): 

69 self.dimensions = dimensions 

70 self.universe = dimensions.universe 

71 self.exposure_dimensions = self.universe["exposure"].minimal_group 

72 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"]) 

73 

74 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None: 

75 # Docstring is inherited from a base class. 

76 # Make a relation that starts with visit_definition (mapping between 

77 # exposure and visit). 

78 relation = context.make_initial_relation() 

79 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

80 if visit_definition_storage is None: 

81 return None 

82 relation = visit_definition_storage.join(relation, Join(), context) 

83 # Join in a table with either visit+detector regions or visit regions. 

84 if "detector" in dataId.dimensions: 

85 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

86 if visit_detector_region_storage is None: 

87 return None 

88 relation = visit_detector_region_storage.join(relation, Join(), context) 

89 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

90 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

91 else: 

92 visit_storage = self.dimensions.get(self.universe["visit"]) 

93 if visit_storage is None: 

94 return None 

95 relation = visit_storage.join(relation, Join(), context) 

96 constraint_data_id = dataId.subset(self.exposure_dimensions) 

97 region_tag = DimensionRecordColumnTag("visit", "region") 

98 # Constrain the relation to match the given exposure and (if present) 

99 # detector IDs. 

100 relation = relation.with_rows_satisfying( 

101 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

102 ) 

103 # If we get more than one result (because the exposure belongs to 

104 # multiple visits), just pick an arbitrary one. 

105 relation = relation[:1] 

106 # Run the query and extract the region, if the query has any results. 

107 for row in context.fetch_iterable(relation): 

108 return row[region_tag] 

109 return None 

110 

111 

112class ObsCoreLiveTableManager(ObsCoreTableManager): 

113 """A manager class for ObsCore table, implements methods for updating the 

114 records in that table. 

115 """ 

116 

117 def __init__( 

118 self, 

119 *, 

120 db: Database, 

121 table: sqlalchemy.schema.Table, 

122 schema: ObsCoreSchema, 

123 universe: DimensionUniverse, 

124 config: ObsCoreManagerConfig, 

125 dimensions: DimensionRecordStorageManager, 

126 spatial_plugins: Collection[SpatialObsCorePlugin], 

127 registry_schema_version: VersionTuple | None = None, 

128 ): 

129 super().__init__(registry_schema_version=registry_schema_version) 

130 self.db = db 

131 self.table = table 

132 self.schema = schema 

133 self.universe = universe 

134 self.config = config 

135 self.spatial_plugins = spatial_plugins 

136 exposure_region_factory = _ExposureRegionFactory(dimensions) 

137 self.record_factory = RecordFactory( 

138 config, schema, universe, spatial_plugins, exposure_region_factory 

139 ) 

140 self.tagged_collection: str | None = None 

141 self.run_patterns: list[re.Pattern] = [] 

142 if config.collection_type is ConfigCollectionType.TAGGED: 

143 assert ( 

144 config.collections is not None and len(config.collections) == 1 

145 ), "Exactly one collection name required for tagged type." 

146 self.tagged_collection = config.collections[0] 

147 elif config.collection_type is ConfigCollectionType.RUN: 

148 if config.collections: 

149 for coll in config.collections: 

150 try: 

151 self.run_patterns.append(re.compile(coll)) 

152 except re.error as exc: 

153 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

154 else: 

155 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

156 

157 @classmethod 

158 def initialize( 

159 cls, 

160 db: Database, 

161 context: StaticTablesContext, 

162 *, 

163 universe: DimensionUniverse, 

164 config: Mapping, 

165 datasets: type[DatasetRecordStorageManager], 

166 dimensions: DimensionRecordStorageManager, 

167 registry_schema_version: VersionTuple | None = None, 

168 ) -> ObsCoreTableManager: 

169 # Docstring inherited from base class. 

170 config_data = Config(config) 

171 obscore_config = ObsCoreManagerConfig.model_validate(config_data) 

172 

173 # Instantiate all spatial plugins. 

174 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

175 

176 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

177 

178 # Generate table specification for main obscore table. 

179 table_spec = schema.table_spec 

180 for plugin in spatial_plugins: 

181 plugin.extend_table_spec(table_spec) 

182 table = context.addTable(obscore_config.table_name, schema.table_spec) 

183 

184 return ObsCoreLiveTableManager( 

185 db=db, 

186 table=table, 

187 schema=schema, 

188 universe=universe, 

189 config=obscore_config, 

190 dimensions=dimensions, 

191 spatial_plugins=spatial_plugins, 

192 registry_schema_version=registry_schema_version, 

193 ) 

194 

195 def config_json(self) -> str: 

196 """Dump configuration in JSON format. 

197 

198 Returns 

199 ------- 

200 json : `str` 

201 Configuration serialized in JSON format. 

202 """ 

203 return self.config.model_dump_json() 

204 

205 @classmethod 

206 def currentVersions(cls) -> list[VersionTuple]: 

207 # Docstring inherited from base class. 

208 return [_VERSION] 

209 

210 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

211 # Docstring inherited from base class. 

212 

213 # Only makes sense for RUN collection types 

214 if self.config.collection_type is not ConfigCollectionType.RUN: 

215 return 0 

216 

217 obscore_refs: Iterable[DatasetRef] 

218 if self.run_patterns: 

219 # Check each dataset run against configured run list. We want to 

220 # reduce number of calls to _check_dataset_run, which may be 

221 # expensive. Normally references are grouped by run, if there are 

222 # multiple input references, they should have the same run. 

223 # Instead of just checking that, we group them by run again. 

224 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

225 for ref in refs: 

226 # Record factory will filter dataset types, but to reduce 

227 # collection checks we also pre-filter it here. 

228 if ref.datasetType.name not in self.config.dataset_types: 

229 continue 

230 

231 assert ref.run is not None, "Run cannot be None" 

232 refs_by_run[ref.run].append(ref) 

233 

234 good_refs: list[DatasetRef] = [] 

235 for run, run_refs in refs_by_run.items(): 

236 if not self._check_dataset_run(run): 

237 continue 

238 good_refs.extend(run_refs) 

239 obscore_refs = good_refs 

240 

241 else: 

242 # Take all refs, no collection check. 

243 obscore_refs = refs 

244 

245 return self._populate(obscore_refs, context) 

246 

247 def associate( 

248 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

249 ) -> int: 

250 # Docstring inherited from base class. 

251 

252 # Only works when collection type is TAGGED 

253 if self.tagged_collection is None: 

254 return 0 

255 

256 if collection.name == self.tagged_collection: 

257 return self._populate(refs, context) 

258 else: 

259 return 0 

260 

261 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

262 # Docstring inherited from base class. 

263 

264 # Only works when collection type is TAGGED 

265 if self.tagged_collection is None: 

266 return 0 

267 

268 count = 0 

269 if collection.name == self.tagged_collection: 

270 # Sorting may improve performance 

271 dataset_ids = sorted(ref.id for ref in refs) 

272 if dataset_ids: 

273 fk_field = self.schema.dataset_fk 

274 assert fk_field is not None, "Cannot be None by construction" 

275 # There may be too many of them, do it in chunks. 

276 for ids in chunk_iterable(dataset_ids): 

277 where = self.table.columns[fk_field.name].in_(ids) 

278 count += self.db.deleteWhere(self.table, where) 

279 return count 

280 

281 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

282 """Populate obscore table with the data from given datasets.""" 

283 records: list[Record] = [] 

284 for ref in refs: 

285 record = self.record_factory(ref, context) 

286 if record is not None: 

287 records.append(record) 

288 

289 if records: 

290 # Ignore potential conflicts with existing datasets. 

291 return self.db.ensure(self.table, *records, primary_key_only=True) 

292 else: 

293 return 0 

294 

295 def _check_dataset_run(self, run: str) -> bool: 

296 """Check that specified run collection matches know patterns.""" 

297 if not self.run_patterns: 

298 # Empty list means take anything. 

299 return True 

300 

301 # Try each pattern in turn. 

302 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

303 

304 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

305 # Docstring inherited from base class. 

306 instrument_column = self.schema.dimension_column("instrument") 

307 exposure_column = self.schema.dimension_column("exposure") 

308 detector_column = self.schema.dimension_column("detector") 

309 if instrument_column is None or exposure_column is None or detector_column is None: 

310 # Not all needed columns are in the table. 

311 return 0 

312 

313 update_rows: list[Record] = [] 

314 for exposure, detector, region in region_data: 

315 try: 

316 record = self.record_factory.make_spatial_records(region) 

317 except RegionTypeError as exc: 

318 warnings.warn( 

319 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

320 category=RegionTypeWarning, 

321 stacklevel=find_outside_stacklevel("lsst.daf.butler"), 

322 ) 

323 continue 

324 

325 record.update( 

326 { 

327 "instrument_column": instrument, 

328 "exposure_column": exposure, 

329 "detector_column": detector, 

330 } 

331 ) 

332 update_rows.append(record) 

333 

334 where_dict: dict[str, str] = { 

335 instrument_column: "instrument_column", 

336 exposure_column: "exposure_column", 

337 detector_column: "detector_column", 

338 } 

339 

340 count = self.db.update(self.table, where_dict, *update_rows) 

341 return count 

342 

343 @contextmanager 

344 def query( 

345 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any 

346 ) -> Iterator[sqlalchemy.engine.CursorResult]: 

347 # Docstring inherited from base class. 

348 if columns is not None: 

349 column_elements: list[sqlalchemy.sql.ColumnElement] = [] 

350 for column in columns: 

351 if isinstance(column, str): 

352 column_elements.append(self.table.columns[column]) 

353 else: 

354 column_elements.append(column) 

355 query = sqlalchemy.sql.select(*column_elements).select_from(self.table) 

356 else: 

357 query = self.table.select() 

358 

359 if kwargs: 

360 query = query.where( 

361 sqlalchemy.sql.expression.and_( 

362 *[self.table.columns[column] == value for column, value in kwargs.items()] 

363 ) 

364 ) 

365 with self.db.query(query) as result: 

366 yield result