Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%

142 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-09 02:51 -0800

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import uuid 

29from collections import defaultdict 

30from collections.abc import Collection, Mapping 

31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast 

32 

33import sqlalchemy 

34from lsst.daf.butler import ( 

35 Config, 

36 DataCoordinate, 

37 DataCoordinateIterable, 

38 DatasetRef, 

39 Dimension, 

40 DimensionUniverse, 

41) 

42from lsst.sphgeom import Region 

43from lsst.utils.iteration import chunk_iterable 

44 

45from ..interfaces import ObsCoreTableManager, VersionTuple 

46from ._config import ConfigCollectionType, ObsCoreManagerConfig 

47from ._records import ExposureRegionFactory, Record, RecordFactory 

48from ._schema import ObsCoreSchema 

49from ._spatial import SpatialObsCorePlugin 

50 

51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true

52 from ..interfaces import ( 

53 CollectionRecord, 

54 Database, 

55 DatasetRecordStorageManager, 

56 DimensionRecordStorageManager, 

57 StaticTablesContext, 

58 ) 

59 

60_VERSION = VersionTuple(0, 0, 1) 

61 

62 

63class _ExposureRegionFactory(ExposureRegionFactory): 

64 """Find exposure region from a matching visit dimensions records.""" 

65 

66 def __init__(self, dimensions: DimensionRecordStorageManager): 

67 self.dimensions = dimensions 

68 self.universe = dimensions.universe 

69 self.exposure = self.universe["exposure"] 

70 self.visit = self.universe["visit"] 

71 

72 def exposure_region(self, dataId: DataCoordinate) -> Optional[Region]: 

73 # Docstring is inherited from a base class. 

74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

75 if visit_definition_storage is None: 

76 return None 

77 exposureDataId = dataId.subset(self.exposure.graph) 

78 records = visit_definition_storage.fetch(DataCoordinateIterable.fromScalar(exposureDataId)) 

79 # There may be more than one visit per exposure, they should nave the 

80 # same region, so we use arbitrary one. 

81 record = next(iter(records), None) 

82 if record is None: 

83 return None 

84 visit: int = record.visit 

85 

86 detector = cast(Dimension, self.universe["detector"]) 

87 if detector in dataId: 

88 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

89 if visit_detector_region_storage is None: 

90 return None 

91 visitDataId = DataCoordinate.standardize( 

92 { 

93 "instrument": dataId["instrument"], 

94 "visit": visit, 

95 "detector": dataId["detector"], 

96 }, 

97 universe=self.universe, 

98 ) 

99 records = visit_detector_region_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId)) 

100 record = next(iter(records), None) 

101 if record is not None: 

102 return record.region 

103 

104 else: 

105 

106 visit_storage = self.dimensions.get(self.visit) 

107 if visit_storage is None: 

108 return None 

109 visitDataId = DataCoordinate.standardize( 

110 { 

111 "instrument": dataId["instrument"], 

112 "visit": visit, 

113 }, 

114 universe=self.universe, 

115 ) 

116 records = visit_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId)) 

117 record = next(iter(records), None) 

118 if record is not None: 

119 return record.region 

120 

121 return None 

122 

123 

124class ObsCoreLiveTableManager(ObsCoreTableManager): 

125 """A manager class for ObsCore table, implements methods for updating the 

126 records in that table. 

127 """ 

128 

129 def __init__( 

130 self, 

131 *, 

132 db: Database, 

133 table: sqlalchemy.schema.Table, 

134 schema: ObsCoreSchema, 

135 universe: DimensionUniverse, 

136 config: ObsCoreManagerConfig, 

137 dimensions: DimensionRecordStorageManager, 

138 spatial_plugins: Collection[SpatialObsCorePlugin], 

139 ): 

140 self.db = db 

141 self.table = table 

142 self.schema = schema 

143 self.universe = universe 

144 self.config = config 

145 self.spatial_plugins = spatial_plugins 

146 exposure_region_factory = _ExposureRegionFactory(dimensions) 

147 self.record_factory = RecordFactory( 

148 config, schema, universe, spatial_plugins, exposure_region_factory 

149 ) 

150 self.tagged_collection: Optional[str] = None 

151 self.run_patterns: list[re.Pattern] = [] 

152 if config.collection_type is ConfigCollectionType.TAGGED: 

153 assert ( 

154 config.collections is not None and len(config.collections) == 1 

155 ), "Exactly one collection name required for tagged type." 

156 self.tagged_collection = config.collections[0] 

157 elif config.collection_type is ConfigCollectionType.RUN: 

158 if config.collections: 

159 for coll in config.collections: 

160 try: 

161 self.run_patterns.append(re.compile(coll)) 

162 except re.error as exc: 

163 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

164 else: 

