Coverage for python/lsst/daf/butler/registry/queries/_sql_query_backend.py: 19%

110 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-07 11:04 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("SqlQueryBackend",) 

30 

31from collections.abc import Iterable, Mapping, Sequence, Set 

32from contextlib import AbstractContextManager 

33from typing import TYPE_CHECKING, Any, cast 

34 

35from lsst.daf.relation import ColumnError, ColumnExpression, ColumnTag, Join, Predicate, Relation 

36 

37from ..._column_categorization import ColumnCategorization 

38from ..._column_tags import DimensionKeyColumnTag, DimensionRecordColumnTag 

39from ..._dataset_type import DatasetType 

40from ...dimensions import DimensionGroup, DimensionRecordSet, DimensionUniverse 

41from ...dimensions.record_cache import DimensionRecordCache 

42from .._collection_type import CollectionType 

43from .._exceptions import DataIdValueError 

44from ..interfaces import CollectionRecord, Database 

45from ._query_backend import QueryBackend 

46from ._sql_query_context import SqlQueryContext 

47 

48if TYPE_CHECKING: 

49 from ..managers import RegistryManagerInstances 

50 

51 

52class SqlQueryBackend(QueryBackend[SqlQueryContext]): 

53 """An implementation of `QueryBackend` for `SqlRegistry`. 

54 

55 Parameters 

56 ---------- 

57 db : `Database` 

58 Object that abstracts the database engine. 

59 managers : `RegistryManagerInstances` 

60 Struct containing the manager objects that back a `SqlRegistry`. 

61 dimension_record_cache : `DimensionRecordCache` 

62 Cache of all records for dimension elements with 

63 `~DimensionElement.is_cached` `True`. 

64 """ 

65 

66 def __init__( 

67 self, db: Database, managers: RegistryManagerInstances, dimension_record_cache: DimensionRecordCache 

68 ): 

69 self._db = db 

70 self._managers = managers 

71 self._dimension_record_cache = dimension_record_cache 

72 

73 @property 

74 def universe(self) -> DimensionUniverse: 

75 # Docstring inherited. 

76 return self._managers.dimensions.universe 

77 

78 def caching_context(self) -> AbstractContextManager[None]: 

79 # Docstring inherited. 

80 return self._managers.caching_context_manager() 

81 

82 def context(self) -> SqlQueryContext: 

83 # Docstring inherited. 

84 return SqlQueryContext(self._db, self._managers.column_types) 

85 

86 def get_collection_name(self, key: Any) -> str: 

87 assert ( 

88 self._managers.caching_context.is_enabled 

89 ), "Collection-record caching should already been enabled any time this is called." 

90 return self._managers.collections[key].name 

91 

92 def resolve_collection_wildcard( 

93 self, 

94 expression: Any, 

95 *, 

96 collection_types: Set[CollectionType] = CollectionType.all(), 

97 done: set[str] | None = None, 

98 flatten_chains: bool = True, 

99 include_chains: bool | None = None, 

100 ) -> list[CollectionRecord]: 

101 # Docstring inherited. 

102 return self._managers.collections.resolve_wildcard( 

103 expression, 

104 collection_types=collection_types, 

105 done=done, 

106 flatten_chains=flatten_chains, 

107 include_chains=include_chains, 

108 ) 

109 

110 def resolve_dataset_type_wildcard( 

111 self, 

112 expression: Any, 

113 missing: list[str] | None = None, 

114 explicit_only: bool = False, 

115 ) -> list[DatasetType]: 

116 # Docstring inherited. 

117 return self._managers.datasets.resolve_wildcard( 

118 expression, 

119 missing, 

120 explicit_only, 

121 ) 

122 

123 def filter_dataset_collections( 

124 self, 

125 dataset_types: Iterable[DatasetType], 

126 collections: Sequence[CollectionRecord], 

127 *, 

128 governor_constraints: Mapping[str, Set[str]], 

129 rejections: list[str] | None = None, 

130 ) -> dict[DatasetType, list[CollectionRecord]]: 

131 # Docstring inherited. 

132 result: dict[DatasetType, list[CollectionRecord]] = { 

133 dataset_type: [] for dataset_type in dataset_types 

134 } 

135 summaries = self._managers.datasets.fetch_summaries(collections, result.keys()) 

136 for dataset_type, filtered_collections in result.items(): 

137 for collection_record in collections: 

138 if not dataset_type.isCalibration() and collection_record.type is CollectionType.CALIBRATION: 

139 if rejections is not None: 

140 rejections.append( 

141 f"Not searching for non-calibration dataset of type {dataset_type.name!r} " 

142 f"in CALIBRATION collection {collection_record.name!r}." 

143 ) 

