Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%

171 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:58 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ["ObsCoreLiveTableManager"] 

31 

32import re 

33import warnings 

34from collections import defaultdict 

35from collections.abc import Collection, Iterable, Iterator, Mapping 

36from contextlib import contextmanager 

37from typing import TYPE_CHECKING, Any 

38 

39import sqlalchemy 

40from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

41from lsst.daf.relation import Join 

42from lsst.sphgeom import Region 

43from lsst.utils.introspection import find_outside_stacklevel 

44from lsst.utils.iteration import chunk_iterable 

45 

46from ..interfaces import ObsCoreTableManager, VersionTuple 

47from ._config import ConfigCollectionType, ObsCoreManagerConfig 

48from ._records import ExposureRegionFactory, Record, RecordFactory 

49from ._schema import ObsCoreSchema 

50from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

51 

52if TYPE_CHECKING: 

53 from ..interfaces import ( 

54 CollectionRecord, 

55 Database, 

56 DatasetRecordStorageManager, 

57 DimensionRecordStorageManager, 

58 StaticTablesContext, 

59 ) 

60 from ..queries import SqlQueryContext 

61 

62_VERSION = VersionTuple(0, 0, 1) 

63 

64 

65class _ExposureRegionFactory(ExposureRegionFactory): 

66 """Find exposure region from a matching visit dimensions records. 

67 

68 Parameters 

69 ---------- 

70 dimensions : `DimensionRecordStorageManager` 

71 The dimension records storage manager. 

72 """ 

73 

74 def __init__(self, dimensions: DimensionRecordStorageManager): 

75 self.dimensions = dimensions 

76 self.universe = dimensions.universe 

77 self.exposure_dimensions = self.universe["exposure"].minimal_group 

78 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"]) 

79 

80 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None: 

81 # Docstring is inherited from a base class. 

82 # Make a relation that starts with visit_definition (mapping between 

83 # exposure and visit). 

84 relation = context.make_initial_relation() 

85 if "visit_definition" not in self.universe.elements.names: 

86 return None 

87 relation = self.dimensions.join("visit_definition", relation, Join(), context) 

88 # Join in a table with either visit+detector regions or visit regions. 

89 if "detector" in dataId.dimensions: 

90 if "visit_detector_region" not in self.universe: 

91 return None 

92 relation = self.dimensions.join("visit_detector_region", relation, Join(), context) 

93 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

94 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

95 else: 

96 if "visit" not in self.universe: 

97 return None 

98 relation = self.dimensions.join("visit", relation, Join(), context) 

99 constraint_data_id = dataId.subset(self.exposure_dimensions) 

100 region_tag = DimensionRecordColumnTag("visit", "region") 

101 # Constrain the relation to match the given exposure and (if present) 

102 # detector IDs. 

103 relation = relation.with_rows_satisfying( 

104 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

105 ) 

106 # If we get more than one result (because the exposure belongs to 

107 # multiple visits), just pick an arbitrary one. 

108 relation = relation[:1] 

109 # Run the query and extract the region, if the query has any results. 

110 for row in context.fetch_iterable(relation): 

111 return row[region_tag] 

112 return None 

113 

114 

115class ObsCoreLiveTableManager(ObsCoreTableManager): 

116 """A manager class for ObsCore table, implements methods for updating the 

117 records in that table. 

118 

119 Parameters 

120 ---------- 

121 db : `Database` 

122 The active database. 

123 table : `sqlalchemy.schema.Table` 

124 The ObsCore table. 

125 schema : `ObsCoreSchema` 

126 The relevant schema. 

127 universe : `DimensionUniverse` 

128 The dimension universe. 

129 config : `ObsCoreManagerConfig` 

130 The config controlling the manager. 

131 dimensions : `DimensionRecordStorageManager` 

132 The storage manager for the dimension records. 

133 spatial_plugins : `~collections.abc.Collection` of `SpatialObsCorePlugin` 

134 Spatial plugins. 

135 registry_schema_version : `VersionTuple` or `None`, optional 

136 Version of registry schema. 

137 """ 

138 

139 def __init__( 

140 self, 

141 *, 

142 db: Database, 

143 table: sqlalchemy.schema.Table, 

144 schema: ObsCoreSchema, 

145 universe: DimensionUniverse, 

146 config: ObsCoreManagerConfig, 

147 dimensions: DimensionRecordStorageManager, 

148 spatial_plugins: Collection[SpatialObsCorePlugin], 

149 registry_schema_version: VersionTuple | None = None, 

150 ): 

