Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 17%

166 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-31 02:41 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import uuid 

29import warnings 

30from collections import defaultdict 

31from collections.abc import Collection, Iterable, Iterator, Mapping 

32from contextlib import contextmanager 

33from typing import TYPE_CHECKING, Any, Type, cast 

34 

35import sqlalchemy 

36from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

37from lsst.daf.relation import Join 

38from lsst.sphgeom import Region 

39from lsst.utils.iteration import chunk_iterable 

40 

41from ..interfaces import ObsCoreTableManager, VersionTuple 

42from ._config import ConfigCollectionType, ObsCoreManagerConfig 

43from ._records import ExposureRegionFactory, Record, RecordFactory 

44from ._schema import ObsCoreSchema 

45from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

46 

47if TYPE_CHECKING: 

48 from ..interfaces import ( 

49 CollectionRecord, 

50 Database, 

51 DatasetRecordStorageManager, 

52 DimensionRecordStorageManager, 

53 StaticTablesContext, 

54 ) 

55 from ..queries import SqlQueryContext 

56 

57_VERSION = VersionTuple(0, 0, 1) 

58 

59 

60class _ExposureRegionFactory(ExposureRegionFactory): 

61 """Find exposure region from a matching visit dimensions records.""" 

62 

63 def __init__(self, dimensions: DimensionRecordStorageManager): 

64 self.dimensions = dimensions 

65 self.universe = dimensions.universe 

66 self.exposure_dimensions = self.universe["exposure"].graph 

67 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"]) 

68 

69 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None: 

70 # Docstring is inherited from a base class. 

71 # Make a relation that starts with visit_definition (mapping between 

72 # exposure and visit). 

73 relation = context.make_initial_relation() 

74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

75 if visit_definition_storage is None: 

76 return None 

77 relation = visit_definition_storage.join(relation, Join(), context) 

78 # Join in a table with either visit+detector regions or visit regions. 

79 if "detector" in dataId.names: 

80 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

81 if visit_detector_region_storage is None: 

82 return None 

83 relation = visit_detector_region_storage.join(relation, Join(), context) 

84 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

85 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

86 else: 

87 visit_storage = self.dimensions.get(self.universe["visit"]) 

88 if visit_storage is None: 

89 return None 

90 relation = visit_storage.join(relation, Join(), context) 

91 constraint_data_id = dataId.subset(self.exposure_dimensions) 

92 region_tag = DimensionRecordColumnTag("visit", "region") 

93 # Constrain the relation to match the given exposure and (if present) 

94 # detector IDs. 

95 relation = relation.with_rows_satisfying( 

96 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

97 ) 

98 # If we get more than one result (because the exposure belongs to 

99 # multiple visits), just pick an arbitrary one. 

100 relation = relation[:1] 

101 # Run the query and extract the region, if the query has any results. 

102 for row in context.fetch_iterable(relation): 

103 return row[region_tag] 

104 return None 

105 

106 

107class ObsCoreLiveTableManager(ObsCoreTableManager): 

108 """A manager class for ObsCore table, implements methods for updating the 

109 records in that table. 

110 """ 

111 

112 def __init__( 

113 self, 

114 *, 

115 db: Database, 

116 table: sqlalchemy.schema.Table, 

117 schema: ObsCoreSchema, 

118 universe: DimensionUniverse, 

119 config: ObsCoreManagerConfig, 

120 dimensions: DimensionRecordStorageManager, 

121 spatial_plugins: Collection[SpatialObsCorePlugin], 

122 registry_schema_version: VersionTuple | None = None, 

123 ): 

124 super().__init__(registry_schema_version=registry_schema_version) 

125 self.db = db 

126 self.table = table 

127 self.schema = schema 

128 self.universe = universe 

129 self.config = config 

130 self.spatial_plugins = spatial_plugins 

131 exposure_region_factory = _ExposureRegionFactory(dimensions) 

132 self.record_factory = RecordFactory( 

133 config, schema, universe, spatial_plugins, exposure_region_factory 

134 ) 

135 self.tagged_collection: str | None = None 

136 self.run_patterns: list[re.Pattern] = [] 

137 if config.collection_type is ConfigCollectionType.TAGGED: 

138 assert ( 

139 config.collections is not None and len(config.collections) == 1 

140 ), "Exactly one collection name required for tagged type." 

141 self.tagged_collection = config.collections[0] 

142 elif config.collection_type is ConfigCollectionType.RUN: 

143 if config.collections: 

144 for coll in config.collections: 

145 try: 

146 self.run_patterns.append(re.compile(coll)) 

147 except re.error as exc: 

148 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

149 else: 

150 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

151 

152 @classmethod 

153 def initialize( 

154 cls, 

155 db: Database, 

156 context: StaticTablesContext, 

157 *, 

158 universe: DimensionUniverse, 

159 config: Mapping, 

160 datasets: Type[DatasetRecordStorageManager], 

161 dimensions: DimensionRecordStorageManager, 

162 registry_schema_version: VersionTuple | None = None, 

163 ) -> ObsCoreTableManager: 

164 # Docstring inherited from base class. 

165 config_data = Config(config) 

166 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

167 

168 # Instantiate all spatial plugins. 

169 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

170 

171 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

172 

173 # Generate table specification for main obscore table. 

174 table_spec = schema.table_spec 

175 for plugin in spatial_plugins: 

176 plugin.extend_table_spec(table_spec) 

177 table = context.addTable(obscore_config.table_name, schema.table_spec) 

178 

