Coverage for python / lsst / daf / butler / registry / obscore / _manager.py: 17%

177 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:41 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ["ObsCoreLiveTableManager"] 

31 

32import re 

33import warnings 

34from collections import defaultdict 

35from collections.abc import Collection, Iterable, Iterator, Mapping 

36from contextlib import AbstractContextManager, contextmanager 

37from typing import TYPE_CHECKING, Any 

38 

39import sqlalchemy 

40 

41from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionUniverse 

42from lsst.sphgeom import Region 

43from lsst.utils.introspection import find_outside_stacklevel 

44from lsst.utils.iteration import chunk_iterable 

45 

46from ...queries import Query, QueryFactoryFunction 

47from ..interfaces import ObsCoreTableManager, VersionTuple 

48from ._config import ConfigCollectionType, ObsCoreManagerConfig 

49from ._records import DerivedRegionFactory, Record, RecordFactory 

50from ._schema import ObsCoreSchema 

51from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

52 

53if TYPE_CHECKING: 

54 from ..interfaces import ( 

55 CollectionRecord, 

56 Database, 

57 DatasetRecordStorageManager, 

58 DimensionRecordStorageManager, 

59 StaticTablesContext, 

60 ) 

61 

62_VERSION = VersionTuple(0, 0, 1) 

63 

64 

65class _ExposureRegionFactory(DerivedRegionFactory): 

66 """Find exposure region from a matching visit dimensions records. 

67 

68 Parameters 

69 ---------- 

70 dimensions : `DimensionRecordStorageManager` 

71 The dimension records storage manager. 

72 query_func : `QueryFactoryFunction` 

73 Function returning a context manager that sets up a `Query` object 

74 for querying the registry. (That is, a function equivalent to 

75 ``Butler.query()``). 

76 """ 

77 

78 def __init__( 

79 self, 

80 dimensions: DimensionRecordStorageManager, 

81 query_func: QueryFactoryFunction, 

82 ): 

83 self.dimensions = dimensions 

84 self.universe = dimensions.universe 

85 self.exposure_dimensions = self.universe["exposure"].minimal_group 

86 self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"]) 

87 self._query_func = query_func 

88 

89 def derived_region(self, dataId: DataCoordinate) -> Region | None: 

90 # Docstring is inherited from a base class. 

91 

92 # Make sure the dimension universe contains a table that can be used 

93 # to find visits from exposures. 

94 if "visit_definition" not in self.universe.elements.names: 

95 return None 

96 # Choose the table we will use to look up the visit region: 

97 # either visit+detector regions or visit regions. 

98 if "detector" in dataId.dimensions: 

99 if "visit_detector_region" not in self.universe: 

100 return None 

101 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

102 region_dimension = "visit_detector_region" 

103 else: 

104 if "visit" not in self.universe: 

105 return None 

106 constraint_data_id = dataId.subset(self.exposure_dimensions) 

107 region_dimension = "visit" 

108 

109 with self._query_func() as query: 

110 result = list( 

111 query.dimension_records(region_dimension) 

112 # Constrain the relation to match the given exposure and (if 

113 # present) detector IDs. 

114 .where(constraint_data_id) 

115 # If we get more than one result (because the exposure belongs 

116 # to multiple visits), just pick an arbitrary one. 

117 .limit(1) 

118 ) 

119 if len(result) > 0: 

120 return result[0].region 

121 else: 

122 return None 

123 

124 

125class ObsCoreLiveTableManager(ObsCoreTableManager): 

126 """A manager class for ObsCore table, implements methods for updating the 

127 records in that table. 

128 

129 Parameters 

130 ---------- 

131 db : `Database` 

132 The active database. 

133 table : `sqlalchemy.schema.Table` 

134 The ObsCore table. 

135 schema : `ObsCoreSchema` 

136 The relevant schema. 

137 universe : `DimensionUniverse` 

138 The dimension universe. 

139 config : `ObsCoreManagerConfig` 

140 The config controlling the manager. 

141 dimensions : `DimensionRecordStorageManager` 

142 The storage manager for the dimension records. 

143 spatial_plugins : `~collections.abc.Collection` of `SpatialObsCorePlugin` 

144 Spatial plugins. 

145 registry_schema_version : `VersionTuple` or `None`, optional 

146 Version of registry schema. 

147 """ 

148 

149 def __init__( 

150 self, 

151 *, 

152 db: Database, 

153 table: sqlalchemy.schema.Table, 

154 schema: ObsCoreSchema, 

155 universe: DimensionUniverse, 

156 config: ObsCoreManagerConfig, 

157 dimensions: DimensionRecordStorageManager, 

158 spatial_plugins: Collection[SpatialObsCorePlugin], 

159 registry_schema_version: VersionTuple | None = None, 

160 ): 

