Coverage for python / lsst / daf / butler / queries / _general_query_results.py: 26%
109 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:30 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("GeneralQueryResults", "GeneralResultTuple")
32import itertools
33from collections.abc import Iterator
34from typing import Any, NamedTuple, final
36from .._dataset_ref import DatasetRef
37from .._dataset_type import DatasetType
38from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord, DimensionRecordSet
39from ._base import QueryResultsBase
40from .driver import QueryDriver
41from .result_specs import GeneralResultSpec
42from .tree import AnyDatasetFieldName, QueryTree, ResultColumn
45class GeneralResultTuple(NamedTuple):
46 """Helper class for general result that represents the result row as a
47 data coordinate and optionally a set of dataset refs extracted from a row.
48 """
50 data_id: DataCoordinate
51 """Data coordinate for current row."""
53 refs: list[DatasetRef]
54 """Dataset refs extracted from the current row, the order matches the order
55 of arguments in ``iter_tuples`` call."""
57 raw_row: dict[str, Any]
58 """Original result row, the keys are the names of the dimensions,
59 dimension fields (separated from dimension by dot) or dataset type fields
60 (separated from dataset type name by dot).
61 """
64@final
65class GeneralQueryResults(QueryResultsBase):
66 """A query for `DatasetRef` results with a single dataset type.
68 Parameters
69 ----------
70 driver : `QueryDriver`
71 Implementation object that knows how to actually execute queries.
72 tree : `QueryTree`
73 Description of the query as a tree of joins and column expressions. The
74 instance returned directly by the `Butler._query` entry point should be
75 constructed via `make_unit_query_tree`.
76 spec : `GeneralResultSpec`
77 Specification of the query result rows, including output columns,
78 ordering, and slicing.
80 Notes
81 -----
82 This class should never be constructed directly by users; use `Query`
83 methods instead.
84 """
86 def __init__(self, driver: QueryDriver, tree: QueryTree, spec: GeneralResultSpec):
87 spec.validate_tree(tree)
88 super().__init__(driver, tree)
89 self._spec = spec
91 def __iter__(self) -> Iterator[dict[str, Any]]:
92 """Iterate over result rows.
94 Yields
95 ------
96 row_dict : `dict` [`str`, `typing.Any`]
97 Result row as dictionary, the keys are the names of the dimensions,
98 dimension fields (separated from dimension by dot) or dataset type
99 fields (separated from dataset type name by dot).
100 """
101 for page in self._driver.execute(self._spec, self._tree):
102 columns = tuple(str(column) for column in page.spec.get_result_columns())
103 for row in page.rows:
104 result = dict(zip(columns, row, strict=True))
105 if page.dimension_records:
106 records = self._get_cached_dimension_records(result, page.dimension_records)
107 self._add_dimension_records(result, records)
108 yield result
110 def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]:
111 """Iterate over result rows and return data coordinate, and dataset
112 refs constructed from each row, and an original row.
114 This object has to include "dataset_id" and "run" columns for each type
115 in ``dataset_types``.
117 Parameters
118 ----------
119 *dataset_types : `DatasetType`
120 Zero or more types of the datasets to return.
122 Yields
123 ------
124 row_tuple : `GeneralResultTuple`
125 Structure containing data coordinate, refs, and a copy of the row.
126 """
127 all_dimensions = self._spec.dimensions
128 dataset_keys: list[tuple[DatasetType, DimensionGroup, str, str]] = []
129 for dataset_type in dataset_types:
130 dimensions = dataset_type.dimensions
131 id_key = f"{dataset_type.name}.dataset_id"
132 run_key = f"{dataset_type.name}.run"
133 dataset_keys.append((dataset_type, dimensions, id_key, run_key))
134 for page in self._driver.execute(self._spec, self._tree):
135 columns = tuple(str(column) for column in page.spec.get_result_columns())
136 for page_row in page.rows:
137 row = dict(zip(columns, page_row, strict=True))
138 if page.dimension_records:
139 cached_records = self._get_cached_dimension_records(row, page.dimension_records)
140 self._add_dimension_records(row, cached_records)
141 else:
142 cached_records = {}
143 data_coordinate = self._make_data_id(row, all_dimensions, cached_records)
144 refs = []
145 for dataset_type, dimensions, id_key, run_key in dataset_keys:
146 data_id = data_coordinate.subset(dimensions)
147 refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
148 yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)
150 @property
151 def dimensions(self) -> DimensionGroup:
152 # Docstring inherited
153 return self._spec.dimensions
155 @property
156 def has_dimension_records(self) -> bool:
157 """Whether all data IDs in this iterable contain dimension records."""
158 return self._spec.include_dimension_records
160 def with_dimension_records(self) -> GeneralQueryResults:
161 """Return a results object for which `has_dimension_records` is
162 `True`.
163 """
164 if self.has_dimension_records:
165 return self
166 return self._copy(tree=self._tree, include_dimension_records=True)
168 def count(self, *, exact: bool = True, discard: bool = False) -> int:
169 # Docstring inherited.
170 return self._driver.count(self._tree, self._spec, exact=exact, discard=discard)
172 def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults:
173 # Docstring inherited.
174 return GeneralQueryResults(self._driver, tree, self._spec.model_copy(update=kwargs))
176 def _with_added_dataset_field(self, dataset_type: str, field: AnyDatasetFieldName) -> GeneralQueryResults:
177 dataset_fields = dict(self._spec.dataset_fields)
178 field_set = set(dataset_fields.get(dataset_type, set()))
179 field_set.add(field)
180 dataset_fields[dataset_type] = field_set
182 return self._copy(self._tree, dataset_fields=dataset_fields)
184 def _get_datasets(self) -> frozenset[str]:
185 # Docstring inherited.
186 return frozenset(self._spec.dataset_fields)
188 def _make_data_id(
189 self,
190 row: dict[str, Any],
191 dimensions: DimensionGroup,
192 cached_row_records: dict[DimensionElement, DimensionRecord],
193 ) -> DataCoordinate:
194 values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
195 data_coordinate = DataCoordinate.from_full_values(dimensions, values)
196 if self.has_dimension_records:
197 records = {}
198 for name in dimensions.elements:
199 element = dimensions.universe[name]
200 record = cached_row_records.get(element)
201 if record is None:
202 record = self._make_dimension_record(row, dimensions.universe[name])
203 records[name] = record
204 data_coordinate = data_coordinate.expanded(records)
205 return data_coordinate
207 def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord:
208 column_map = list(
209 zip(
210 element.schema.dimensions.names,
211 element.dimensions.names,
212 )
213 )
214 for field in element.schema.remainder.names:
215 column_map.append((field, str(ResultColumn(element.name, field))))
216 d = {k: row[v] for k, v in column_map}
217 record_cls = element.RecordClass
218 return record_cls(**d)
220 def _get_cached_dimension_records(
221 self, row: dict[str, Any], dimension_records: dict[DimensionElement, DimensionRecordSet]
222 ) -> dict[DimensionElement, DimensionRecord]:
223 """Find cached dimension records matching this row."""
224 records = {}
225 for element, element_records in dimension_records.items():
226 required_values = tuple(row[key] for key in element.required.names)
227 records[element] = element_records.find_with_required_values(required_values)
228 return records
230 def _add_dimension_records(
231 self, row: dict[str, Any], records: dict[DimensionElement, DimensionRecord]
232 ) -> None:
233 """Extend row with the fields from cached dimension records."""
234 for element, record in records.items():
235 for name, value in record.toDict().items():
236 if name not in element.schema.required.names:
237 row[f"{element.name}.{name}"] = value