Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 17%

168 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-28 04:40 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import uuid 

29import warnings 

30from collections import defaultdict 

31from collections.abc import Collection, Iterable, Iterator, Mapping 

32from contextlib import contextmanager 

33from typing import TYPE_CHECKING, Any, Type, cast 

34 

35import sqlalchemy 

36from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

37from lsst.daf.relation import Join 

38from lsst.sphgeom import Region 

39from lsst.utils.iteration import chunk_iterable 

40 

41from ..interfaces import ObsCoreTableManager, VersionTuple 

42from ._config import ConfigCollectionType, ObsCoreManagerConfig 

43from ._records import ExposureRegionFactory, Record, RecordFactory 

44from ._schema import ObsCoreSchema 

45from ._spatial import RegionTypeError, RegionTypeWarning, SpatialObsCorePlugin 

46 

47if TYPE_CHECKING: 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true

48 from ..interfaces import ( 

49 CollectionRecord, 

50 Database, 

51 DatasetRecordStorageManager, 

52 DimensionRecordStorageManager, 

53 StaticTablesContext, 

54 ) 

55 from ..queries import SqlQueryContext 

56 

57_VERSION = VersionTuple(0, 0, 1) 

58 

59 

60class _ExposureRegionFactory(ExposureRegionFactory): 

61 """Find exposure region from a matching visit dimensions records.""" 

62 

63 def __init__(self, dimensions: DimensionRecordStorageManager): 

64 self.dimensions = dimensions 

65 self.universe = dimensions.universe 

66 self.exposure_dimensions = self.universe["exposure"].graph 

67 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"]) 

68 

69 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None: 

70 # Docstring is inherited from a base class. 

71 # Make a relation that starts with visit_definition (mapping between 

72 # exposure and visit). 

73 relation = context.make_initial_relation() 

74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

75 if visit_definition_storage is None: 

76 return None 

77 relation = visit_definition_storage.join(relation, Join(), context) 

78 # Join in a table with either visit+detector regions or visit regions. 

79 if "detector" in dataId.names: 

80 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

81 if visit_detector_region_storage is None: 

82 return None 

83 relation = visit_detector_region_storage.join(relation, Join(), context) 

84 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

85 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

86 else: 

87 visit_storage = self.dimensions.get(self.universe["visit"]) 

88 if visit_storage is None: 

89 return None 

90 relation = visit_storage.join(relation, Join(), context) 

91 constraint_data_id = dataId.subset(self.exposure_dimensions) 

92 region_tag = DimensionRecordColumnTag("visit", "region") 

93 # Constrain the relation to match the given exposure and (if present) 

94 # detector IDs. 

95 relation = relation.with_rows_satisfying( 

96 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

97 ) 

98 # If we get more than one result (because the exposure belongs to 

99 # multiple visits), just pick an arbitrary one. 

100 relation = relation[:1] 

101 # Run the query and extract the region, if the query has any results. 

102 for row in context.fetch_iterable(relation): 

103 return row[region_tag] 

104 return None 

105 

106 

107class ObsCoreLiveTableManager(ObsCoreTableManager): 

108 """A manager class for ObsCore table, implements methods for updating the 

109 records in that table. 

110 """ 

111 

112 def __init__( 

113 self, 

114 *, 

115 db: Database, 

116 table: sqlalchemy.schema.Table, 

117 schema: ObsCoreSchema, 

118 universe: DimensionUniverse, 

119 config: ObsCoreManagerConfig, 

120 dimensions: DimensionRecordStorageManager, 

121 spatial_plugins: Collection[SpatialObsCorePlugin], 

122 ): 

123 self.db = db 

124 self.table = table 

125 self.schema = schema 

126 self.universe = universe 

127 self.config = config 

128 self.spatial_plugins = spatial_plugins 

129 exposure_region_factory = _ExposureRegionFactory(dimensions) 

130 self.record_factory = RecordFactory( 

131 config, schema, universe, spatial_plugins, exposure_region_factory 

132 ) 

133 self.tagged_collection: str | None = None 

134 self.run_patterns: list[re.Pattern] = [] 

135 if config.collection_type is ConfigCollectionType.TAGGED: 

136 assert ( 

137 config.collections is not None and len(config.collections) == 1 

138 ), "Exactly one collection name required for tagged type." 

139 self.tagged_collection = config.collections[0] 

140 elif config.collection_type is ConfigCollectionType.RUN: 

141 if config.collections: 

142 for coll in config.collections: 

143 try: 

144 self.run_patterns.append(re.compile(coll)) 

145 except re.error as exc: 

146 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

147 else: 

148 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

149 

150 @classmethod 

151 def initialize( 

152 cls, 

153 db: Database, 

154 context: StaticTablesContext, 

155 *, 

156 universe: DimensionUniverse, 

157 config: Mapping, 

158 datasets: Type[DatasetRecordStorageManager], 

159 dimensions: DimensionRecordStorageManager, 

160 ) -> ObsCoreTableManager: 

161 # Docstring inherited from base class. 

162 config_data = Config(config) 

163 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

164 

165 # Instantiate all spatial plugins. 

166 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

167 

168 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

169 

170 # Generate table specification for main obscore table. 

171 table_spec = schema.table_spec 

172 for plugin in spatial_plugins: 

173 plugin.extend_table_spec(table_spec) 

174 table = context.addTable(obscore_config.table_name, schema.table_spec) 

175 

176 return ObsCoreLiveTableManager( 

177 db=db, 

178 table=table, 

179 schema=schema, 

180 universe=universe, 

181 config=obscore_config, 

182 dimensions=dimensions, 

183 spatial_plugins=spatial_plugins, 

184 ) 

