Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%

169 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-16 10:44 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ["ObsCoreLiveTableManager"] 

31 

32import re 

33import warnings 

34from collections import defaultdict 

35from collections.abc import Collection, Iterable, Iterator, Mapping 

36from contextlib import contextmanager 

37from typing import TYPE_CHECKING, Any 

38 

39import sqlalchemy 

40from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

41from lsst.daf.relation import Join 

42from lsst.sphgeom import Region 

43from lsst.utils.introspection import find_outside_stacklevel 

44from lsst.utils.iteration import chunk_iterable 

45 

46from ..interfaces import ObsCoreTableManager, VersionTuple 

47from ._config import ConfigCollectionType, ObsCoreManagerConfig 

48from ._records import ExposureRegionFactory, Record, RecordFactory 

49from ._schema import ObsCoreSchema 

50from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

51 

52if TYPE_CHECKING: 

53 from ..interfaces import ( 

54 CollectionRecord, 

55 Database, 

56 DatasetRecordStorageManager, 

57 DimensionRecordStorageManager, 

58 StaticTablesContext, 

59 ) 

60 from ..queries import SqlQueryContext 

61 

62_VERSION = VersionTuple(0, 0, 1) 

63 

64 

65class _ExposureRegionFactory(ExposureRegionFactory): 

66 """Find exposure region from a matching visit dimensions records. 

67 

68 Parameters 

69 ---------- 

70 dimensions : `DimensionRecordStorageManager` 

71 The dimension records storage manager. 

72 """ 

73 

74 def __init__(self, dimensions: DimensionRecordStorageManager): 

75 self.dimensions = dimensions 

76 self.universe = dimensions.universe 

77 self.exposure_dimensions = self.universe["exposure"].minimal_group 

78 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"]) 

79 

80 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None: 

81 # Docstring is inherited from a base class. 

82 # Make a relation that starts with visit_definition (mapping between 

83 # exposure and visit). 

84 relation = context.make_initial_relation() 

85 if "visit_definition" not in self.universe.elements.names: 

86 return None 

87 relation = self.dimensions.join("visit_definition", relation, Join(), context) 

88 # Join in a table with either visit+detector regions or visit regions. 

89 if "detector" in dataId.dimensions: 

90 if "visit_detector_region" not in self.universe: 

91 return None 

92 relation = self.dimensions.join("visit_detector_region", relation, Join(), context) 

93 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

94 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

95 else: 

96 if "visit" not in self.universe: 

97 return None 

98 relation = self.dimensions.join("visit", relation, Join(), context) 

99 constraint_data_id = dataId.subset(self.exposure_dimensions) 

100 region_tag = DimensionRecordColumnTag("visit", "region") 

101 # Constrain the relation to match the given exposure and (if present) 

102 # detector IDs. 

103 relation = relation.with_rows_satisfying( 

104 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

105 ) 

106 # If we get more than one result (because the exposure belongs to 

107 # multiple visits), just pick an arbitrary one. 

108 relation = relation[:1] 

109 # Run the query and extract the region, if the query has any results. 

110 for row in context.fetch_iterable(relation): 

111 return row[region_tag] 

112 return None 

113 

114 

115class ObsCoreLiveTableManager(ObsCoreTableManager): 

116 """A manager class for ObsCore table, implements methods for updating the 

117 records in that table. 

118 

119 Parameters 

120 ---------- 

121 db : `Database` 

122 The active database. 

123 table : `sqlalchemy.schema.Table` 

124 The ObsCore table. 

125 schema : `ObsCoreSchema` 

126 The relevant schema. 

127 universe : `DimensionUniverse` 

128 The dimension universe. 

129 config : `ObsCoreManagerConfig` 

130 The config controlling the manager. 

131 dimensions : `DimensionRecordStorageManager` 

132 The storage manager for the dimension records. 

133 spatial_plugins : `~collections.abc.Collection` of `SpatialObsCorePlugin` 

134 Spatial plugins. 

135 registry_schema_version : `VersionTuple` or `None`, optional 

136 Version of registry schema. 

137 """ 

138 