151 super().__init__(registry_schema_version=registry_schema_version) 

152 self.db = db 

153 self.table = table 

154 self.schema = schema 

155 self.universe = universe 

156 self.config = config 

157 self.spatial_plugins = spatial_plugins 

158 exposure_region_factory = _ExposureRegionFactory(dimensions) 

159 self.record_factory = RecordFactory( 

160 config, schema, universe, spatial_plugins, exposure_region_factory 

161 ) 

162 self.tagged_collection: str | None = None 

163 self.run_patterns: list[re.Pattern] = [] 

164 if config.collection_type is ConfigCollectionType.TAGGED: 

165 assert ( 

166 config.collections is not None and len(config.collections) == 1 

167 ), "Exactly one collection name required for tagged type." 

168 self.tagged_collection = config.collections[0] 

169 elif config.collection_type is ConfigCollectionType.RUN: 

170 if config.collections: 

171 for coll in config.collections: 

172 try: 

173 self.run_patterns.append(re.compile(coll)) 

174 except re.error as exc: 

175 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

176 else: 

177 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

178 

179 def clone(self, *, db: Database, dimensions: DimensionRecordStorageManager) -> ObsCoreLiveTableManager: 

180 return ObsCoreLiveTableManager( 

181 db=db, 

182 table=self.table, 

183 schema=self.schema, 

184 universe=self.universe, 

185 config=self.config, 

186 dimensions=dimensions, 

187 # Current spatial plugins are safe to share without cloning -- they 

188 # are immutable and do not use their Database object outside of 

189 # 'initialize'. 

190 spatial_plugins=self.spatial_plugins, 

191 registry_schema_version=self._registry_schema_version, 

192 ) 

193 

194 @classmethod 

195 def initialize( 

196 cls, 

197 db: Database, 

198 context: StaticTablesContext, 

199 *, 

200 universe: DimensionUniverse, 

201 config: Mapping, 

202 datasets: type[DatasetRecordStorageManager], 

203 dimensions: DimensionRecordStorageManager, 

204 registry_schema_version: VersionTuple | None = None, 

205 ) -> ObsCoreTableManager: 

206 # Docstring inherited from base class. 

207 config_data = Config(config) 

208 obscore_config = ObsCoreManagerConfig.model_validate(config_data) 

209 

210 # Instantiate all spatial plugins. 

211 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

212 

213 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

214 

215 # Generate table specification for main obscore table. 

216 table_spec = schema.table_spec 

217 for plugin in spatial_plugins: 

218 plugin.extend_table_spec(table_spec) 

219 table = context.addTable(obscore_config.table_name, schema.table_spec) 

220 

221 return ObsCoreLiveTableManager( 

222 db=db, 

223 table=table, 

224 schema=schema, 

225 universe=universe, 

226 config=obscore_config, 

227 dimensions=dimensions, 

228 spatial_plugins=spatial_plugins, 

229 registry_schema_version=registry_schema_version, 

230 ) 

231 

232 def config_json(self) -> str: 

233 """Dump configuration in JSON format. 

234 

235 Returns 

236 ------- 

237 json : `str` 

238 Configuration serialized in JSON format. 

239 """ 

240 return self.config.model_dump_json() 

241 

242 @classmethod 

243 def currentVersions(cls) -> list[VersionTuple]: 

244 # Docstring inherited from base class. 

245 return [_VERSION] 

246 

247 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

248 # Docstring inherited from base class. 

249 

250 # Only makes sense for RUN collection types 

251 if self.config.collection_type is not ConfigCollectionType.RUN: 

252 return 0 

253 

254 obscore_refs: Iterable[DatasetRef] 

255 if self.run_patterns: 

256 # Check each dataset run against configured run list. We want to 

257 # reduce number of calls to _check_dataset_run, which may be 

258 # expensive. Normally references are grouped by run, if there are 

259 # multiple input references, they should have the same run. 

260 # Instead of just checking that, we group them by run again. 

261 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

262 for ref in refs: 

263 # Record factory will filter dataset types, but to reduce 

264 # collection checks we also pre-filter it here. 

265 if ref.datasetType.name not in self.config.dataset_types: 

266 continue 

267 

268 assert ref.run is not None, "Run cannot be None" 