161 super().__init__(registry_schema_version=registry_schema_version) 

162 self.db = db 

163 self.table = table 

164 self.schema = schema 

165 self.universe = universe 

166 self.config = config 

167 self.spatial_plugins = spatial_plugins 

168 self._query_func: QueryFactoryFunction | None = None 

169 exposure_region_factory = _ExposureRegionFactory(dimensions, self._get_query_object) 

170 self.record_factory = RecordFactory.get_record_type_from_universe(universe)( 

171 config, schema, universe, spatial_plugins, exposure_region_factory 

172 ) 

173 self.tagged_collection: str | None = None 

174 self.run_patterns: list[re.Pattern] = [] 

175 if config.collection_type is ConfigCollectionType.TAGGED: 

176 assert config.collections is not None and len(config.collections) == 1, ( 

177 "Exactly one collection name required for tagged type." 

178 ) 

179 self.tagged_collection = config.collections[0] 

180 elif config.collection_type is ConfigCollectionType.RUN: 

181 if config.collections: 

182 for coll in config.collections: 

183 try: 

184 self.run_patterns.append(re.compile(coll)) 

185 except re.error as exc: 

186 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

187 else: 

188 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

189 

190 def clone(self, *, db: Database, dimensions: DimensionRecordStorageManager) -> ObsCoreLiveTableManager: 

191 manager = ObsCoreLiveTableManager( 

192 db=db, 

193 table=self.table, 

194 schema=self.schema, 

195 universe=self.universe, 

196 config=self.config, 

197 dimensions=dimensions, 

198 # Current spatial plugins are safe to share without cloning -- they 

199 # are immutable and do not use their Database object outside of 

200 # 'initialize'. 

201 spatial_plugins=self.spatial_plugins, 

202 registry_schema_version=self._registry_schema_version, 

203 ) 

204 if self._query_func is not None: 

205 manager.set_query_function(self._query_func) 

206 return manager 

207 

208 @classmethod 

209 def initialize( 

210 cls, 

211 db: Database, 

212 context: StaticTablesContext, 

213 *, 

214 universe: DimensionUniverse, 

215 config: Mapping, 

216 datasets: type[DatasetRecordStorageManager], 

217 dimensions: DimensionRecordStorageManager, 

218 registry_schema_version: VersionTuple | None = None, 

219 ) -> ObsCoreTableManager: 

220 # Docstring inherited from base class. 

221 config_data = Config(config) 

222 obscore_config = ObsCoreManagerConfig.model_validate(config_data) 

223 

224 # Instantiate all spatial plugins. 

225 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

226 

227 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

228 

229 # Generate table specification for main obscore table. 

230 table_spec = schema.table_spec 

231 for plugin in spatial_plugins: 

232 plugin.extend_table_spec(table_spec) 

233 table = context.addTable(obscore_config.table_name, schema.table_spec) 

234 

235 return ObsCoreLiveTableManager( 

236 db=db, 

237 table=table, 

238 schema=schema, 

239 universe=universe, 

240 config=obscore_config, 

241 dimensions=dimensions, 

242 spatial_plugins=spatial_plugins, 

243 registry_schema_version=registry_schema_version, 

244 ) 

245 

246 def _get_query_object(self) -> AbstractContextManager[Query]: 

247 if self._query_func is None: 

248 raise AssertionError("set_query_function should have been called prior to this.") 

249 

250 return self._query_func() 

251 

252 def set_query_function(self, query_func: QueryFactoryFunction) -> None: 

253 self._query_func = query_func 

254 

255 def config_json(self) -> str: 

256 """Dump configuration in JSON format. 

257 

258 Returns 

259 ------- 

260 json : `str` 

261 Configuration serialized in JSON format. 

262 """ 

263 return self.config.model_dump_json() 

264 

265 @classmethod 

266 def currentVersions(cls) -> list[VersionTuple]: 

267 # Docstring inherited from base class. 

268 return [_VERSION] 

269 

270 def add_datasets(self, refs: Iterable[DatasetRef]) -> int: 

271 # Docstring inherited from base class. 

272 

273 # Only makes sense for RUN collection types 

274 if self.config.collection_type is not ConfigCollectionType.RUN: 

275 return 0 

276 

277 obscore_refs: Iterable[DatasetRef] 

278 if self.run_patterns: 

279 # Check each dataset run against configured run list. We want to 

280 # reduce number of calls to _check_dataset_run, which may be 

281 # expensive. Normally references are grouped by run, if there are 

282 # multiple input references, they should have the same run. 