139 def __init__( 

140 self, 

141 *, 

142 db: Database, 

143 table: sqlalchemy.schema.Table, 

144 schema: ObsCoreSchema, 

145 universe: DimensionUniverse, 

146 config: ObsCoreManagerConfig, 

147 dimensions: DimensionRecordStorageManager, 

148 spatial_plugins: Collection[SpatialObsCorePlugin], 

149 registry_schema_version: VersionTuple | None = None, 

150 ): 

151 super().__init__(registry_schema_version=registry_schema_version) 

152 self.db = db 

153 self.table = table 

154 self.schema = schema 

155 self.universe = universe 

156 self.config = config 

157 self.spatial_plugins = spatial_plugins 

158 exposure_region_factory = _ExposureRegionFactory(dimensions) 

159 self.record_factory = RecordFactory( 

160 config, schema, universe, spatial_plugins, exposure_region_factory 

161 ) 

162 self.tagged_collection: str | None = None 

163 self.run_patterns: list[re.Pattern] = [] 

164 if config.collection_type is ConfigCollectionType.TAGGED: 

165 assert ( 

166 config.collections is not None and len(config.collections) == 1 

167 ), "Exactly one collection name required for tagged type." 

168 self.tagged_collection = config.collections[0] 

169 elif config.collection_type is ConfigCollectionType.RUN: 

170 if config.collections: 

171 for coll in config.collections: 

172 try: 

173 self.run_patterns.append(re.compile(coll)) 

174 except re.error as exc: 

175 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

176 else: 

177 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

178 

179 @classmethod 

180 def initialize( 

181 cls, 

182 db: Database, 

183 context: StaticTablesContext, 

184 *, 

185 universe: DimensionUniverse, 

186 config: Mapping, 

187 datasets: type[DatasetRecordStorageManager], 

188 dimensions: DimensionRecordStorageManager, 

189 registry_schema_version: VersionTuple | None = None, 

190 ) -> ObsCoreTableManager: 

191 # Docstring inherited from base class. 

192 config_data = Config(config) 

193 obscore_config = ObsCoreManagerConfig.model_validate(config_data) 

194 

195 # Instantiate all spatial plugins. 

196 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

197 

198 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

199 

200 # Generate table specification for main obscore table. 

201 table_spec = schema.table_spec 

202 for plugin in spatial_plugins: 

203 plugin.extend_table_spec(table_spec) 

204 table = context.addTable(obscore_config.table_name, schema.table_spec) 

205 

206 return ObsCoreLiveTableManager( 

207 db=db, 

208 table=table, 

209 schema=schema, 

210 universe=universe, 

211 config=obscore_config, 

212 dimensions=dimensions, 

213 spatial_plugins=spatial_plugins, 

214 registry_schema_version=registry_schema_version, 

215 ) 

216 

217 def config_json(self) -> str: 

218 """Dump configuration in JSON format. 

219 

220 Returns 

221 ------- 

222 json : `str` 

223 Configuration serialized in JSON format. 

224 """ 

225 return self.config.model_dump_json() 

226 

227 @classmethod 

228 def currentVersions(cls) -> list[VersionTuple]: 

229 # Docstring inherited from base class. 

230 return [_VERSION] 

231 

232 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

233 # Docstring inherited from base class. 

234 

235 # Only makes sense for RUN collection types 

236 if self.config.collection_type is not ConfigCollectionType.RUN: 

237 return 0 

238 

239 obscore_refs: Iterable[DatasetRef] 

240 if self.run_patterns: 

241 # Check each dataset run against configured run list. We want to 

242 # reduce number of calls to _check_dataset_run, which may be 

243 # expensive. Normally references are grouped by run, if there are 

244 # multiple input references, they should have the same run. 

245 # Instead of just checking that, we group them by run again. 

246 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

247 for ref in refs: 

248 # Record factory will filter dataset types, but to reduce 

249 # collection checks we also pre-filter it here. 

250 if ref.datasetType.name not in self.config.dataset_types: 

251 continue 

252 

253 assert ref.run is not None, "Run cannot be None" 

254 refs_by_run[ref.run].append(ref) 

255 

256 good_refs: list[DatasetRef] = [] 

257 for run, run_refs in refs_by_run.items(): 

