Coverage for python/lsst/daf/butler/registry/queries/_sql_query_backend.py: 19%
110 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-01 11:20 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-01 11:20 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("SqlQueryBackend",)
31from collections.abc import Iterable, Mapping, Sequence, Set
32from contextlib import AbstractContextManager
33from typing import TYPE_CHECKING, Any, cast
35from lsst.daf.relation import ColumnError, ColumnExpression, ColumnTag, Join, Predicate, Relation
37from ..._column_categorization import ColumnCategorization
38from ..._column_tags import DimensionKeyColumnTag, DimensionRecordColumnTag
39from ..._dataset_type import DatasetType
40from ...dimensions import DimensionGroup, DimensionRecordSet, DimensionUniverse
41from ...dimensions.record_cache import DimensionRecordCache
42from .._collection_type import CollectionType
43from .._exceptions import DataIdValueError
44from ..interfaces import CollectionRecord, Database
45from ._query_backend import QueryBackend
46from ._sql_query_context import SqlQueryContext
48if TYPE_CHECKING:
49 from ..managers import RegistryManagerInstances
52class SqlQueryBackend(QueryBackend[SqlQueryContext]):
53 """An implementation of `QueryBackend` for `SqlRegistry`.
55 Parameters
56 ----------
57 db : `Database`
58 Object that abstracts the database engine.
59 managers : `RegistryManagerInstances`
60 Struct containing the manager objects that back a `SqlRegistry`.
61 dimension_record_cache : `DimensionRecordCache`
62 Cache of all records for dimension elements with
63 `~DimensionElement.is_cached` `True`.
64 """
66 def __init__(
67 self, db: Database, managers: RegistryManagerInstances, dimension_record_cache: DimensionRecordCache
68 ):
69 self._db = db
70 self._managers = managers
71 self._dimension_record_cache = dimension_record_cache
73 @property
74 def universe(self) -> DimensionUniverse:
75 # Docstring inherited.
76 return self._managers.dimensions.universe
78 def caching_context(self) -> AbstractContextManager[None]:
79 # Docstring inherited.
80 return self._managers.caching_context_manager()
82 def context(self) -> SqlQueryContext:
83 # Docstring inherited.
84 return SqlQueryContext(self._db, self._managers.column_types)
86 def get_collection_name(self, key: Any) -> str:
87 assert (
88 self._managers.caching_context.is_enabled
89 ), "Collection-record caching should already been enabled any time this is called."
90 return self._managers.collections[key].name
92 def resolve_collection_wildcard(
93 self,
94 expression: Any,
95 *,
96 collection_types: Set[CollectionType] = CollectionType.all(),
97 done: set[str] | None = None,
98 flatten_chains: bool = True,
99 include_chains: bool | None = None,
100 ) -> list[CollectionRecord]:
101 # Docstring inherited.
102 return self._managers.collections.resolve_wildcard(
103 expression,
104 collection_types=collection_types,
105 done=done,
106 flatten_chains=flatten_chains,
107 include_chains=include_chains,
108 )
110 def resolve_dataset_type_wildcard(
111 self,
112 expression: Any,
113 missing: list[str] | None = None,
114 explicit_only: bool = False,
115 ) -> list[DatasetType]:
116 # Docstring inherited.
117 return self._managers.datasets.resolve_wildcard(
118 expression,
119 missing,
120 explicit_only,
121 )
123 def filter_dataset_collections(
124 self,
125 dataset_types: Iterable[DatasetType],
126 collections: Sequence[CollectionRecord],
127 *,
128 governor_constraints: Mapping[str, Set[str]],
129 rejections: list[str] | None = None,
130 ) -> dict[DatasetType, list[CollectionRecord]]:
131 # Docstring inherited.
132 result: dict[DatasetType, list[CollectionRecord]] = {
133 dataset_type: [] for dataset_type in dataset_types
134 }
135 summaries = self._managers.datasets.fetch_summaries(collections, result.keys())
136 for dataset_type, filtered_collections in result.items():
137 for collection_record in collections:
138 if not dataset_type.isCalibration() and collection_record.type is CollectionType.CALIBRATION:
139 if rejections is not None:
140 rejections.append(
141 f"Not searching for non-calibration dataset of type {dataset_type.name!r} "
142 f"in CALIBRATION collection {collection_record.name!r}."
143 )
144 else:
145 collection_summary = summaries[collection_record.key]
146 if collection_summary.is_compatible_with(
147 dataset_type,
148 governor_constraints,
149 rejections=rejections,
150 name=collection_record.name,
151 ):
152 filtered_collections.append(collection_record)
153 return result
155 def _make_dataset_query_relation_impl(
156 self,
157 dataset_type: DatasetType,
158 collections: Sequence[CollectionRecord],
159 columns: Set[str],
160 context: SqlQueryContext,
161 ) -> Relation:
162 # Docstring inherited.
163 assert len(collections) > 0, (
164 "Caller is responsible for handling the case of all collections being rejected (we can't "
165 "write a good error message without knowing why collections were rejected)."
166 )
167 dataset_storage = self._managers.datasets.find(dataset_type.name)
168 if dataset_storage is None:
169 # Unrecognized dataset type means no results.
170 return self.make_doomed_dataset_relation(
171 dataset_type,
172 columns,
173 messages=[
174 f"Dataset type {dataset_type.name!r} is not registered, "
175 "so no instances of it can exist in any collection."
176 ],
177 context=context,
178 )
179 else:
180 return dataset_storage.make_relation(
181 *collections,
182 columns=columns,
183 context=context,
184 )
186 def make_dimension_relation(
187 self,
188 dimensions: DimensionGroup,
189 columns: Set[ColumnTag],
190 context: SqlQueryContext,
191 *,
192 initial_relation: Relation | None = None,
193 initial_join_max_columns: frozenset[ColumnTag] | None = None,
194 initial_dimension_relationships: Set[frozenset[str]] | None = None,
195 spatial_joins: Iterable[tuple[str, str]] = (),
196 governor_constraints: Mapping[str, Set[str]],
197 ) -> Relation:
198 # Docstring inherited.
200 default_join = Join(max_columns=initial_join_max_columns)
202 # Set up the relation variable we'll update as we join more relations
203 # in, and ensure it is in the SQL engine.
204 relation = context.make_initial_relation(initial_relation)
206 if initial_dimension_relationships is None:
207 relationships = self.extract_dimension_relationships(relation)
208 else:
209 relationships = set(initial_dimension_relationships)
211 # Make a mutable copy of the columns argument.
212 columns_required = set(columns)
214 # Sort spatial joins to put those involving the commonSkyPix dimension
215 # first, since those join subqueries might get reused in implementing
216 # other joins later.
217 spatial_joins = list(spatial_joins)
218 spatial_joins.sort(key=lambda j: self.universe.commonSkyPix.name not in j)
220 # Next we'll handle spatial joins, since those can require refinement
221 # predicates that will need region columns to be included in the
222 # relations we'll join.
223 predicate: Predicate = Predicate.literal(True)
224 for element1, element2 in spatial_joins:
225 (overlaps, needs_refinement) = self._managers.dimensions.make_spatial_join_relation(
226 element1,
227 element2,
228 context=context,
229 existing_relationships=relationships,
230 )
231 if needs_refinement:
232 predicate = predicate.logical_and(
233 context.make_spatial_region_overlap_predicate(
234 ColumnExpression.reference(DimensionRecordColumnTag(element1, "region")),
235 ColumnExpression.reference(DimensionRecordColumnTag(element2, "region")),
236 )
237 )
238 columns_required.add(DimensionRecordColumnTag(element1, "region"))
239 columns_required.add(DimensionRecordColumnTag(element2, "region"))
240 relation = relation.join(overlaps)
241 relationships.add(
242 frozenset(self.universe[element1].dimensions.names | self.universe[element2].dimensions.names)
243 )
245 # All skypix columns need to come from either the initial_relation or a
246 # spatial join, since we need all dimension key columns present in the
247 # SQL engine and skypix regions are added by postprocessing in the
248 # native iteration engine.
249 for skypix_dimension_name in dimensions.skypix:
250 if DimensionKeyColumnTag(skypix_dimension_name) not in relation.columns:
251 raise NotImplementedError(
252 f"Cannot construct query involving skypix dimension {skypix_dimension_name} unless "
253 "it is part of a dataset subquery, spatial join, or other initial relation."
254 )
256 # Before joining in new tables to provide columns, attempt to restore
257 # them from the given relation by weakening projections applied to it.
258 relation, _ = context.restore_columns(relation, columns_required)
260 # Categorize columns not yet included in the relation to associate them
261 # with dimension elements and detect bad inputs.
262 missing_columns = ColumnCategorization.from_iterable(columns_required - relation.columns)
263 if not (missing_columns.dimension_keys <= dimensions.names):
264 raise ColumnError(
265 "Cannot add dimension key column(s) "
266 f"{{{', '.join(name for name in missing_columns.dimension_keys)}}} "
267 f"that were not included in the given dimensions {dimensions}."
268 )
269 if missing_columns.datasets:
270 raise ColumnError(
271 f"Unexpected dataset columns {missing_columns.datasets} in call to make_dimension_relation; "
272 "use make_dataset_query_relation or make_dataset_search relation instead, or filter them "
273 "out if they have already been added or will be added later."
274 )
275 for element_name in missing_columns.dimension_records:
276 if element_name not in dimensions.elements.names:
277 raise ColumnError(
278 f"Cannot join dimension element {element_name} whose dimensions are not a "
279 f"subset of {dimensions}."
280 )
282 # Iterate over all dimension elements whose relations definitely have
283 # to be joined in. The order doesn't matter as long as we can assume
284 # the database query optimizer is going to try to reorder them anyway.
285 for element_name in dimensions.elements:
286 columns_still_needed = missing_columns.dimension_records[element_name]
287 element = self.universe[element_name]
288 # Two separate conditions in play here:
289 # - if we need a record column (not just key columns) from this
290 # element, we have to join in its relation;
291 # - if the element establishes a relationship between key columns
292 # that wasn't already established by the initial relation, we
293 # always join that element's relation. Any element with
294 # implied dependencies or the alwaysJoin flag establishes such a
295 # relationship.
296 if columns_still_needed or (
297 element.defines_relationships and frozenset(element.dimensions.names) not in relationships
298 ):
299 relation = self._managers.dimensions.join(element_name, relation, default_join, context)
300 # At this point we've joined in all of the element relations that
301 # definitely need to be included, but we may not have all of the
302 # dimension key columns in the query that we want. To fill out that
303 # set, we iterate over just the given DimensionGroup's dimensions (not
304 # all dimension *elements*) in reverse topological order. That order
305 # should reduce the total number of tables we bring in, since each
306 # dimension will bring in keys for its required dependencies before we
307 # get to those required dependencies.
308 for dimension_name in reversed(dimensions.names.as_tuple()):
309 if DimensionKeyColumnTag(dimension_name) not in relation.columns:
310 relation = self._managers.dimensions.join(dimension_name, relation, default_join, context)
312 # Add the predicates we constructed earlier, with a transfer to native
313 # iteration first if necessary.
314 if not predicate.as_trivial():
315 relation = relation.with_rows_satisfying(
316 predicate, preferred_engine=context.iteration_engine, transfer=True
317 )
319 # Finally project the new relation down to just the columns in the
320 # initial relation, the dimension key columns, and the new columns
321 # requested.
322 columns_kept = set(columns)
323 if initial_relation is not None:
324 columns_kept.update(initial_relation.columns)
325 columns_kept.update(DimensionKeyColumnTag.generate(dimensions.names))
326 relation = relation.with_only_columns(columns_kept, preferred_engine=context.preferred_engine)
328 return relation
330 def resolve_governor_constraints(
331 self, dimensions: DimensionGroup, constraints: Mapping[str, Set[str]]
332 ) -> Mapping[str, Set[str]]:
333 # Docstring inherited.
334 result: dict[str, Set[str]] = {}
335 for dimension_name in dimensions.governors:
336 all_values = {
337 cast(str, record.dataId[dimension_name])
338 for record in self._dimension_record_cache[dimension_name]
339 }
340 if (constraint_values := constraints.get(dimension_name)) is not None:
341 if not (constraint_values <= all_values):
342 raise DataIdValueError(
343 f"Unknown values specified for governor dimension {dimension_name}: "
344 f"{constraint_values - all_values}."
345 )
346 result[dimension_name] = constraint_values
347 else:
348 result[dimension_name] = all_values
349 return result
351 def get_dimension_record_cache(self, element_name: str) -> DimensionRecordSet | None:
352 return (
353 self._dimension_record_cache[element_name]
354 if element_name in self._dimension_record_cache
355 else None
356 )