Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 16%

173 statements  

« prev     ^ index     » next       coverage.py v7.2.4, created at 2023-04-29 02:58 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import uuid 

29import warnings 

30from collections import defaultdict 

31from collections.abc import Collection, Iterable, Iterator, Mapping 

32from contextlib import contextmanager 

33from typing import TYPE_CHECKING, Any, Type, cast 

34 

35import sqlalchemy 

36from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

37from lsst.daf.relation import Join 

38from lsst.sphgeom import Region 

39from lsst.utils.iteration import chunk_iterable 

40 

41from ..interfaces import ObsCoreTableManager, VersionTuple 

42from ._config import ConfigCollectionType, ObsCoreManagerConfig 

43from ._records import ExposureRegionFactory, Record, RecordFactory 

44from ._schema import ObsCoreSchema 

45from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

46 

47if TYPE_CHECKING: 

48 from ..interfaces import ( 

49 CollectionRecord, 

50 Database, 

51 DatasetRecordStorageManager, 

52 DimensionRecordStorageManager, 

53 StaticTablesContext, 

54 ) 

55 from ..queries import SqlQueryContext 

56 

57_VERSION = VersionTuple(0, 0, 1) 

58 

59 

60class _ExposureRegionFactory(ExposureRegionFactory): 

61 """Find exposure region from a matching visit dimensions records.""" 

62 

63 def __init__(self, dimensions: DimensionRecordStorageManager): 

64 self.dimensions = dimensions 

65 self.universe = dimensions.universe 

66 self.exposure_dimensions = self.universe["exposure"].graph 

67 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"]) 

68 

69 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None: 

70 # Docstring is inherited from a base class. 

71 # Make a relation that starts with visit_definition (mapping between 

72 # exposure and visit). 

73 relation = context.make_initial_relation() 

74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

75 if visit_definition_storage is None: 

76 return None 

77 relation = visit_definition_storage.join(relation, Join(), context) 

78 # Join in a table with either visit+detector regions or visit regions. 

79 if "detector" in dataId.names: 

80 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

81 if visit_detector_region_storage is None: 

82 return None 

83 relation = visit_detector_region_storage.join(relation, Join(), context) 

84 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

85 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

86 else: 

87 visit_storage = self.dimensions.get(self.universe["visit"]) 

88 if visit_storage is None: 

89 return None 

90 relation = visit_storage.join(relation, Join(), context) 

91 constraint_data_id = dataId.subset(self.exposure_dimensions) 

92 region_tag = DimensionRecordColumnTag("visit", "region") 

93 # Constrain the relation to match the given exposure and (if present) 

94 # detector IDs. 

95 relation = relation.with_rows_satisfying( 

96 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

97 ) 

98 # If we get more than one result (because the exposure belongs to 

99 # multiple visits), just pick an arbitrary one. 

100 relation = relation[:1] 

101 # Run the query and extract the region, if the query has any results. 

102 for row in context.fetch_iterable(relation): 

103 return row[region_tag] 

104 return None 

105 

106 

107class ObsCoreLiveTableManager(ObsCoreTableManager): 

108 """A manager class for ObsCore table, implements methods for updating the 

109 records in that table. 

110 """ 

111 

112 def __init__( 

113 self, 

114 *, 

115 db: Database, 

116 table: sqlalchemy.schema.Table, 

117 schema: ObsCoreSchema, 

118 universe: DimensionUniverse, 

119 config: ObsCoreManagerConfig, 

120 dimensions: DimensionRecordStorageManager, 

121 spatial_plugins: Collection[SpatialObsCorePlugin], 

122 registry_schema_version: VersionTuple | None = None, 

123 ): 

124 super().__init__(registry_schema_version=registry_schema_version) 

125 self.db = db 

126 self.table = table 

127 self.schema = schema 

128 self.universe = universe 

129 self.config = config 

130 self.spatial_plugins = spatial_plugins 

131 exposure_region_factory = _ExposureRegionFactory(dimensions) 

132 self.record_factory = RecordFactory( 

133 config, schema, universe, spatial_plugins, exposure_region_factory 

134 ) 

135 self.tagged_collection: str | None = None 

136 self.run_patterns: list[re.Pattern] = [] 

137 if config.collection_type is ConfigCollectionType.TAGGED: 

138 assert ( 

139 config.collections is not None and len(config.collections) == 1 

140 ), "Exactly one collection name required for tagged type." 

141 self.tagged_collection = config.collections[0] 

142 elif config.collection_type is ConfigCollectionType.RUN: 

143 if config.collections: 

144 for coll in config.collections: 

145 try: 

146 self.run_patterns.append(re.compile(coll)) 

147 except re.error as exc: 

148 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

149 else: 

150 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

151 

152 @classmethod 

153 def initialize( 

154 cls, 

155 db: Database, 

156 context: StaticTablesContext, 

157 *, 

158 universe: DimensionUniverse, 

159 config: Mapping, 

160 datasets: Type[DatasetRecordStorageManager], 

161 dimensions: DimensionRecordStorageManager, 

162 registry_schema_version: VersionTuple | None = None, 

163 ) -> ObsCoreTableManager: 

164 # Docstring inherited from base class. 

165 config_data = Config(config) 

166 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

167 

168 # Instantiate all spatial plugins. 

169 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

170 

171 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

172 

173 # Generate table specification for main obscore table. 

174 table_spec = schema.table_spec 

175 for plugin in spatial_plugins: 

176 plugin.extend_table_spec(table_spec) 

177 table = context.addTable(obscore_config.table_name, schema.table_spec) 

178 