165 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

166 

167 @classmethod 

168 def initialize( 

169 cls, 

170 db: Database, 

171 context: StaticTablesContext, 

172 *, 

173 universe: DimensionUniverse, 

174 config: Mapping, 

175 datasets: Type[DatasetRecordStorageManager], 

176 dimensions: DimensionRecordStorageManager, 

177 ) -> ObsCoreTableManager: 

178 # Docstring inherited from base class. 

179 config_data = Config(config) 

180 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

181 

182 # Instantiate all spatial plugins. 

183 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

184 

185 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

186 

187 # Generate table specification for main obscore table. 

188 table_spec = schema.table_spec 

189 for plugin in spatial_plugins: 

190 plugin.extend_table_spec(table_spec) 

191 table = context.addTable(obscore_config.table_name, schema.table_spec) 

192 

193 return ObsCoreLiveTableManager( 

194 db=db, 

195 table=table, 

196 schema=schema, 

197 universe=universe, 

198 config=obscore_config, 

199 dimensions=dimensions, 

200 spatial_plugins=spatial_plugins, 

201 ) 

202 

203 def config_json(self) -> str: 

204 """Dump configuration in JSON format. 

205 

206 Returns 

207 ------- 

208 json : `str` 

209 Configuration serialized in JSON format. 

210 """ 

211 return json.dumps(self.config.dict()) 

212 

213 @classmethod 

214 def currentVersion(cls) -> Optional[VersionTuple]: 

215 # Docstring inherited from base class. 

216 return _VERSION 

217 

218 def schemaDigest(self) -> Optional[str]: 

219 # Docstring inherited from base class. 

220 return None 

221 

222 def add_datasets(self, refs: Iterable[DatasetRef]) -> None: 

223 # Docstring inherited from base class. 

224 

225 # Only makes sense for RUN collection types 

226 if self.config.collection_type is not ConfigCollectionType.RUN: 

227 return 

228 

229 obscore_refs: Iterable[DatasetRef] 

230 if self.run_patterns: 

231 # Check each dataset run against configured run list. We want to 

232 # reduce number of calls to _check_dataset_run, which may be 

233 # expensive. Normally references are grouped by run, if there are 

234 # multiple input references, they should have the same run. 

235 # Instead of just checking that, we group them by run again. 

236 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list) 

237 for ref in refs: 

238 

239 # Record factory will filter dataset types, but to reduce 

240 # collection checks we also pre-filter it here. 

241 if ref.datasetType.name not in self.config.dataset_types: 

242 continue 

243 

244 assert ref.run is not None, "Run cannot be None" 

245 refs_by_run[ref.run].append(ref) 

246 

247 good_refs: List[DatasetRef] = [] 

248 for run, run_refs in refs_by_run.items(): 

249 if not self._check_dataset_run(run): 

250 continue 

251 good_refs.extend(run_refs) 

252 obscore_refs = good_refs 

253 

254 else: 

255 

256 # Take all refs, no collection check. 

257 obscore_refs = refs 

258 

259 self._populate(obscore_refs) 

260 

261 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None: 

262 # Docstring inherited from base class. 

263 

264 # Only works when collection type is TAGGED 

265 if self.tagged_collection is None: 

266 return 

267 

268 if collection.name == self.tagged_collection: 

269 self._populate(refs) 

270 

271 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> None: 

272 # Docstring inherited from base class. 

273 

274 # Only works when collection type is TAGGED 

275 if self.tagged_collection is None: 

276 return 

277 

278 if collection.name == self.tagged_collection: 

279 

280 # Sorting may improve performance 

281 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs) 

282 if dataset_ids: 

283 fk_field = self.schema.dataset_fk 

284 assert fk_field is not None, "Cannot be None by construction" 

285 # There may be too many of them, do it in chunks. 

286 for ids in chunk_iterable(dataset_ids): 

287 where = self.table.columns[fk_field.name].in_(ids) 

288 self.db.deleteWhere(self.table, where) 

289 

290 def _populate(self, refs: Iterable[DatasetRef]) -> None: 

291 """Populate obscore table with the data from given datasets.""" 

292 records: List[Record] = [] 

293 for ref in refs: 

294 record = self.record_factory(ref) 

295 if record is not None: 

296 records.append(record) 

297 

298 if records: 

299 # Ignore potential conflicts with existing datasets. 

300 self.db.ensure(self.table, *records, primary_key_only=True) 

301 

302 def _check_dataset_run(self, run: str) -> bool: 

303 """Check that specified run collection matches know patterns.""" 

304 

305 if not self.run_patterns: 

306 # Empty list means take anything. 

307 return True 

308 

309 # Try each pattern in turn. 

310 return any(pattern.fullmatch(run) for pattern in self.run_patterns)