185 

186 def config_json(self) -> str: 

187 """Dump configuration in JSON format. 

188 

189 Returns 

190 ------- 

191 json : `str` 

192 Configuration serialized in JSON format. 

193 """ 

194 return json.dumps(self.config.dict()) 

195 

196 @classmethod 

197 def currentVersion(cls) -> VersionTuple | None: 

198 # Docstring inherited from base class. 

199 return _VERSION 

200 

201 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

202 # Docstring inherited from base class. 

203 

204 # Only makes sense for RUN collection types 

205 if self.config.collection_type is not ConfigCollectionType.RUN: 

206 return 0 

207 

208 obscore_refs: Iterable[DatasetRef] 

209 if self.run_patterns: 

210 # Check each dataset run against configured run list. We want to 

211 # reduce number of calls to _check_dataset_run, which may be 

212 # expensive. Normally references are grouped by run, if there are 

213 # multiple input references, they should have the same run. 

214 # Instead of just checking that, we group them by run again. 

215 refs_by_run: dict[str, list[DatasetRef]] = defaultdict(list) 

216 for ref in refs: 

217 # Record factory will filter dataset types, but to reduce 

218 # collection checks we also pre-filter it here. 

219 if ref.datasetType.name not in self.config.dataset_types: 

220 continue 

221 

222 assert ref.run is not None, "Run cannot be None" 

223 refs_by_run[ref.run].append(ref) 

224 

225 good_refs: list[DatasetRef] = [] 

226 for run, run_refs in refs_by_run.items(): 

227 if not self._check_dataset_run(run): 

228 continue 

229 good_refs.extend(run_refs) 

230 obscore_refs = good_refs 

231 

232 else: 

233 # Take all refs, no collection check. 

234 obscore_refs = refs 

235 

236 return self._populate(obscore_refs, context) 

237 

238 def associate( 

239 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

240 ) -> int: 

241 # Docstring inherited from base class. 

242 

243 # Only works when collection type is TAGGED 

244 if self.tagged_collection is None: 

245 return 0 

246 

247 if collection.name == self.tagged_collection: 

248 return self._populate(refs, context) 

249 else: 

250 return 0 

251 

252 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

253 # Docstring inherited from base class. 

254 

255 # Only works when collection type is TAGGED 

256 if self.tagged_collection is None: 

257 return 0 

258 

259 count = 0 

260 if collection.name == self.tagged_collection: 

261 # Sorting may improve performance 

262 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs) 

263 if dataset_ids: 

264 fk_field = self.schema.dataset_fk 

265 assert fk_field is not None, "Cannot be None by construction" 

266 # There may be too many of them, do it in chunks. 

267 for ids in chunk_iterable(dataset_ids): 

268 where = self.table.columns[fk_field.name].in_(ids) 

269 count += self.db.deleteWhere(self.table, where) 

270 return count 

271 

272 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

273 """Populate obscore table with the data from given datasets.""" 

274 records: list[Record] = [] 

275 for ref in refs: 

276 record = self.record_factory(ref, context) 

277 if record is not None: 

278 records.append(record) 

279 

280 if records: 

281 # Ignore potential conflicts with existing datasets. 

282 return self.db.ensure(self.table, *records, primary_key_only=True) 

283 else: 

284 return 0 

285 

286 def _check_dataset_run(self, run: str) -> bool: 

287 """Check that specified run collection matches know patterns.""" 

288 

289 if not self.run_patterns: 

290 # Empty list means take anything. 

291 return True 

292 

293 # Try each pattern in turn. 

294 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

295 

296 def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[int, int, Region]]) -> int: 

297 # Docstring inherited from base class. 

298 instrument_column = self.schema.dimension_column("instrument") 

299 exposure_column = self.schema.dimension_column("exposure") 

300 detector_column = self.schema.dimension_column("detector") 

301 if instrument_column is None or exposure_column is None or detector_column is None: 

302 # Not all needed columns are in the table. 

303 return 0 

304 

305 update_rows: list[Record] = [] 

306 for exposure, detector, region in region_data: 

307 try: 

308 record = self.record_factory.make_spatial_records(region) 

309 except RegionTypeError as exc: 

310 warnings.warn( 

311 f"Failed to convert region for exposure={exposure} detector={detector}: {exc}", 

312 category=RegionTypeWarning, 

313 ) 

314 continue 

315 

316 record.update( 

317 { 

318 "instrument_column": instrument, 

319 "exposure_column": exposure, 

320 "detector_column": detector, 

321 } 

322 ) 

323 update_rows.append(record) 

324 

325 where_dict: dict[str, str] = { 

326 instrument_column: "instrument_column", 

327 exposure_column: "exposure_column", 

328 detector_column: "detector_column", 

329 } 

330 

331 count = self.db.update(self.table, where_dict, *update_rows) 

332 return count 

333 

334 @contextmanager 

335 def query(self, **kwargs: Any) -> Iterator[sqlalchemy.engine.CursorResult]: 

336 """Run a SELECT query against obscore table and return result rows. 

337 

338 Parameters 

339 ---------- 

340 **kwargs 

341 Restriction on values of individual obscore columns. Key is the 

342 column name, value is the required value of the column. Multiple 

343 restrictions are ANDed together. 

344 """ 

345 query = self.table.select() 

346 if kwargs: 

347 query = query.where( 

348 sqlalchemy.sql.expression.and_( 

349 *[self.table.columns[column] == value for column, value in kwargs.items()] 

350 ) 

351 ) 

352 with self.db.query(query) as result: 

353 yield result