258 if not self._check_dataset_run(run): 

259 continue 

260 good_refs.extend(run_refs) 

261 obscore_refs = good_refs 

262 

263 else: 

264 # Take all refs, no collection check. 

265 obscore_refs = refs 

266 

267 return self._populate(obscore_refs, context) 

268 

269 def associate( 

270 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

271 ) -> int: 

272 # Docstring inherited from base class. 

273 

274 # Only works when collection type is TAGGED 

275 if self.tagged_collection is None: 

276 return 0 

277 

278 if collection.name == self.tagged_collection: 

279 return self._populate(refs, context) 

280 else: 

281 return 0 

282 

283 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

284 # Docstring inherited from base class. 

285 

286 # Only works when collection type is TAGGED 

287 if self.tagged_collection is None: 

288 return 0 

289 

290 count = 0 

291 if collection.name == self.tagged_collection: 

292 # Sorting may improve performance 

293 dataset_ids = sorted(ref.id for ref in refs) 

294 if dataset_ids: 

295 fk_field = self.schema.dataset_fk 

296 assert fk_field is not None, "Cannot be None by construction" 

297 # There may be too many of them, do it in chunks. 

298 for ids in chunk_iterable(dataset_ids): 

299 where = self.table.columns[fk_field.name].in_(ids) 

300 count += self.db.deleteWhere(self.table, where) 

301 return count 

302 

303 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

304 """Populate obscore table with the data from given datasets.""" 

305 records: list[Record] = [] 

306 for ref in refs: 

307 record = self.record_factory(ref, context) 

308 if record is not None: 

309 records.append(record) 

310 

311 if records: 

312 # Ignore potential conflicts with existing datasets. 

313 return self.db.ensure(self.table, *records, primary_key_only=True) 

314 else: 

315 return 0 

316 

317 def _check_dataset_run(self, run: str) -> bool: 

318 """Check that specified run collection matches know patterns.""" 

319 if not self.run_patterns: 

320 # Empty list means take anything. 

321 return True 

322 

323 # Try each pattern in turn. 

324 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

325 

326 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

327 # Docstring inherited from base class. 

328 instrument_column = self.schema.dimension_column("instrument") 

329 exposure_column = self.schema.dimension_column("exposure") 

330 detector_column = self.schema.dimension_column("detector") 

331 if instrument_column is None or exposure_column is None or detector_column is None: 

332 # Not all needed columns are in the table. 

333 return 0 

334 

335 update_rows: list[Record] = [] 

336 for exposure, detector, region in region_data: 

337 try: 

338 record = self.record_factory.make_spatial_records(region) 

339 except RegionTypeError as exc: 

340 warnings.warn( 

341 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

342 category=RegionTypeWarning, 

343 stacklevel=find_outside_stacklevel("lsst.daf.butler"), 

344 ) 

345 continue 

346 

347 record.update( 

348 { 

349 "instrument_column": instrument, 

350 "exposure_column": exposure, 

351 "detector_column": detector, 

352 } 

353 ) 

354 update_rows.append(record) 

355 

356 where_dict: dict[str, str] = { 

357 instrument_column: "instrument_column", 

358 exposure_column: "exposure_column", 

359 detector_column: "detector_column", 

360 } 

361 

362 count = self.db.update(self.table, where_dict, *update_rows) 

363 return count 

364 

365 @contextmanager 

366 def query( 

367 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any 

368 ) -> Iterator[sqlalchemy.engine.CursorResult]: 

369 # Docstring inherited from base class. 

370 if columns is not None: 

371 column_elements: list[sqlalchemy.sql.ColumnElement] = [] 

372 for column in columns: 

373 if isinstance(column, str): 

374 column_elements.append(self.table.columns[column]) 

375 else: 

376 column_elements.append(column) 

377 query = sqlalchemy.sql.select(*column_elements).select_from(self.table) 

378 else: 

379 query = self.table.select() 

380 

381 if kwargs: 

382 query = query.where( 

383 sqlalchemy.sql.expression.and_( 

384 *[self.table.columns[column] == value for column, value in kwargs.items()] 

385 ) 

386 ) 

387 with self.db.query(query) as result: 

388 yield result