Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 20%

144 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-07 02:47 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import uuid 

29from collections import defaultdict 

30from collections.abc import Mapping 

31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast 

32 

33import sqlalchemy 

34from lsst.daf.butler import ( 

35 Config, 

36 DataCoordinate, 

37 DataCoordinateIterable, 

38 DatasetRef, 

39 Dimension, 

40 DimensionUniverse, 

41) 

42from lsst.sphgeom import Region 

43from lsst.utils.iteration import chunk_iterable 

44 

45from ..interfaces import ObsCoreTableManager, VersionTuple 

46from ._config import ConfigCollectionType, ObsCoreManagerConfig 

47from ._records import ExposureRegionFactory, RecordFactory 

48from ._schema import ObsCoreSchema 

49 

50if TYPE_CHECKING: 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true

51 from ..interfaces import ( 

52 CollectionRecord, 

53 Database, 

54 DatasetRecordStorageManager, 

55 DimensionRecordStorageManager, 

56 StaticTablesContext, 

57 ) 

58 

59_VERSION = VersionTuple(0, 0, 1) 

60 

61 

62class _ExposureRegionFactory(ExposureRegionFactory): 

63 """Find exposure region from a matching visit dimensions records.""" 

64 

65 def __init__(self, dimensions: DimensionRecordStorageManager): 

66 self.dimensions = dimensions 

67 self.universe = dimensions.universe 

68 self.exposure = self.universe["exposure"] 

69 self.visit = self.universe["visit"] 

70 

71 def exposure_region(self, dataId: DataCoordinate) -> Optional[Region]: 

72 # Docstring is inherited from a base class. 

73 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

74 if visit_definition_storage is None: 

75 return None 

76 exposureDataId = dataId.subset(self.exposure.graph) 

77 records = visit_definition_storage.fetch(DataCoordinateIterable.fromScalar(exposureDataId)) 

78 # There may be more than one visit per exposure, they should nave the 

79 # same region, so we use arbitrary one. 

80 record = next(iter(records), None) 

81 if record is None: 

82 return None 

83 visit: int = record.visit 

84 

85 detector = cast(Dimension, self.universe["detector"]) 

86 if detector in dataId: 

87 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

88 if visit_detector_region_storage is None: 

89 return None 

90 visitDataId = DataCoordinate.standardize( 

91 { 

92 "instrument": dataId["instrument"], 

93 "visit": visit, 

94 "detector": dataId["detector"], 

95 }, 

96 universe=self.universe, 

97 ) 

98 records = visit_detector_region_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId)) 

99 record = next(iter(records), None) 

100 if record is not None: 

101 return record.region 

102 

103 else: 

104 

105 visit_storage = self.dimensions.get(self.visit) 

106 if visit_storage is None: 

107 return None 

108 visitDataId = DataCoordinate.standardize( 

109 { 

110 "instrument": dataId["instrument"], 

111 "visit": visit, 

112 }, 

113 universe=self.universe, 

114 ) 

115 records = visit_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId)) 

116 record = next(iter(records), None) 

117 if record is not None: 

118 return record.region 

119 

120 return None 

121 

122 

123class ObsCoreLiveTableManager(ObsCoreTableManager): 

124 """A manager class for ObsCore table, implements methods for updating the 

125 records in that table. 

126 """ 

127 

128 def __init__( 

129 self, 

130 *, 

131 db: Database, 

132 table: sqlalchemy.schema.Table, 

133 schema: ObsCoreSchema, 

134 universe: DimensionUniverse, 

135 config: ObsCoreManagerConfig, 

136 dimensions: DimensionRecordStorageManager, 

137 ): 

138 self.db = db 

139 self.table = table 

140 self.schema = schema 

141 self.universe = universe 

142 self.config = config 

143 exposure_region_factory = _ExposureRegionFactory(dimensions) 

144 self.record_factory = RecordFactory(config, schema, universe, exposure_region_factory) 

145 

146 @classmethod 

147 def initialize( 

148 cls, 

149 db: Database, 

150 context: StaticTablesContext, 

151 *, 

152 universe: DimensionUniverse, 

153 config: Mapping, 

154 datasets: Type[DatasetRecordStorageManager], 

155 dimensions: DimensionRecordStorageManager, 

156 ) -> ObsCoreTableManager: 

157 # Docstring inherited from base class. 

158 config_data = Config(config) 

159 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

160 

161 schema = ObsCoreSchema(config=obscore_config, datasets=datasets) 

162 table = context.addTable(obscore_config.table_name, schema.table_spec) 

163 if obscore_config.collection_type is ConfigCollectionType.TAGGED: 

164 # Configuration validation guarantees that there is exactly one 

165 # collection for TAGGED type. 

166 assert obscore_config.collections is not None, "Collections must be defined" 

167 return _TaggedObsCoreTableManager( 

168 db=db, 

169 table=table, 

170 schema=schema, 

171 universe=universe, 

172 config=obscore_config, 

173 tagged_collection=obscore_config.collections[0], 

174 dimensions=dimensions, 

175 ) 

176 elif obscore_config.collection_type is ConfigCollectionType.RUN: 

177 return _RunObsCoreTableManager( 

178 db=db, 

179 table=table, 

180 schema=schema, 

181 universe=universe, 

182 config=obscore_config, 

183 dimensions=dimensions, 

184 ) 

185 else: 

186 raise ValueError(f"Unexpected value of collection_type: {obscore_config.collection_type}") 

187 

188 def config_json(self) -> str: 

