Coverage for python/lsst/daf/butler/registry/queries/_sql_query_backend.py: 16%

103 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-12 09:20 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("SqlQueryBackend",) 

24 

25from collections.abc import Iterable, Mapping, Sequence, Set 

26from typing import TYPE_CHECKING, Any, cast 

27 

28from lsst.daf.relation import ColumnError, ColumnExpression, ColumnTag, Join, Predicate, Relation 

29 

30from ...core import ( 

31 ColumnCategorization, 

32 DataCoordinate, 

33 DatasetType, 

34 DimensionGraph, 

35 DimensionKeyColumnTag, 

36 DimensionRecord, 

37 DimensionRecordColumnTag, 

38 DimensionUniverse, 

39 SkyPixDimension, 

40) 

41from .._collectionType import CollectionType 

42from .._exceptions import DataIdValueError 

43from ..interfaces import CollectionRecord, Database 

44from ._query_backend import QueryBackend 

45from ._sql_query_context import SqlQueryContext 

46 

47if TYPE_CHECKING: 

48 from ..managers import RegistryManagerInstances 

49 

50 

51class SqlQueryBackend(QueryBackend[SqlQueryContext]): 

52 """An implementation of `QueryBackend` for `SqlRegistry`. 

53 

54 Parameters 

55 ---------- 

56 db : `Database` 

57 Object that abstracts the database engine. 

58 managers : `RegistryManagerInstances` 

59 Struct containing the manager objects that back a `SqlRegistry`. 

60 """ 

61 

62 def __init__( 

63 self, 

64 db: Database, 

65 managers: RegistryManagerInstances, 

66 ): 

67 self._db = db 

68 self._managers = managers 

69 

70 @property 

71 def universe(self) -> DimensionUniverse: 

72 # Docstring inherited. 

73 return self._managers.dimensions.universe 

74 

75 def context(self) -> SqlQueryContext: 

76 # Docstring inherited. 

77 return SqlQueryContext(self._db, self._managers.column_types) 

78 

79 def get_collection_name(self, key: Any) -> str: 

80 return self._managers.collections[key].name 

81 

82 def resolve_collection_wildcard( 

83 self, 

84 expression: Any, 

85 *, 

86 collection_types: Set[CollectionType] = CollectionType.all(), 

87 done: set[str] | None = None, 

88 flatten_chains: bool = True, 

89 include_chains: bool | None = None, 

90 ) -> list[CollectionRecord]: 

91 # Docstring inherited. 

92 return self._managers.collections.resolve_wildcard( 

93 expression, 

94 collection_types=collection_types, 

95 done=done, 

96 flatten_chains=flatten_chains, 

97 include_chains=include_chains, 

98 ) 

99 

100 def resolve_dataset_type_wildcard( 

101 self, 

102 expression: Any, 

103 components: bool | None = None, 

104 missing: list[str] | None = None, 

105 explicit_only: bool = False, 

106 components_deprecated: bool = True, 

107 ) -> dict[DatasetType, list[str | None]]: 

108 # Docstring inherited. 

109 return self._managers.datasets.resolve_wildcard( 

110 expression, components, missing, explicit_only, components_deprecated 

111 ) 

112 

113 def filter_dataset_collections( 

114 self, 

115 dataset_types: Iterable[DatasetType], 

116 collections: Sequence[CollectionRecord], 

117 *, 

118 governor_constraints: Mapping[str, Set[str]], 

119 rejections: list[str] | None = None, 

120 ) -> dict[DatasetType, list[CollectionRecord]]: 

121 # Docstring inherited. 

122 result: dict[DatasetType, list[CollectionRecord]] = { 

123 dataset_type: [] for dataset_type in dataset_types 

124 } 

125 for dataset_type, filtered_collections in result.items(): 

126 for collection_record in collections: 

127 if not dataset_type.isCalibration() and collection_record.type is CollectionType.CALIBRATION: 

128 if rejections is not None: 

129 rejections.append( 

130 f"Not searching for non-calibration dataset of type {dataset_type.name!r} " 

131 f"in CALIBRATION collection {collection_record.name!r}." 

132 ) 

133 else: 

134 collection_summary = self._managers.datasets.getCollectionSummary(collection_record) 

135 if collection_summary.is_compatible_with( 

136 dataset_type, 

137 governor_constraints, 

138 rejections=rejections, 

139 name=collection_record.name, 

140 ): 

141 filtered_collections.append(collection_record) 

142 return result 

143 