283 # Instead of just checking that, we group them by run again. 

284 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

285 for ref in refs: 

286 # Record factory will filter dataset types, but to reduce 

287 # collection checks we also pre-filter it here. 

288 if ref.datasetType.name not in self.config.dataset_types: 

289 continue 

290 

291 assert ref.run is not None, "Run cannot be None" 

292 refs_by_run[ref.run].append(ref) 

293 

294 good_refs: list[DatasetRef] = [] 

295 for run, run_refs in refs_by_run.items(): 

296 if not self._check_dataset_run(run): 

297 continue 

298 good_refs.extend(run_refs) 

299 obscore_refs = good_refs 

300 

301 else: 

302 # Take all refs, no collection check. 

303 obscore_refs = refs 

304 

305 return self._populate(obscore_refs) 

306 

307 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

308 # Docstring inherited from base class. 

309 

310 # Only works when collection type is TAGGED 

311 if self.tagged_collection is None: 

312 return 0 

313 

314 if collection.name == self.tagged_collection: 

315 return self._populate(refs) 

316 else: 

317 return 0 

318 

319 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

320 # Docstring inherited from base class. 

321 

322 # Only works when collection type is TAGGED 

323 if self.tagged_collection is None: 

324 return 0 

325 

326 count = 0 

327 if collection.name == self.tagged_collection: 

328 # Sorting may improve performance 

329 dataset_ids = sorted(ref.id for ref in refs) 

330 if dataset_ids: 

331 fk_field = self.schema.dataset_fk 

332 assert fk_field is not None, "Cannot be None by construction" 

333 # There may be too many of them, do it in chunks. 

334 for ids in chunk_iterable(dataset_ids): 

335 where = self.table.columns[fk_field.name].in_(ids) 

336 count += self.db.deleteWhere(self.table, where) 

337 return count 

338 

339 def _populate(self, refs: Iterable[DatasetRef]) -> int: 

340 """Populate obscore table with the data from given datasets.""" 

341 records: list[Record] = [] 

342 for ref in refs: 

343 record = self.record_factory(ref) 

344 if record is not None: 

345 records.append(record) 

346 

347 if records: 

348 # Ignore potential conflicts with existing datasets. 

349 return self.db.ensure(self.table, *records, primary_key_only=True) 

350 else: 

351 return 0 

352 

353 def _check_dataset_run(self, run: str) -> bool: 

354 """Check that specified run collection matches know patterns.""" 

355 if not self.run_patterns: 

356 # Empty list means take anything. 

357 return True 

358 

359 # Try each pattern in turn. 

360 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

361 

362 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

363 # Docstring inherited from base class. 

364 instrument_column = self.schema.dimension_column("instrument") 

365 exposure_column = self.schema.dimension_column("exposure") 

366 detector_column = self.schema.dimension_column("detector") 

367 if instrument_column is None or exposure_column is None or detector_column is None: 

368 # Not all needed columns are in the table. 

369 return 0 

370 

371 update_rows: list[Record] = [] 

372 for exposure, detector, region in region_data: 

373 try: 

374 record = self.record_factory.make_spatial_records(region) 

375 except RegionTypeError as exc: 

376 warnings.warn( 

377 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

378 category=RegionTypeWarning, 

379 stacklevel=find_outside_stacklevel("lsst.daf.butler"), 

380 ) 

381 continue 

382 

383 record.update( 

384 { 

385 "instrument_column": instrument, 

386 "exposure_column": exposure, 

387 "detector_column": detector, 

388 } 

389 ) 

390 update_rows.append(record) 

391 

392 where_dict: dict[str, str] = { 

393 instrument_column: "instrument_column", 

394 exposure_column: "exposure_column", 

395 detector_column: "detector_column", 

396 } 

397 

398 count = self.db.update(self.table, where_dict, *update_rows) 

399 return count 

400 

401 @contextmanager 

402 def query( 

403 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any 

404 ) -> Iterator[sqlalchemy.engine.CursorResult]: 

405 # Docstring inherited from base class. 

406 if columns is not None: 

407 column_elements: list[sqlalchemy.sql.ColumnElement] = [] 

408 for column in columns: 

409 if isinstance(column, str): 

410 column_elements.append(self.table.columns[column]) 

411 else: 

412 column_elements.append(column) 

413 query = sqlalchemy.sql.select(*column_elements).select_from(self.table) 

414 else: 

415 query = self.table.select() 

416 

417 if kwargs: 

418 query = query.where( 

419 sqlalchemy.sql.expression.and_( 

420 *[self.table.columns[column] == value for column, value in kwargs.items()] 

421 ) 

422 ) 

423 with self.db.query(query) as result: 

424 yield result