Coverage for python/lsst/daf/butler/registry/obscore/_manager.py: 18%

143 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-07 10:08 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ObsCoreLiveTableManager"] 

25 

26import json 

27import re 

28import uuid 

29from collections import defaultdict 

30from collections.abc import Collection, Mapping 

31from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Type, cast 

32 

33import sqlalchemy 

34from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DimensionRecordColumnTag, DimensionUniverse 

35from lsst.daf.relation import Join 

36from lsst.sphgeom import Region 

37from lsst.utils.iteration import chunk_iterable 

38 

39from ..interfaces import ObsCoreTableManager, VersionTuple 

40from ._config import ConfigCollectionType, ObsCoreManagerConfig 

41from ._records import ExposureRegionFactory, Record, RecordFactory 

42from ._schema import ObsCoreSchema 

43from ._spatial import SpatialObsCorePlugin 

44 

45if TYPE_CHECKING: 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true

46 from ..interfaces import ( 

47 CollectionRecord, 

48 Database, 

49 DatasetRecordStorageManager, 

50 DimensionRecordStorageManager, 

51 StaticTablesContext, 

52 ) 

53 from ..queries import SqlQueryContext 

54 

55_VERSION = VersionTuple(0, 0, 1) 

56 

57 

58class _ExposureRegionFactory(ExposureRegionFactory): 

59 """Find exposure region from a matching visit dimensions records.""" 

60 

61 def __init__(self, dimensions: DimensionRecordStorageManager): 

62 self.dimensions = dimensions 

63 self.universe = dimensions.universe 

64 self.exposure_dimensions = self.universe["exposure"].graph 

65 self.exposure_detector_dimensions = self.universe.extract(["exposure", "detector"]) 

66 

67 def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Optional[Region]: 

68 # Docstring is inherited from a base class. 

69 # Make a relation that starts with visit_definition (mapping between 

70 # exposure and visit). 

71 relation = context.make_initial_relation() 

72 visit_definition_storage = self.dimensions.get(self.universe["visit_definition"]) 

73 if visit_definition_storage is None: 

74 return None 

75 relation = visit_definition_storage.join(relation, Join(), context) 

76 # Join in a table with either visit+detector regions or visit regions. 

77 if "detector" in dataId.names: 

78 visit_detector_region_storage = self.dimensions.get(self.universe["visit_detector_region"]) 

79 if visit_detector_region_storage is None: 

80 return None 

81 relation = visit_detector_region_storage.join(relation, Join(), context) 

82 constraint_data_id = dataId.subset(self.exposure_detector_dimensions) 

83 region_tag = DimensionRecordColumnTag("visit_detector_region", "region") 

84 else: 

85 visit_storage = self.dimensions.get(self.universe["visit"]) 

86 if visit_storage is None: 

87 return None 

88 relation = visit_storage.join(relation, Join(), context) 

89 constraint_data_id = dataId.subset(self.exposure_dimensions) 

90 region_tag = DimensionRecordColumnTag("visit", "region") 

91 # Constrain the relation to match the given exposure and (if present) 

92 # detector IDs. 

93 relation = relation.with_rows_satisfying( 

94 context.make_data_coordinate_predicate(constraint_data_id, full=False) 

95 ) 

96 # If we get more than one result (because the exposure belongs to 

97 # multiple visits), just pick an arbitrary one. 

98 relation = relation[:1] 

99 # Run the query and extract the region, if the query has any results. 

100 for row in context.fetch_iterable(relation): 

101 return row[region_tag] 

102 return None 

103 

104 

105class ObsCoreLiveTableManager(ObsCoreTableManager): 

106 """A manager class for ObsCore table, implements methods for updating the 

107 records in that table. 

108 """ 

109 

110 def __init__( 

111 self, 

112 *, 

113 db: Database, 

114 table: sqlalchemy.schema.Table, 

115 schema: ObsCoreSchema, 

116 universe: DimensionUniverse, 

117 config: ObsCoreManagerConfig, 

118 dimensions: DimensionRecordStorageManager, 

119 spatial_plugins: Collection[SpatialObsCorePlugin], 

120 ): 

121 self.db = db 

122 self.table = table 

123 self.schema = schema 

124 self.universe = universe 

125 self.config = config 

126 self.spatial_plugins = spatial_plugins 

127 exposure_region_factory = _ExposureRegionFactory(dimensions) 

128 self.record_factory = RecordFactory( 

129 config, schema, universe, spatial_plugins, exposure_region_factory 

130 ) 

131 self.tagged_collection: Optional[str] = None 

132 self.run_patterns: list[re.Pattern] = [] 

133 if config.collection_type is ConfigCollectionType.TAGGED: 

134 assert ( 

135 config.collections is not None and len(config.collections) == 1 

136 ), "Exactly one collection name required for tagged type." 

137 self.tagged_collection = config.collections[0] 

138 elif config.collection_type is ConfigCollectionType.RUN: 

139 if config.collections: 

140 for coll in config.collections: 

141 try: 

142 self.run_patterns.append(re.compile(coll)) 

143 except re.error as exc: 

144 raise ValueError(f"Failed to compile regex: {coll!r}") from exc 

145 else: 

146 raise ValueError(f"Unexpected value of collection_type: {config.collection_type}") 