144 def make_dataset_query_relation( 

145 self, 

146 dataset_type: DatasetType, 

147 collections: Sequence[CollectionRecord], 

148 columns: Set[str], 

149 context: SqlQueryContext, 

150 ) -> Relation: 

151 # Docstring inherited. 

152 assert len(collections) > 0, ( 

153 "Caller is responsible for handling the case of all collections being rejected (we can't " 

154 "write a good error message without knowing why collections were rejected)." 

155 ) 

156 dataset_storage = self._managers.datasets.find(dataset_type.name) 

157 if dataset_storage is None: 

158 # Unrecognized dataset type means no results. 

159 return self.make_doomed_dataset_relation( 

160 dataset_type, 

161 columns, 

162 messages=[ 

163 f"Dataset type {dataset_type.name!r} is not registered, " 

164 "so no instances of it can exist in any collection." 

165 ], 

166 context=context, 

167 ) 

168 else: 

169 return dataset_storage.make_relation( 

170 *collections, 

171 columns=columns, 

172 context=context, 

173 ) 

174 

175 def make_dimension_relation( 

176 self, 

177 dimensions: DimensionGraph, 

178 columns: Set[ColumnTag], 

179 context: SqlQueryContext, 

180 *, 

181 initial_relation: Relation | None = None, 

182 initial_join_max_columns: frozenset[ColumnTag] | None = None, 

183 initial_dimension_relationships: Set[frozenset[str]] | None = None, 

184 spatial_joins: Iterable[tuple[str, str]] = (), 

185 governor_constraints: Mapping[str, Set[str]], 

186 ) -> Relation: 

187 # Docstring inherited. 

188 

189 default_join = Join(max_columns=initial_join_max_columns) 

190 

191 # Set up the relation variable we'll update as we join more relations 

192 # in, and ensure it is in the SQL engine. 

193 relation = context.make_initial_relation(initial_relation) 

194 

195 if initial_dimension_relationships is None: 

196 relationships = self.extract_dimension_relationships(relation) 

197 else: 

198 relationships = set(initial_dimension_relationships) 

199 

200 # Make a mutable copy of the columns argument. 

201 columns_required = set(columns) 

202 

203 # Sort spatial joins to put those involving the commonSkyPix dimension 

204 # first, since those join subqueries might get reused in implementing 

205 # other joins later. 

206 spatial_joins = list(spatial_joins) 

207 spatial_joins.sort(key=lambda j: self.universe.commonSkyPix.name not in j) 

208 

209 # Next we'll handle spatial joins, since those can require refinement 

210 # predicates that will need region columns to be included in the 

211 # relations we'll join. 

212 predicate: Predicate = Predicate.literal(True) 

213 for element1, element2 in spatial_joins: 

214 (overlaps, needs_refinement) = self._managers.dimensions.make_spatial_join_relation( 

215 element1, 

216 element2, 

217 context=context, 

218 governor_constraints=governor_constraints, 

219 existing_relationships=relationships, 

220 ) 

221 if needs_refinement: 

222 predicate = predicate.logical_and( 

223 context.make_spatial_region_overlap_predicate( 

224 ColumnExpression.reference(DimensionRecordColumnTag(element1, "region")), 

225 ColumnExpression.reference(DimensionRecordColumnTag(element2, "region")), 

226 ) 

227 ) 

228 columns_required.add(DimensionRecordColumnTag(element1, "region")) 

229 columns_required.add(DimensionRecordColumnTag(element2, "region")) 

230 relation = relation.join(overlaps) 

231 relationships.add( 

232 frozenset(self.universe[element1].dimensions.names | self.universe[element2].dimensions.names) 

233 ) 

234 

235 # All skypix columns need to come from either the initial_relation or a 

236 # spatial join, since we need all dimension key columns present in the 

237 # SQL engine and skypix regions are added by postprocessing in the 

238 # native iteration engine. 

239 for dimension in dimensions: 

240 if DimensionKeyColumnTag(dimension.name) not in relation.columns and isinstance( 

241 dimension, SkyPixDimension 

242 ): 

243 raise NotImplementedError( 

244 f"Cannot construct query involving skypix dimension {dimension.name} unless " 

245 "it is part of a dataset subquery, spatial join, or other initial relation." 

246 ) 

247 

248 # Categorize columns not yet included in the relation to associate them 

249 # with dimension elements and detect bad inputs. 

250 missing_columns = ColumnCategorization.from_iterable(columns_required - relation.columns) 

251 if not (missing_columns.dimension_keys <= dimensions.names): 