269 refs_by_run[ref.run].append(ref) 

270 

271 good_refs: list[DatasetRef] = [] 

272 for run, run_refs in refs_by_run.items(): 

273 if not self._check_dataset_run(run): 

274 continue 

275 good_refs.extend(run_refs) 

276 obscore_refs = good_refs 

277 

278 else: 

279 # Take all refs, no collection check. 

280 obscore_refs = refs 

281 

282 return self._populate(obscore_refs, context) 

283 

284 def associate( 

285 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

286 ) -> int: 

287 # Docstring inherited from base class. 

288 

289 # Only works when collection type is TAGGED 

290 if self.tagged_collection is None: 

291 return 0 

292 

293 if collection.name == self.tagged_collection: 

294 return self._populate(refs, context) 

295 else: 

296 return 0 

297 

298 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

299 # Docstring inherited from base class. 

300 

301 # Only works when collection type is TAGGED 

302 if self.tagged_collection is None: 

303 return 0 

304 

305 count = 0 

306 if collection.name == self.tagged_collection: 

307 # Sorting may improve performance 

308 dataset_ids = sorted(ref.id for ref in refs) 

309 if dataset_ids: 

310 fk_field = self.schema.dataset_fk 

311 assert fk_field is not None, "Cannot be None by construction" 

312 # There may be too many of them, do it in chunks. 

313 for ids in chunk_iterable(dataset_ids): 

314 where = self.table.columns[fk_field.name].in_(ids) 

315 count += self.db.deleteWhere(self.table, where) 

316 return count 

317 

318 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

319 """Populate obscore table with the data from given datasets.""" 

320 records: list[Record] = [] 

321 for ref in refs: 

322 record = self.record_factory(ref, context) 

323 if record is not None: 

324 records.append(record) 

325 

326 if records: 

327 # Ignore potential conflicts with existing datasets. 

328 return self.db.ensure(self.table, *records, primary_key_only=True) 

329 else: 

330 return 0 

331 

332 def _check_dataset_run(self, run: str) -> bool: 

333 """Check that specified run collection matches know patterns.""" 

334 if not self.run_patterns: 

335 # Empty list means take anything. 

336 return True 

337 

338 # Try each pattern in turn. 

339 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

340 

341 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

342 # Docstring inherited from base class. 

343 instrument_column = self.schema.dimension_column("instrument") 

344 exposure_column = self.schema.dimension_column("exposure") 

345 detector_column = self.schema.dimension_column("detector") 

346 if instrument_column is None or exposure_column is None or detector_column is None: 

347 # Not all needed columns are in the table. 

348 return 0 

349 

350 update_rows: list[Record] = [] 

351 for exposure, detector, region in region_data: 

352 try: 

353 record = self.record_factory.make_spatial_records(region) 

354 except RegionTypeError as exc: 

355 warnings.warn( 

356 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

357 category=RegionTypeWarning, 

358 stacklevel=find_outside_stacklevel("lsst.daf.butler"), 

359 ) 

360 continue 

361 

362 record.update( 

363 { 

364 "instrument_column": instrument, 

365 "exposure_column": exposure, 

366 "detector_column": detector, 

367 } 

368 ) 

369 update_rows.append(record) 

370 

371 where_dict: dict[str, str] = { 

372 instrument_column: "instrument_column", 

373 exposure_column: "exposure_column", 

374 detector_column: "detector_column", 

375 } 

376 

377 count = self.db.update(self.table, where_dict, *update_rows) 

378 return count 

379 

380 @contextmanager 

381 def query( 

382 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any 

383 ) -> Iterator[sqlalchemy.engine.CursorResult]: 

384 # Docstring inherited from base class. 

385 if columns is not None: 

386 column_elements: list[sqlalchemy.sql.ColumnElement] = [] 

387 for column in columns: 

388 if isinstance(column, str): 

389 column_elements.append(self.table.columns[column]) 

390 else: 

391 column_elements.append(column) 

392 query = sqlalchemy.sql.select(*column_elements).select_from(self.table) 

393 else: 

394 query = self.table.select() 

395 

396 if kwargs: 

397 query = query.where( 

398 sqlalchemy.sql.expression.and_( 

399 *[self.table.columns[column] == value for column, value in kwargs.items()] 

400 ) 

401 ) 

402 with self.db.query(query) as result: 

403 yield result