144 else: 

145 collection_summary = summaries[collection_record.key] 

146 if collection_summary.is_compatible_with( 

147 dataset_type, 

148 governor_constraints, 

149 rejections=rejections, 

150 name=collection_record.name, 

151 ): 

152 filtered_collections.append(collection_record) 

153 return result 

154 

155 def _make_dataset_query_relation_impl( 

156 self, 

157 dataset_type: DatasetType, 

158 collections: Sequence[CollectionRecord], 

159 columns: Set[str], 

160 context: SqlQueryContext, 

161 ) -> Relation: 

162 # Docstring inherited. 

163 assert len(collections) > 0, ( 

164 "Caller is responsible for handling the case of all collections being rejected (we can't " 

165 "write a good error message without knowing why collections were rejected)." 

166 ) 

167 dataset_storage = self._managers.datasets.find(dataset_type.name) 

168 if dataset_storage is None: 

169 # Unrecognized dataset type means no results. 

170 return self.make_doomed_dataset_relation( 

171 dataset_type, 

172 columns, 

173 messages=[ 

174 f"Dataset type {dataset_type.name!r} is not registered, " 

175 "so no instances of it can exist in any collection." 

176 ], 

177 context=context, 

178 ) 

179 else: 

180 return dataset_storage.make_relation( 

181 *collections, 

182 columns=columns, 

183 context=context, 

184 ) 

185 

186 def make_dimension_relation( 

187 self, 

188 dimensions: DimensionGroup, 

189 columns: Set[ColumnTag], 

190 context: SqlQueryContext, 

191 *, 

192 initial_relation: Relation | None = None, 

193 initial_join_max_columns: frozenset[ColumnTag] | None = None, 

194 initial_dimension_relationships: Set[frozenset[str]] | None = None, 

195 spatial_joins: Iterable[tuple[str, str]] = (), 

196 governor_constraints: Mapping[str, Set[str]], 

197 ) -> Relation: 

198 # Docstring inherited. 

199 

200 default_join = Join(max_columns=initial_join_max_columns) 

201 

202 # Set up the relation variable we'll update as we join more relations 

203 # in, and ensure it is in the SQL engine. 

204 relation = context.make_initial_relation(initial_relation) 

205 

206 if initial_dimension_relationships is None: 

207 relationships = self.extract_dimension_relationships(relation) 

208 else: 

209 relationships = set(initial_dimension_relationships) 

210 

211 # Make a mutable copy of the columns argument. 

212 columns_required = set(columns) 

213 

214 # Sort spatial joins to put those involving the commonSkyPix dimension 

215 # first, since those join subqueries might get reused in implementing 

216 # other joins later. 

217 spatial_joins = list(spatial_joins) 

218 spatial_joins.sort(key=lambda j: self.universe.commonSkyPix.name not in j) 

219 

220 # Next we'll handle spatial joins, since those can require refinement 

221 # predicates that will need region columns to be included in the 

222 # relations we'll join. 

223 predicate: Predicate = Predicate.literal(True) 

224 for element1, element2 in spatial_joins: 

225 (overlaps, needs_refinement) = self._managers.dimensions.make_spatial_join_relation( 

226 element1, 

227 element2, 

228 context=context, 

229 existing_relationships=relationships, 

230 ) 

231 if needs_refinement: 

232 predicate = predicate.logical_and( 

233 context.make_spatial_region_overlap_predicate( 

234 ColumnExpression.reference(DimensionRecordColumnTag(element1, "region")), 

235 ColumnExpression.reference(DimensionRecordColumnTag(element2, "region")), 

236 ) 

237 ) 

238 columns_required.add(DimensionRecordColumnTag(element1, "region")) 

239 columns_required.add(DimensionRecordColumnTag(element2, "region")) 

240 relation = relation.join(overlaps) 

241 relationships.add( 

242 frozenset(self.universe[element1].dimensions.names | self.universe[element2].dimensions.names) 

243 ) 

244 

245 # All skypix columns need to come from either the initial_relation or a 

246 # spatial join, since we need all dimension key columns present in the 

247 # SQL engine and skypix regions are added by postprocessing in the 

248 # native iteration engine. 

249 for skypix_dimension_name in dimensions.skypix: 

250 if DimensionKeyColumnTag(skypix_dimension_name) not in relation.columns: 

251 raise NotImplementedError( 

252 f"Cannot construct query involving skypix dimension {skypix_dimension_name} unless " 

253 "it is part of a dataset subquery, spatial join, or other initial relation." 

254 ) 

255 

256 # Before joining in new tables to provide columns, attempt to restore 

257 # them from the given relation by weakening projections applied to it. 