252 raise ColumnError( 

253 "Cannot add dimension key column(s) " 

254 f"{{{', '.join(name for name in missing_columns.dimension_keys)}}} " 

255 f"that were not included in the given dimensions {dimensions}." 

256 ) 

257 if missing_columns.datasets: 

258 raise ColumnError( 

259 f"Unexpected dataset columns {missing_columns.datasets} in call to make_dimension_relation; " 

260 "use make_dataset_query_relation or make_dataset_search relation instead, or filter them " 

261 "out if they have already been added or will be added later." 

262 ) 

263 for element_name in missing_columns.dimension_records: 

264 if element_name not in dimensions.elements.names: 

265 raise ColumnError( 

266 f"Cannot join dimension element {element_name} whose dimensions are not a " 

267 f"subset of {dimensions}." 

268 ) 

269 

270 # Iterate over all dimension elements whose relations definitely have 

271 # to be joined in. The order doesn't matter as long as we can assume 

272 # the database query optimizer is going to try to reorder them anyway. 

273 for element in dimensions.elements: 

274 columns_still_needed = missing_columns.dimension_records[element.name] 

275 # Two separate conditions in play here: 

276 # - if we need a record column (not just key columns) from this 

277 # element, we have to join in its relation; 

278 # - if the element establishes a relationship between key columns 

279 # that wasn't already established by the initial relation, we 

280 # always join that element's relation. Any element with 

281 # implied dependencies or the alwaysJoin flag establishes such a 

282 # relationship. 

283 if columns_still_needed or ( 

284 (element.alwaysJoin or element.implied) 

285 and frozenset(element.dimensions.names) not in relationships 

286 ): 

287 storage = self._managers.dimensions[element] 

288 relation = storage.join(relation, default_join, context) 

289 # At this point we've joined in all of the element relations that 

290 # definitely need to be included, but we may not have all of the 

291 # dimension key columns in the query that we want. To fill out that 

292 # set, we iterate over just the given DimensionGraph's dimensions (not 

293 # all dimension *elements*) in reverse topological order. That order 

294 # should reduce the total number of tables we bring in, since each 

295 # dimension will bring in keys for its required dependencies before we 

296 # get to those required dependencies. 

297 for dimension in self.universe.sorted(dimensions, reverse=True): 

298 if DimensionKeyColumnTag(dimension.name) not in relation.columns: 

299 storage = self._managers.dimensions[dimension] 

300 relation = storage.join(relation, default_join, context) 

301 

302 # Add the predicates we constructed earlier, with a transfer to native 

303 # iteration first if necessary. 

304 if not predicate.as_trivial(): 

305 relation = relation.with_rows_satisfying( 

306 predicate, preferred_engine=context.iteration_engine, transfer=True 

307 ) 

308 

309 # Finally project the new relation down to just the columns in the 

310 # initial relation, the dimension key columns, and the new columns 

311 # requested. 

312 columns_kept = set(columns) 

313 if initial_relation is not None: 

314 columns_kept.update(initial_relation.columns) 

315 columns_kept.update(DimensionKeyColumnTag.generate(dimensions.names)) 

316 relation = relation.with_only_columns(columns_kept, preferred_engine=context.preferred_engine) 

317 

318 return relation 

319 

320 def resolve_governor_constraints( 

321 self, dimensions: DimensionGraph, constraints: Mapping[str, Set[str]], context: SqlQueryContext 

322 ) -> Mapping[str, Set[str]]: 

323 # Docstring inherited. 

324 result: dict[str, Set[str]] = {} 

325 for dimension in dimensions.governors: 

326 storage = self._managers.dimensions[dimension] 

327 records = storage.get_record_cache(context) 

328 assert records is not None, "Governor dimensions are always cached." 

329 all_values = {cast(str, data_id[dimension.name]) for data_id in records} 

330 if (constraint_values := constraints.get(dimension.name)) is not None: 

331 if not (constraint_values <= all_values): 

332 raise DataIdValueError( 

333 f"Unknown values specified for governor dimension {dimension.name}: " 

334 f"{constraint_values - all_values}." 

335 ) 

336 result[dimension.name] = constraint_values 

337 else: 

338 result[dimension.name] = all_values 

339 return result 

340 

341 def get_dimension_record_cache( 

342 self, 

343 element_name: str, 

344 context: SqlQueryContext, 

345 ) -> Mapping[DataCoordinate, DimensionRecord] | None: 

346 return self._managers.dimensions[element_name].get_record_cache(context)