Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 17%

146 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-15 00:10 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import uuid 

29from collections import defaultdict 

30from collections.abc import Collection, Mapping 

31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast 

32 

33import sqlalchemy 

34from lsst.daf.butler import ( 

35 Config, 

36 DataCoordinate, 

37 DataCoordinateIterable, 

38 DatasetRef, 

39 Dimension, 

40 DimensionUniverse, 

41) 

42from lsst.sphgeom import Region 

43from lsst.utils.iteration import chunk_iterable 

44 

45from ..interfaces import ObsCoreTableManager, VersionTuple 

46from ._config import ConfigCollectionType, ObsCoreManagerConfig 

47from ._records import ExposureRegionFactory, Record, RecordFactory 

48from ._schema import ObsCoreSchema 

49from ._spatial import SpatialObsCorePlugin 

50 

51if TYPE_CHECKING: 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true

52 from ..interfaces import ( 

53 CollectionRecord, 

54 Database, 

55 DatasetRecordStorageManager, 

56 DimensionRecordStorageManager, 

57 StaticTablesContext, 

58 ) 

59 

60_VERSION = VersionTuple(0, 0, 1) 

61 

62 

63class _ExposureRegionFactory(ExposureRegionFactory): 

64 """Find exposure region from a matching visit dimensions records.""" 

65 

66 def __init__(self, dimensions: DimensionRecordStorageManager): 

67 self.dimensions = dimensions 

68 self.universe = dimensions.universe 

69 self.exposure = self.universe["exposure"] 

70 self.visit = self.universe["visit"] 

71 

72 def exposure_region(self, dataId: DataCoordinate) -> Optional[Region]: 

73 # Docstring is inherited from a base class. 

74 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

75 if visit_definition_storage is None: 

76 return None 

77 exposureDataId = dataId.subset(self.exposure.graph) 

78 records = visit_definition_storage.fetch(DataCoordinateIterable.fromScalar(exposureDataId)) 

79 # There may be more than one visit per exposure, they should nave the 

80 # same region, so we use arbitrary one. 

81 record = next(iter(records), None) 

82 if record is None: 

83 return None 

84 visit: int = record.visit 

85 

86 detector = cast(Dimension, self.universe["detector"]) 

87 if detector in dataId: 

88 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

89 if visit_detector_region_storage is None: 

90 return None 

91 visitDataId = DataCoordinate.standardize( 

92 { 

93 "instrument": dataId["instrument"], 

94 "visit": visit, 

95 "detector": dataId["detector"], 

96 }, 

97 universe=self.universe, 

98 ) 

99 records = visit_detector_region_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId)) 

100 record = next(iter(records), None) 

101 if record is not None: 

102 return record.region 

103 

104 else: 

105 visit_storage = self.dimensions.get(self.visit) 

106 if visit_storage is None: 

107 return None 

108 visitDataId = DataCoordinate.standardize( 

109 { 

110 "instrument": dataId["instrument"], 

111 "visit": visit, 

112 }, 

113 universe=self.universe, 

114 ) 

115 records = visit_storage.fetch(DataCoordinateIterable.fromScalar(visitDataId)) 

116 record = next(iter(records), None) 

117 if record is not None: 

118 return record.region 

119 

120 return None 

121 

122 

123class ObsCoreLiveTableManager(ObsCoreTableManager): 

124 """A manager class for ObsCore table, implements methods for updating the 

125 records in that table. 

126 """ 

127 

128 def __init__( 

129 self, 

130 *, 

131 db: Database, 

132 table: sqlalchemy.schema.Table, 

133 schema: ObsCoreSchema, 

134 universe: DimensionUniverse, 

135 config: ObsCoreManagerConfig, 

136 dimensions: DimensionRecordStorageManager, 

137 spatial_plugins: Collection[SpatialObsCorePlugin], 

138 ): 

139 self.db = db 

140 self.table = table 

141 self.schema = schema 

142 self.universe = universe 

143 self.config = config 

144 self.spatial_plugins = spatial_plugins 

145 exposure_region_factory = _ExposureRegionFactory(dimensions) 

146 self.record_factory = RecordFactory( 

147 config, schema, universe, spatial_plugins, exposure_region_factory 

148 ) 

149 self.tagged_collection: Optional[str] = None 

150 self.run_patterns: list[re.Pattern] = [] 

151 if config.collection_type is ConfigCollectionType.TAGGED: 

152 assert ( 

153 config.collections is not None and len(config.collections) == 1 

154 ), "Exactly one collection name required for tagged type." 

155 self.tagged_collection = config.collections[0] 

156 elif config.collection_type is ConfigCollectionType.RUN: 

157 if config.collections: 

158 for coll in config.collections: 

159 try: 

160 self.run_patterns.append(re.compile(coll)) 

161 except re.error as exc: 

162 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

163 else: 