179 return ObsCoreLiveTableManager( 

180 db=db, 

181 table=table, 

182 schema=schema, 

183 universe=universe, 

184 config=obscore_config, 

185 dimensions=dimensions, 

186 spatial_plugins=spatial_plugins, 

187 registry_schema_version=registry_schema_version, 

188 ) 

189 

190 def config_json(self) -> str: 

191 """Dump configuration in JSON format. 

192 

193 Returns 

194 ------- 

195 json : `str` 

196 Configuration serialized in JSON format. 

197 """ 

198 return json.dumps(self.config.dict()) 

199 

200 @classmethod 

201 def currentVersions(cls) -> list[VersionTuple]: 

202 # Docstring inherited from base class. 

203 return [_VERSION] 

204 

205 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

206 # Docstring inherited from base class. 

207 

208 # Only makes sense for RUN collection types 

209 if self.config.collection_type is not ConfigCollectionType.RUN: 

210 return 0 

211 

212 obscore_refs: Iterable[DatasetRef] 

213 if self.run_patterns: 

214 # Check each dataset run against configured run list. We want to 

215 # reduce number of calls to _check_dataset_run, which may be 

216 # expensive. Normally references are grouped by run, if there are 

217 # multiple input references, they should have the same run. 

218 # Instead of just checking that, we group them by run again. 

219 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

220 for ref in refs: 

221 # Record factory will filter dataset types, but to reduce 

222 # collection checks we also pre-filter it here. 

223 if ref.datasetType.name not in self.config.dataset_types: 

224 continue 

225 

226 assert ref.run is not None, "Run cannot be None" 

227 refs_by_run[ref.run].append(ref) 

228 

229 good_refs: list[DatasetRef] = [] 

230 for run, run_refs in refs_by_run.items(): 

231 if not self._check_dataset_run(run): 

232 continue 

233 good_refs.extend(run_refs) 

234 obscore_refs = good_refs 

235 

236 else: 

237 # Take all refs, no collection check. 

238 obscore_refs = refs 

239 

240 return self._populate(obscore_refs, context) 

241 

242 def associate( 

243 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

244 ) -> int: 

245 # Docstring inherited from base class. 

246 

247 # Only works when collection type is TAGGED 

248 if self.tagged_collection is None: 

249 return 0 

250 

251 if collection.name == self.tagged_collection: 

252 return self._populate(refs, context) 

253 else: 

254 return 0 

255 

256 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

257 # Docstring inherited from base class. 

258 

259 # Only works when collection type is TAGGED 

260 if self.tagged_collection is None: 

261 return 0 

262 

263 count = 0 

264 if collection.name == self.tagged_collection: 

265 # Sorting may improve performance 

266 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs) 

267 if dataset_ids: 

268 fk_field = self.schema.dataset_fk 

269 assert fk_field is not None, "Cannot be None by construction" 

270 # There may be too many of them, do it in chunks. 

271 for ids in chunk_iterable(dataset_ids): 

272 where = self.table.columns[fk_field.name].in_(ids) 

273 count += self.db.deleteWhere(self.table, where) 

274 return count 

275 

276 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

277 """Populate obscore table with the data from given datasets.""" 

278 records: list[Record] = [] 

279 for ref in refs: 

280 record = self.record_factory(ref, context) 

281 if record is not None: 

282 records.append(record) 

283 

284 if records: 

285 # Ignore potential conflicts with existing datasets. 

286 return self.db.ensure(self.table, *records, primary_key_only=True) 

287 else: 

288 return 0 

289 

290 def _check_dataset_run(self, run: str) -> bool: 

291 """Check that specified run collection matches know patterns.""" 

292 

293 if not self.run_patterns: 

294 # Empty list means take anything. 

295 return True 

296 

297 # Try each pattern in turn. 

298 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

299 

300 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

301 # Docstring inherited from base class. 

302 instrument_column = self.schema.dimension_column("instrument") 

303 exposure_column = self.schema.dimension_column("exposure") 

304 detector_column = self.schema.dimension_column("detector") 

305 if instrument_column is None or exposure_column is None or detector_column is None: 

306 # Not all needed columns are in the table. 

307 return 0 

308 

309 update_rows: list[Record] = [] 

310 for exposure, detector, region in region_data: 

311 try: 

312 record = self.record_factory.make_spatial_records(region) 

313 except RegionTypeError as exc: 

314 warnings.warn( 

315 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

316 category=RegionTypeWarning, 

317 ) 

318 continue 

319 

320 record.update( 

321 { 

322 "instrument_column": instrument, 

323 "exposure_column": exposure, 

324 "detector_column": detector, 

325 } 

326 ) 

327 update_rows.append(record) 

328 

329 where_dict: dict[str, str] = { 

330 instrument_column: "instrument_column", 

331 exposure_column: "exposure_column", 

332 detector_column: "detector_column", 

333 } 

334 

335 count = self.db.update(self.table, where_dict, *update_rows) 

336 return count 

337 

338 @contextmanager 

339 def query(self, **kwargs: Any) -> Iterator[sqlalchemy.engine.CursorResult]: 

340 """Run a SELECT query against obscore table and return result rows. 

341 

342 Parameters 

343 ---------- 

344 **kwargs 

345 Restriction on values of individual obscore columns. Key is the 

346 column name, value is the required value of the column. Multiple 

347 restrictions are ANDed together. 

348 """ 

349 query = self.table.select() 

350 if kwargs: 

351 query = query.where( 

352 sqlalchemy.sql.expression.and_( 

353 *[self.table.columns[column] == value for column, value in kwargs.items()] 

354 ) 

355 ) 

356 with self.db.query(query) as result: 

357 yield result