147 

148 @classmethod 

149 def initialize( 

150 cls, 

151 db: Database, 

152 context: StaticTablesContext, 

153 *, 

154 universe: DimensionUniverse, 

155 config: Mapping, 

156 datasets: Type[DatasetRecordStorageManager], 

157 dimensions: DimensionRecordStorageManager, 

158 ) -> ObsCoreTableManager: 

159 # Docstring inherited from base class. 

160 config_data = Config(config) 

161 obscore_config = ObsCoreManagerConfig.parse_obj(config_data) 

162 

163 # Instantiate all spatial plugins. 

164 spatial_plugins = SpatialObsCorePlugin.load_plugins(obscore_config.spatial_plugins, db) 

165 

166 schema = ObsCoreSchema(config=obscore_config, spatial_plugins=spatial_plugins, datasets=datasets) 

167 

168 # Generate table specification for main obscore table. 

169 table_spec = schema.table_spec 

170 for plugin in spatial_plugins: 

171 plugin.extend_table_spec(table_spec) 

172 table = context.addTable(obscore_config.table_name, schema.table_spec) 

173 

174 return ObsCoreLiveTableManager( 

175 db=db, 

176 table=table, 

177 schema=schema, 

178 universe=universe, 

179 config=obscore_config, 

180 dimensions=dimensions, 

181 spatial_plugins=spatial_plugins, 

182 ) 

183 

184 def config_json(self) -> str: 

185 """Dump configuration in JSON format. 

186 

187 Returns 

188 ------- 

189 json : `str` 

190 Configuration serialized in JSON format. 

191 """ 

192 return json.dumps(self.config.dict()) 

193 

194 @classmethod 

195 def currentVersion(cls) -> Optional[VersionTuple]: 

196 # Docstring inherited from base class. 

197 return _VERSION 

198 

199 def schemaDigest(self) -> Optional[str]: 

200 # Docstring inherited from base class. 

201 return None 

202 

203 def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

204 # Docstring inherited from base class. 

205 

206 # Only makes sense for RUN collection types 

207 if self.config.collection_type is not ConfigCollectionType.RUN: 

208 return 0 

209 

210 obscore_refs: Iterable[DatasetRef] 

211 if self.run_patterns: 

212 # Check each dataset run against configured run list. We want to 

213 # reduce number of calls to _check_dataset_run, which may be 

214 # expensive. Normally references are grouped by run, if there are 

215 # multiple input references, they should have the same run. 

216 # Instead of just checking that, we group them by run again. 

217 refs_by_run: Dict[str, List[DatasetRef]] = defaultdict(list) 

218 for ref in refs: 

219 

220 # Record factory will filter dataset types, but to reduce 

221 # collection checks we also pre-filter it here. 

222 if ref.datasetType.name not in self.config.dataset_types: 

223 continue 

224 

225 assert ref.run is not None, "Run cannot be None" 

226 refs_by_run[ref.run].append(ref) 

227 

228 good_refs: List[DatasetRef] = [] 

229 for run, run_refs in refs_by_run.items(): 

230 if not self._check_dataset_run(run): 

231 continue 

232 good_refs.extend(run_refs) 

233 obscore_refs = good_refs 

234 

235 else: 

236 

237 # Take all refs, no collection check. 

238 obscore_refs = refs 

239 

240 return self._populate(obscore_refs, context) 

241 

242 def associate( 

243 self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext 

244 ) -> int: 

245 # Docstring inherited from base class. 

246 

247 # Only works when collection type is TAGGED 

248 if self.tagged_collection is None: 

249 return 0 

250 

251 if collection.name == self.tagged_collection: 

252 return self._populate(refs, context) 

253 else: 

254 return 0 

255 

256 def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int: 

257 # Docstring inherited from base class. 

258 

259 # Only works when collection type is TAGGED 

260 if self.tagged_collection is None: 

261 return 0 

262 

263 count = 0 

264 if collection.name == self.tagged_collection: 

265 

266 # Sorting may improve performance 

267 dataset_ids = sorted(cast(uuid.UUID, ref.id) for ref in refs) 

268 if dataset_ids: 

269 fk_field = self.schema.dataset_fk 

270 assert fk_field is not None, "Cannot be None by construction" 

271 # There may be too many of them, do it in chunks. 

272 for ids in chunk_iterable(dataset_ids): 

273 where = self.table.columns[fk_field.name].in_(ids) 

274 count += self.db.deleteWhere(self.table, where) 

275 return count 

276 

277 def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int: 

278 """Populate obscore table with the data from given datasets.""" 

279 records: List[Record] = [] 

280 for ref in refs: 

281 record = self.record_factory(ref, context) 

282 if record is not None: 

283 records.append(record) 

284 

285 if records: 

286 # Ignore potential conflicts with existing datasets. 

287 return self.db.ensure(self.table, *records, primary_key_only=True) 

288 else: 

289 return 0 

290 

291 def _check_dataset_run(self, run: str) -> bool: 

292 """Check that specified run collection matches know patterns.""" 

293 

294 if not self.run_patterns: 

295 # Empty list means take anything. 

296 return True 

297 

298 # Try each pattern in turn. 

299 return any(pattern.fullmatch(run) for pattern in self.run_patterns)