258 relation, _ = context.restore_columns(relation, columns_required) 

259 

260 # Categorize columns not yet included in the relation to associate them 

261 # with dimension elements and detect bad inputs. 

262 missing_columns = ColumnCategorization.from_iterable(columns_required - relation.columns) 

263 if not (missing_columns.dimension_keys <= dimensions.names): 

264 raise ColumnError( 

265 "Cannot add dimension key column(s) " 

266 f"{{{', '.join(name for name in missing_columns.dimension_keys)}}} " 

267 f"that were not included in the given dimensions {dimensions}." 

268 ) 

269 if missing_columns.datasets: 

270 raise ColumnError( 

271 f"Unexpected dataset columns {missing_columns.datasets} in call to make_dimension_relation; " 

272 "use make_dataset_query_relation or make_dataset_search relation instead, or filter them " 

273 "out if they have already been added or will be added later." 

274 ) 

275 for element_name in missing_columns.dimension_records: 

276 if element_name not in dimensions.elements.names: 

277 raise ColumnError( 

278 f"Cannot join dimension element {element_name} whose dimensions are not a " 

279 f"subset of {dimensions}." 

280 ) 

281 

282 # Iterate over all dimension elements whose relations definitely have 

283 # to be joined in. The order doesn't matter as long as we can assume 

284 # the database query optimizer is going to try to reorder them anyway. 

285 for element_name in dimensions.elements: 

286 columns_still_needed = missing_columns.dimension_records[element_name] 

287 element = self.universe[element_name] 

288 # Two separate conditions in play here: 

289 # - if we need a record column (not just key columns) from this 

290 # element, we have to join in its relation; 

291 # - if the element establishes a relationship between key columns 

292 # that wasn't already established by the initial relation, we 

293 # always join that element's relation. Any element with 

294 # implied dependencies or the alwaysJoin flag establishes such a 

295 # relationship. 

296 if columns_still_needed or ( 

297 element.defines_relationships and frozenset(element.dimensions.names) not in relationships 

298 ): 

299 relation = self._managers.dimensions.join(element_name, relation, default_join, context) 

300 # At this point we've joined in all of the element relations that 

301 # definitely need to be included, but we may not have all of the 

302 # dimension key columns in the query that we want. To fill out that 

303 # set, we iterate over just the given DimensionGroup's dimensions (not 

304 # all dimension *elements*) in reverse topological order. That order 

305 # should reduce the total number of tables we bring in, since each 

306 # dimension will bring in keys for its required dependencies before we 

307 # get to those required dependencies. 

308 for dimension_name in reversed(dimensions.names.as_tuple()): 

309 if DimensionKeyColumnTag(dimension_name) not in relation.columns: 

310 relation = self._managers.dimensions.join(dimension_name, relation, default_join, context) 

311 

312 # Add the predicates we constructed earlier, with a transfer to native 

313 # iteration first if necessary. 

314 if not predicate.as_trivial(): 

315 relation = relation.with_rows_satisfying( 

316 predicate, preferred_engine=context.iteration_engine, transfer=True 

317 ) 

318 

319 # Finally project the new relation down to just the columns in the 

320 # initial relation, the dimension key columns, and the new columns 

321 # requested. 

322 columns_kept = set(columns) 

323 if initial_relation is not None: 

324 columns_kept.update(initial_relation.columns) 

325 columns_kept.update(DimensionKeyColumnTag.generate(dimensions.names)) 

326 relation = relation.with_only_columns(columns_kept, preferred_engine=context.preferred_engine) 

327 

328 return relation 

329 

330 def resolve_governor_constraints( 

331 self, dimensions: DimensionGroup, constraints: Mapping[str, Set[str]] 

332 ) -> Mapping[str, Set[str]]: 

333 # Docstring inherited. 

334 result: dict[str, Set[str]] = {} 

335 for dimension_name in dimensions.governors: 

336 all_values = { 

337 cast(str, record.dataId[dimension_name]) 

338 for record in self._dimension_record_cache[dimension_name] 

339 } 

340 if (constraint_values := constraints.get(dimension_name)) is not None: 

341 if not (constraint_values <= all_values): 

342 raise DataIdValueError( 

343 f"Unknown values specified for governor dimension {dimension_name}: " 

344 f"{constraint_values - all_values}." 

345 ) 

346 result[dimension_name] = constraint_values 

347 else: 

348 result[dimension_name] = all_values 

349 return result 

350 

351 def get_dimension_record_cache(self, element_name: str) -> DimensionRecordSet | None: 

352 return ( 

353 self._dimension_record_cache[element_name] 

354 if element_name in self._dimension_record_cache 

355 else None 

356 )