164 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

165 

166 @classmethod 

167 def initialize( 

168 cls, 

169 db: Database, 

170 context: StaticTablesContext, 

171 *, 

172 universe: DimensionUniverse, 

173 config: Mapping, 

174 datasets: Type[DatasetRecordStorageManager], 

175 dimensions: DimensionRecordStorageManager, 

176 ) -> ObsCoreTableManager: 

177 # Docstring inherited from base class. 

178 config_data = Config(config) 

179 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

180 

181 # Instantiate all spatial plugins. 

182 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

183 

184 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

185 

186 # Generate table specification for main obscore table. 

187 table_spec = schema.table_spec 

188 for plugin in spatial_plugins: 

189 plugin.extend_table_spec(table_spec) 

190 table = context.addTable(obscore_config.table_name, schema.table_spec) 

191 

192 return ObsCoreLiveTableManager( 

193 db=db, 

194 table=table, 

195 schema=schema, 

196 universe=universe, 

197 config=obscore_config, 

198 dimensions=dimensions, 

199 spatial_plugins=spatial_plugins, 

200 ) 

201 

202 def config_json(self) -> str: 

203 """Dump configuration in JSON format. 

204 

205 Returns 

206 ------- 

207 json : `str` 

208 Configuration serialized in JSON format. 

209 """ 

210 return json.dumps(self.config.dict()) 

211 

212 @classmethod 

213 def currentVersion(cls) -> Optional[VersionTuple]: 

214 # Docstring inherited from base class. 

215 return _VERSION 

216 

217 def schemaDigest(self) -> Optional[str]: 

218 # Docstring inherited from base class. 

219 return None 

220 

221 def add_datasets(self, refs: Iterable[DatasetRef]) -> int: 

222 # Docstring inherited from base class. 

223 

224 # Only makes sense for RUN collection types 

225 if self.config.collection_type is not ConfigCollectionType.RUN: 

226 return 0 

227 

228 obscore_refs: Iterable[DatasetRef] 

229 if self.run_patterns: 

230 # Check each dataset run against configured run list. We want to 

231 # reduce number of calls to _check_dataset_run, which may be 

232 # expensive. Normally references are grouped by run, if there are 

233 # multiple input references, they should have the same run. 

234 # Instead of just checking that, we group them by run again. 

235 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list) 

236 for ref in refs: 

237 # Record factory will filter dataset types, but to reduce 

238 # collection checks we also pre-filter it here. 

239 if ref.datasetType.name not in self.config.dataset_types: 

240 continue 

241 

242 assert ref.run is not None, "Run cannot be None" 

243 refs_by_run[ref.run].append(ref) 

244 

245 good_refs: List[DatasetRef] = [] 

246 for run, run_refs in refs_by_run.items(): 

247 if not self._check_dataset_run(run): 

248 continue 

249 good_refs.extend(run_refs) 

250 obscore_refs = good_refs 

251 

252 else: 

253 # Take all refs, no collection check. 

254 obscore_refs = refs 

255 

256 return self._populate(obscore_refs) 

257 

258 def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

259 # Docstring inherited from base class. 

260 

261 # Only works when collection type is TAGGED 

262 if self.tagged_collection is None: 

263 return 0 

264 

265 if collection.name == self.tagged_collection: 

266 return self._populate(refs) 

267 else: 

268 return 0 

269 

270 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

271 # Docstring inherited from base class. 

272 

273 # Only works when collection type is TAGGED 

274 if self.tagged_collection is None: 

275 return 0 

276 

277 count = 0 

278 if collection.name == self.tagged_collection: 

279 # Sorting may improve performance 

280 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs) 

281 if dataset_ids: 

282 fk_field = self.schema.dataset_fk 

283 assert fk_field is not None, "Cannot be None by construction" 

284 # There may be too many of them, do it in chunks. 

285 for ids in chunk_iterable(dataset_ids): 

286 where = self.table.columns[fk_field.name].in_(ids) 

287 count += self.db.deleteWhere(self.table, where) 

288 

289 return count 

290 

291 def _populate(self, refs: Iterable[DatasetRef]) -> int: 

292 """Populate obscore table with the data from given datasets.""" 

293 records: List[Record] = [] 

294 for ref in refs: 

295 record = self.record_factory(ref) 

296 if record is not None: 

297 records.append(record) 

298 

299 if records: 

300 # Ignore potential conflicts with existing datasets. 

301 return self.db.ensure(self.table, *records, primary_key_only=True) 

302 else: 

303 return 0 

304 

305 def _check_dataset_run(self, run: str) -> bool: 

306 """Check that specified run collection matches know patterns.""" 

307 

308 if not self.run_patterns: 

309 # Empty list means take anything. 

310 return True 

311 

312 # Try each pattern in turn. 

313 return any(pattern.fullmatch(run) for pattern in self.run_patterns)