189 """Dump configuration in JSON format. 

190 

191 Returns 

192 ------- 

193 json : `str` 

194 Configuration serialized in JSON format. 

195 """ 

196 return json.dumps(self.config.dict()) 

197 

198 @classmethod 

199 def currentVersion(cls) -> Optional[VersionTuple]: 

200 # Docstring inherited from base class. 

201 return _VERSION 

202 

203 def schemaDigest(self) -> Optional[str]: 

204 # Docstring inherited from base class. 

205 return None 

206 

207 

208class _TaggedObsCoreTableManager(ObsCoreLiveTableManager): 

209 """Implementation of ObsCoreTableManager which is used for 

210 ``collection_type=TAGGED``. 

211 """ 

212 

213 def __init__( 

214 self, 

215 *, 

216 db: Database, 

217 table: sqlalchemy.schema.Table, 

218 schema: ObsCoreSchema, 

219 universe: DimensionUniverse, 

220 config: ObsCoreManagerConfig, 

221 tagged_collection: str, 

222 dimensions: DimensionRecordStorageManager, 

223 ): 

224 super().__init__( 

225 db=db, 

226 table=table, 

227 schema=schema, 

228 universe=universe, 

229 config=config, 

230 dimensions=dimensions, 

231 ) 

232 self.tagged_collection = tagged_collection 

233 

234 def add_datasets(self, refs: Iterable[DatasetRef]) -> None: 

235 # Docstring inherited from base class. 

236 return 

237 

238 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None: 

239 # Docstring inherited from base class. 

240 

241 if collection.name == self.tagged_collection: 

242 

243 records: List[dict] = [] 

244 for ref in refs: 

245 if (record := self.record_factory(ref)) is not None: 

246 records.append(record) 

247 

248 if records: 

249 # Ignore potential conflicts with existing datasets. 

250 self.db.ensure(self.table, *records, primary_key_only=True) 

251 

252 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None: 

253 # Docstring inherited from base class. 

254 

255 if collection.name == self.tagged_collection: 

256 

257 # Sorting may improve performance 

258 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs) 

259 if dataset_ids: 

260 fk_field = self.schema.dataset_fk 

261 assert fk_field is not None, "Cannot be None by construction" 

262 # There may be too many of them, do it in chunks. 

263 for ids in chunk_iterable(dataset_ids): 

264 where = self.table.columns[fk_field.name].in_(ids) 

265 self.db.deleteWhere(self.table, where) 

266 

267 

268class _RunObsCoreTableManager(ObsCoreLiveTableManager): 

269 """Implementation of ObsCoreTableManager which is used for 

270 ``collection_type=TAGGED``. 

271 """ 

272 

273 def __init__( 

274 self, 

275 *, 

276 db: Database, 

277 table: sqlalchemy.schema.Table, 

278 schema: ObsCoreSchema, 

279 universe: DimensionUniverse, 

280 config: ObsCoreManagerConfig, 

281 dimensions: DimensionRecordStorageManager, 

282 ): 

283 super().__init__( 

284 db=db, 

285 table=table, 

286 schema=schema, 

287 universe=universe, 

288 config=config, 

289 dimensions=dimensions, 

290 ) 

291 

292 self.run_patterns: List[re.Pattern] = [] 

293 if config.collections: 

294 for coll in config.collections: 

295 try: 

296 self.run_patterns.append(re.compile(coll)) 

297 except re.error as exc: 

298 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

299 

300 def add_datasets(self, refs: Iterable[DatasetRef]) -> None: 

301 # Docstring inherited from base class. 

302 

303 obscore_refs: Iterable[DatasetRef] 

304 if self.run_patterns: 

305 # Check each dataset run against configured run list. We want to 

306 # reduce number of calls to _check_dataset_run, which may be 

307 # expensive. Normally references are grouped by run, if there are 

308 # multiple input references, they should have the same run. 

309 # Instead of just checking that, we group them by run again. 

310 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list) 

311 for ref in refs: 

312 

313 # Record factory will filter dataset types, but to reduce 

314 # collection checks we also pre-filter it here. 

315 if ref.datasetType.name not in self.config.dataset_types: 

316 continue 

317 

318 assert ref.run is not None, "Run cannot be None" 

319 refs_by_run[ref.run].append(ref) 

320 

321 good_refs: List[DatasetRef] = [] 

322 for run, run_refs in refs_by_run.items(): 

323 if not self._check_dataset_run(run): 

324 continue 

325 good_refs.extend(run_refs) 

326 obscore_refs = good_refs 

327 

328 else: 

329 

330 # Take all refs, no collection check. 

331 obscore_refs = refs 

332 

333 # Convert them all to records. 

334 records: List[dict] = [] 

335 for ref in obscore_refs: 

336 if (record := self.record_factory(ref)) is not None: 

337 records.append(record) 

338 

339 if records: 

340 # Ignore potential conflicts with existing datasets. 

341 self.db.ensure(self.table, *records, primary_key_only=True) 

342 

343 def _check_dataset_run(self, run: str) -> bool: 

344 """Check that specified run collection matches know patterns.""" 

345 

346 if not self.run_patterns: 

347 # Empty list means take anything. 

348 return True 

349 

350 # Try each pattern in turn. 

351 return any(pattern.fullmatch(run) for pattern in self.run_patterns) 

352 

353 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None: 

354 # Docstring inherited from base class. 

355 return 

356 

357 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None: 

358 # Docstring inherited from base class. 

359 return