179 return ObsCoreLiveTableManager( 

180 db=db, 

181 table=table, 

182 schema=schema, 

183 universe=universe, 

184 config=obscore_config, 

185 dimensions=dimensions, 

186 spatial_plugins=spatial_plugins, 

187 registry_schema_version=registry_schema_version, 

188 ) 

189 

190 def config_json(self) -> str: 

191 """Dump configuration in JSON format. 

192 

193 Returns 

194 ------- 

195 json : `str` 

196 Configuration serialized in JSON format. 

197 """ 

198 return json.dumps(self.config.dict()) 

199 

200 @classmethod 

201 def currentVersions(cls) -> list[VersionTuple]: 

202 # Docstring inherited from base class. 

203 return [_VERSION] 

204 

205 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

206 # Docstring inherited from base class. 

207 

208 # Only makes sense for RUN collection types 

209 if self.config.collection_type is not ConfigCollectionType.RUN: 

210 return 0 

211 

212 obscore_refs: Iterable[DatasetRef] 

213 if self.run_patterns: 

214 # Check each dataset run against configured run list. We want to 

215 # reduce number of calls to _check_dataset_run, which may be 

216 # expensive. Normally references are grouped by run, if there are 

217 # multiple input references, they should have the same run. 

218 # Instead of just checking that, we group them by run again. 

219 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

220 for ref in refs: 

221 # Record factory will filter dataset types, but to reduce 

222 # collection checks we also pre-filter it here. 

223 if ref.datasetType.name not in self.config.dataset_types: 

224 continue 

225 

226 assert ref.run is not None, "Run cannot be None" 

227 refs_by_run[ref.run].append(ref) 

228 

229 good_refs: list[DatasetRef] = [] 

230 for run, run_refs in refs_by_run.items(): 

231 if not self._check_dataset_run(run): 

232 continue 

233 good_refs.extend(run_refs) 

234 obscore_refs = good_refs 

235 

236 else: 

237 # Take all refs, no collection check. 

238 obscore_refs = refs 

239 

240 return self._populate(obscore_refs, context) 

241 

242 def associate( 

243 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

244 ) -> int: 

245 # Docstring inherited from base class. 

246 

247 # Only works when collection type is TAGGED 

248 if self.tagged_collection is None: 

249 return 0 

250 

251 if collection.name == self.tagged_collection: 

252 return self._populate(refs, context) 

253 else: 

254 return 0 

255 

256 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

257 # Docstring inherited from base class. 

258 

259 # Only works when collection type is TAGGED 

260 if self.tagged_collection is None: 

261 return 0 

262 

263 count = 0 

264 if collection.name == self.tagged_collection: 

265 # Sorting may improve performance 

266 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs) 

267 if dataset_ids: 

268 fk_field = self.schema.dataset_fk 

269 assert fk_field is not None, "Cannot be None by construction" 

270 # There may be too many of them, do it in chunks. 

271 for ids in chunk_iterable(dataset_ids): 

272 where = self.table.columns[fk_field.name].in_(ids) 

273 count += self.db.deleteWhere(self.table, where) 

274 return count 

275 

276 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

277 """Populate obscore table with the data from given datasets.""" 

278 records: list[Record] = [] 

279 for ref in refs: 

280 record = self.record_factory(ref, context) 

281 if record is not None: 

282 records.append(record) 

283 

284 if records: 

285 # Ignore potential conflicts with existing datasets. 

286 return self.db.ensure(self.table, *records, primary_key_only=True) 

287 else: 

288 return 0 

289 

290 def _check_dataset_run(self, run: str) -> bool: 

291 """Check that specified run collection matches know patterns.""" 

292 

293 if not self.run_patterns: 

294 # Empty list means take anything. 

295 return True 

296 

297 # Try each pattern in turn. 

298 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

299 

300 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

301 # Docstring inherited from base class. 

302 instrument_column = self.schema.dimension_column("instrument") 

303 exposure_column = self.schema.dimension_column("exposure") 

304 detector_column = self.schema.dimension_column("detector") 

305 if instrument_column is None or exposure_column is None or detector_column is None: 

306 # Not all needed columns are in the table. 

307 return 0 

308 

309 update_rows: list[Record] = [] 

310 for exposure, detector, region in region_data: 

311 try: 

312 record = self.record_factory.make_spatial_records(region) 

313 except RegionTypeError as exc: 

314 warnings.warn( 

315 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

316 category=RegionTypeWarning, 

317 ) 

318 continue 

319 

320 record.update( 

321 { 

322 "instrument_column": instrument, 

323 "exposure_column": exposure, 

324 "detector_column": detector, 

325 } 

326 ) 

327 update_rows.append(record) 

328 

329 where_dict: dict[str, str] = { 

330 instrument_column: "instrument_column", 

331 exposure_column: "exposure_column", 

332 detector_column: "detector_column", 

333 } 

334 

335 count = self.db.update(self.table, where_dict, *update_rows) 

336 return count 

337 

338 @contextmanager 

339 def query( 

340 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any 

341 ) -> Iterator[sqlalchemy.engine.CursorResult]: 

342 # Docstring inherited from base class. 

343 if columns is not None: 

344 column_elements: list[sqlalchemy.sql.ColumnElement] = [] 

345 for column in columns: 

346 if isinstance(column, str): 

347 column_elements.append(self.table.columns[column]) 

348 else: 

349 column_elements.append(column) 

350 query = sqlalchemy.sql.select(*column_elements).select_from(self.table) 

351 else: 

352 query = self.table.select() 

353 

354 if kwargs: 

355 query = query.where( 

356 sqlalchemy.sql.expression.and_( 

357 *[self.table.columns[column] == value for column, value in kwargs.items()] 

358 ) 

359 ) 

360 with self.db.query(query) as result: 

361 yield result