Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 16%

172 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-08 05:05 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import warnings 

29from collections import defaultdict 

30from collections.abc import Collection, Iterable, Iterator, Mapping 

31from contextlib import contextmanager 

32from typing import TYPE_CHECKING, Any 

33 

34import sqlalchemy 

35from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

36from lsst.daf.relation import Join 

37from lsst.sphgeom import Region 

38from lsst.utils.iteration import chunk_iterable 

39 

40from ..interfaces import ObsCoreTableManager, VersionTuple 

41from ._config import ConfigCollectionType, ObsCoreManagerConfig 

42from ._records import ExposureRegionFactory, Record, RecordFactory 

43from ._schema import ObsCoreSchema 

44from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

45 

46if TYPE_CHECKING: 

47 from ..interfaces import ( 

48 CollectionRecord, 

49 Database, 

50 DatasetRecordStorageManager, 

51 DimensionRecordStorageManager, 

52 StaticTablesContext, 

53 ) 

54 from ..queries import SqlQueryContext 

55 

56_VERSION = VersionTuple(0, 0, 1) 

57 

58 

59class _ExposureRegionFactory(ExposureRegionFactory): 

60 """Find exposure region from a matching visit dimensions records.""" 

61 

62 def __init__(self, dimensions: DimensionRecordStorageManager): 

63 self.dimensions = dimensions 

64 self.universe = dimensions.universe 

65 self.exposure_dimensions = self.universe["exposure"].graph 

66 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"]) 

67 

68 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None: 

69 # Docstring is inherited from a base class. 

70 # Make a relation that starts with visit_definition (mapping between 

71 # exposure and visit). 

72 relation = context.make_initial_relation() 

73 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

74 if visit_definition_storage is None: 

75 return None 

76 relation = visit_definition_storage.join(relation, Join(), context) 

77 # Join in a table with either visit+detector regions or visit regions. 

78 if "detector" in dataId.names: 

79 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

80 if visit_detector_region_storage is None: 

81 return None 

82 relation = visit_detector_region_storage.join(relation, Join(), context) 

83 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

84 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

85 else: 

86 visit_storage = self.dimensions.get(self.universe["visit"]) 

87 if visit_storage is None: 

88 return None 

89 relation = visit_storage.join(relation, Join(), context) 

90 constraint_data_id = dataId.subset(self.exposure_dimensions) 

91 region_tag = DimensionRecordColumnTag("visit", "region") 

92 # Constrain the relation to match the given exposure and (if present) 

93 # detector IDs. 

94 relation = relation.with_rows_satisfying( 

95 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

96 ) 

97 # If we get more than one result (because the exposure belongs to 

98 # multiple visits), just pick an arbitrary one. 

99 relation = relation[:1] 

100 # Run the query and extract the region, if the query has any results. 

101 for row in context.fetch_iterable(relation): 

102 return row[region_tag] 

103 return None 

104 

105 

106class ObsCoreLiveTableManager(ObsCoreTableManager): 

107 """A manager class for ObsCore table, implements methods for updating the 

108 records in that table. 

109 """ 

110 

111 def __init__( 

112 self, 

113 *, 

114 db: Database, 

115 table: sqlalchemy.schema.Table, 

116 schema: ObsCoreSchema, 

117 universe: DimensionUniverse, 

118 config: ObsCoreManagerConfig, 

119 dimensions: DimensionRecordStorageManager, 

120 spatial_plugins: Collection[SpatialObsCorePlugin], 

121 registry_schema_version: VersionTuple | None = None, 

122 ): 

123 super().__init__(registry_schema_version=registry_schema_version) 

124 self.db = db 

125 self.table = table 

126 self.schema = schema 

127 self.universe = universe 

128 self.config = config 

129 self.spatial_plugins = spatial_plugins 

130 exposure_region_factory = _ExposureRegionFactory(dimensions) 

131 self.record_factory = RecordFactory( 

132 config, schema, universe, spatial_plugins, exposure_region_factory 

133 ) 

134 self.tagged_collection: str | None = None 

135 self.run_patterns: list[re.Pattern] = [] 

136 if config.collection_type is ConfigCollectionType.TAGGED: 

137 assert ( 

138 config.collections is not None and len(config.collections) == 1 

139 ), "Exactly one collection name required for tagged type." 

140 self.tagged_collection = config.collections[0] 

141 elif config.collection_type is ConfigCollectionType.RUN: 

142 if config.collections: 

143 for coll in config.collections: 

144 try: 

145 self.run_patterns.append(re.compile(coll)) 

146 except re.error as exc: 

147 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

148 else: 

149 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

150 

151 @classmethod 

152 def initialize( 

153 cls, 

154 db: Database, 

155 context: StaticTablesContext, 

156 *, 

157 universe: DimensionUniverse, 

158 config: Mapping, 

159 datasets: type[DatasetRecordStorageManager], 

160 dimensions: DimensionRecordStorageManager, 

161 registry_schema_version: VersionTuple | None = None, 

162 ) -> ObsCoreTableManager: 

163 # Docstring inherited from base class. 

164 config_data = Config(config) 

165 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

166 

167 # Instantiate all spatial plugins. 

168 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

169 

170 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

171 

172 # Generate table specification for main obscore table. 

173 table_spec = schema.table_spec 

174 for plugin in spatial_plugins: 

175 plugin.extend_table_spec(table_spec) 

176 table = context.addTable(obscore_config.table_name, schema.table_spec) 

177 

178 return ObsCoreLiveTableManager( 

179 db=db, 

180 table=table, 

181 schema=schema, 

182 universe=universe, 

183 config=obscore_config, 

184 dimensions=dimensions, 

185 spatial_plugins=spatial_plugins, 

186 registry_schema_version=registry_schema_version, 

187 ) 

188 

189 def config_json(self) -> str: 

190 """Dump configuration in JSON format. 

191 

192 Returns 

193 ------- 

194 json : `str` 

195 Configuration serialized in JSON format. 

196 """ 

197 return json.dumps(self.config.dict()) 

198 

199 @classmethod 

200 def currentVersions(cls) -> list[VersionTuple]: 

201 # Docstring inherited from base class. 

202 return [_VERSION] 

203 

204 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

205 # Docstring inherited from base class. 

206 

207 # Only makes sense for RUN collection types 

208 if self.config.collection_type is not ConfigCollectionType.RUN: 

209 return 0 

210 

211 obscore_refs: Iterable[DatasetRef] 

212 if self.run_patterns: 

213 # Check each dataset run against configured run list. We want to 

214 # reduce number of calls to _check_dataset_run, which may be 

215 # expensive. Normally references are grouped by run, if there are 

216 # multiple input references, they should have the same run. 

217 # Instead of just checking that, we group them by run again. 

218 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

219 for ref in refs: 

220 # Record factory will filter dataset types, but to reduce 

221 # collection checks we also pre-filter it here. 

222 if ref.datasetType.name not in self.config.dataset_types: 

223 continue 

224 

225 assert ref.run is not None, "Run cannot be None" 

226 refs_by_run[ref.run].append(ref) 

227 

228 good_refs: list[DatasetRef] = [] 

229 for run, run_refs in refs_by_run.items(): 

230 if not self._check_dataset_run(run): 

231 continue 

232 good_refs.extend(run_refs) 

233 obscore_refs = good_refs 

234 

235 else: 

236 # Take all refs, no collection check. 

237 obscore_refs = refs 

238 

239 return self._populate(obscore_refs, context) 

240 

241 def associate( 

242 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

243 ) -> int: 

244 # Docstring inherited from base class. 

245 

246 # Only works when collection type is TAGGED 

247 if self.tagged_collection is None: 

248 return 0 

249 

250 if collection.name == self.tagged_collection: 

251 return self._populate(refs, context) 

252 else: 

253 return 0 

254 

255 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

256 # Docstring inherited from base class. 

257 

258 # Only works when collection type is TAGGED 

259 if self.tagged_collection is None: 

260 return 0 

261 

262 count = 0 

263 if collection.name == self.tagged_collection: 

264 # Sorting may improve performance 

265 dataset_ids = sorted(ref.id for ref in refs) 

266 if dataset_ids: 

267 fk_field = self.schema.dataset_fk 

268 assert fk_field is not None, "Cannot be None by construction" 

269 # There may be too many of them, do it in chunks. 

270 for ids in chunk_iterable(dataset_ids): 

271 where = self.table.columns[fk_field.name].in_(ids) 

272 count += self.db.deleteWhere(self.table, where) 

273 return count 

274 

275 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

276 """Populate obscore table with the data from given datasets.""" 

277 records: list[Record] = [] 

278 for ref in refs: 

279 record = self.record_factory(ref, context) 

280 if record is not None: 

281 records.append(record) 

282 

283 if records: 

284 # Ignore potential conflicts with existing datasets. 

285 return self.db.ensure(self.table, *records, primary_key_only=True) 

286 else: 

287 return 0 

288 

289 def _check_dataset_run(self, run: str) -> bool: 

290 """Check that specified run collection matches know patterns.""" 

291 

292 if not self.run_patterns: 

293 # Empty list means take anything. 

294 return True 

295 

296 # Try each pattern in turn. 

297 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

298 

299 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

300 # Docstring inherited from base class. 

301 instrument_column = self.schema.dimension_column("instrument") 

302 exposure_column = self.schema.dimension_column("exposure") 

303 detector_column = self.schema.dimension_column("detector") 

304 if instrument_column is None or exposure_column is None or detector_column is None: 

305 # Not all needed columns are in the table. 

306 return 0 

307 

308 update_rows: list[Record] = [] 

309 for exposure, detector, region in region_data: 

310 try: 

311 record = self.record_factory.make_spatial_records(region) 

312 except RegionTypeError as exc: 

313 warnings.warn( 

314 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

315 category=RegionTypeWarning, 

316 ) 

317 continue 

318 

319 record.update( 

320 { 

321 "instrument_column": instrument, 

322 "exposure_column": exposure, 

323 "detector_column": detector, 

324 } 

325 ) 

326 update_rows.append(record) 

327 

328 where_dict: dict[str, str] = { 

329 instrument_column: "instrument_column", 

330 exposure_column: "exposure_column", 

331 detector_column: "detector_column", 

332 } 

333 

334 count = self.db.update(self.table, where_dict, *update_rows) 

335 return count 

336 

337 @contextmanager 

338 def query( 

339 self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any 

340 ) -> Iterator[sqlalchemy.engine.CursorResult]: 

341 # Docstring inherited from base class. 

342 if columns is not None: 

343 column_elements: list[sqlalchemy.sql.ColumnElement] = [] 

344 for column in columns: 

345 if isinstance(column, str): 

346 column_elements.append(self.table.columns[column]) 

347 else: 

348 column_elements.append(column) 

349 query = sqlalchemy.sql.select(*column_elements).select_from(self.table) 

350 else: 

351 query = self.table.select() 

352 

353 if kwargs: 

354 query = query.where( 

355 sqlalchemy.sql.expression.and_( 

356 *[self.table.columns[column] == value for column, value in kwargs.items()] 

357 ) 

358 ) 

359 with self.db.query(query) as result: 

360 yield result