Coverage for python / lsst / daf / butler / remote_butler / _query_driver.py: 0%
116 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:48 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30from uuid import uuid4
32__all__ = ("RemoteQueryDriver",)
35from collections.abc import Iterable, Iterator
36from contextlib import ExitStack
37from typing import Any, Literal, overload
39import httpx
41from ...butler import Butler
42from .._dataset_type import DatasetType
43from ..column_spec import make_tuple_type_adapter
44from ..dimensions import (
45 DataCoordinate,
46 DataIdValue,
47 DimensionGroup,
48 DimensionRecord,
49 DimensionRecordSet,
50 DimensionUniverse,
51)
52from ..queries.driver import (
53 DataCoordinateResultPage,
54 DatasetRefResultPage,
55 DimensionRecordResultPage,
56 GeneralResultPage,
57 QueryDriver,
58 ResultPage,
59)
60from ..queries.result_specs import (
61 DataCoordinateResultSpec,
62 DatasetRefResultSpec,
63 DimensionRecordResultSpec,
64 GeneralResultSpec,
65 ResultSpec,
66 SerializedResultSpec,
67)
68from ..queries.tree import DataCoordinateUploadKey, MaterializationKey, QueryTree, SerializedQueryTree
69from ..registry import NoDefaultCollectionError
70from ._http_connection import RemoteButlerHttpConnection, parse_model
71from ._query_results import convert_dataset_ref_results, read_query_results
72from .server_models import (
73 AdditionalQueryInput,
74 DataCoordinateUpload,
75 GeneralResultModel,
76 MaterializedQuery,
77 QueryAnyRequestModel,
78 QueryAnyResponseModel,
79 QueryCountRequestModel,
80 QueryCountResponseModel,
81 QueryExecuteRequestModel,
82 QueryExecuteResultData,
83 QueryExplainRequestModel,
84 QueryExplainResponseModel,
85 QueryInputs,
86)
89class RemoteQueryDriver(QueryDriver):
90 """Implementation of QueryDriver for client/server Butler.
92 Parameters
93 ----------
94 butler : `Butler`
95 Butler instance that will use this QueryDriver.
96 connection : `RemoteButlerHttpConnection`
97 HTTP connection used to send queries to Butler server.
98 """
100 def __init__(self, butler: Butler, connection: RemoteButlerHttpConnection):
101 self._butler = butler
102 self._connection = connection
103 self._stored_query_inputs: list[AdditionalQueryInput] = []
104 self._pending_queries: set[httpx.Response] = set()
105 self._closed = False
107 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Literal[False]:
108 self._closed = True
109 # Clean up any queries that the user didn't finish iterating. The exit
110 # stack helps handle any exceptions that may be thrown during cleanup.
111 stack = ExitStack().__enter__()
112 for pending in self._pending_queries:
113 stack.callback(pending.close)
114 self._pending_queries = set()
115 stack.__exit__(exc_type, exc_value, traceback)
116 return False
118 @property
119 def universe(self) -> DimensionUniverse:
120 return self._butler.dimensions
122 @overload
123 def execute(
124 self, result_spec: DataCoordinateResultSpec, tree: QueryTree
125 ) -> Iterator[DataCoordinateResultPage]: ...
127 @overload
128 def execute(
129 self, result_spec: DimensionRecordResultSpec, tree: QueryTree
130 ) -> Iterator[DimensionRecordResultPage]: ...
132 @overload
133 def execute(
134 self, result_spec: DatasetRefResultSpec, tree: QueryTree
135 ) -> Iterator[DatasetRefResultPage]: ...
137 @overload
138 def execute(self, result_spec: GeneralResultSpec, tree: QueryTree) -> Iterator[GeneralResultPage]: ...
140 def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPage]:
141 if self._closed:
142 raise RuntimeError("Cannot execute query: query context has been closed")
144 request = QueryExecuteRequestModel(
145 query=self._create_query_input(tree), result_spec=SerializedResultSpec(result_spec)
146 )
147 universe = self.universe
148 with self._connection.post_with_stream_response("query/execute", request) as response:
149 self._pending_queries.add(response)
150 try:
151 # There is one result page JSON object per line of the
152 # response.
153 for result_chunk in read_query_results(response):
154 yield _convert_query_result_page(result_spec, result_chunk, universe)
155 if self._closed:
156 raise RuntimeError(
157 "Cannot continue query result iteration: query context has been closed"
158 )
159 finally:
160 self._pending_queries.discard(response)
162 def materialize(
163 self,
164 tree: QueryTree,
165 dimensions: DimensionGroup,
166 datasets: frozenset[str],
167 allow_duplicate_overlaps: bool = False,
168 ) -> MaterializationKey:
169 key = uuid4()
170 self._stored_query_inputs.append(
171 MaterializedQuery(
172 key=key,
173 tree=SerializedQueryTree(tree.model_copy(deep=True)),
174 dimensions=dimensions.to_simple(),
175 datasets=datasets,
176 allow_duplicate_overlaps=allow_duplicate_overlaps,
177 ),
178 )
179 return key
181 def upload_data_coordinates(
182 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]]
183 ) -> DataCoordinateUploadKey:
184 key = uuid4()
185 self._stored_query_inputs.append(
186 DataCoordinateUpload(key=key, dimensions=dimensions.to_simple(), rows=list(rows))
187 )
188 return key
190 def count(
191 self,
192 tree: QueryTree,
193 result_spec: ResultSpec,
194 *,
195 exact: bool,
196 discard: bool,
197 ) -> int:
198 request = QueryCountRequestModel(
199 query=self._create_query_input(tree),
200 result_spec=SerializedResultSpec(result_spec),
201 exact=exact,
202 discard=discard,
203 )
204 response = self._connection.post("query/count", request)
205 result = parse_model(response, QueryCountResponseModel)
206 return result.count
208 def any(self, tree: QueryTree, *, execute: bool, exact: bool) -> bool:
209 request = QueryAnyRequestModel(
210 query=self._create_query_input(tree),
211 exact=exact,
212 execute=execute,
213 )
214 response = self._connection.post("query/any", request)
215 result = parse_model(response, QueryAnyResponseModel)
216 return result.found_rows
218 def explain_no_results(self, tree: QueryTree, execute: bool) -> Iterable[str]:
219 request = QueryExplainRequestModel(
220 query=self._create_query_input(tree),
221 execute=execute,
222 )
223 response = self._connection.post("query/explain", request)
224 result = parse_model(response, QueryExplainResponseModel)
225 return result.messages
227 def get_default_collections(self) -> tuple[str, ...]:
228 collections = tuple(self._butler.collections.defaults)
229 if not collections:
230 raise NoDefaultCollectionError("No collections provided and no default collections.")
231 return collections
233 def get_dataset_type(self, name: str) -> DatasetType:
234 return self._butler.get_dataset_type(name)
236 def _create_query_input(self, tree: QueryTree) -> QueryInputs:
237 return QueryInputs(
238 tree=SerializedQueryTree(tree),
239 default_data_id=self._butler.registry.defaults.dataId.to_simple(),
240 additional_query_inputs=self._stored_query_inputs,
241 )
244def _convert_query_result_page(
245 result_spec: ResultSpec, result: QueryExecuteResultData, universe: DimensionUniverse
246) -> ResultPage:
247 if result_spec.result_type == "dimension_record":
248 assert result.type == "dimension_record"
249 return DimensionRecordResultPage(
250 spec=result_spec,
251 rows=[DimensionRecord.from_simple(r, universe) for r in result.rows],
252 )
253 elif result_spec.result_type == "data_coordinate":
254 assert result.type == "data_coordinate"
255 return DataCoordinateResultPage(
256 spec=result_spec,
257 rows=[DataCoordinate.from_simple(r, universe) for r in result.rows],
258 )
259 elif result_spec.result_type == "dataset_ref":
260 return DatasetRefResultPage(spec=result_spec, rows=convert_dataset_ref_results(result, universe))
261 elif result_spec.result_type == "general":
262 assert result.type == "general"
263 return _convert_general_result(result_spec, result)
264 else:
265 raise NotImplementedError(f"Unhandled result type {result_spec.result_type}")
268def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage:
269 """Convert GeneralResultModel to a general result page."""
270 if spec.include_dimension_records:
271 # dimension_records must not be None when `include_dimension_records`
272 # is True, but it will be None if remote server was not upgraded.
273 if model.dimension_records is None:
274 raise ValueError(
275 "Missing dimension records in general result -- it is likely that server needs an upgrade."
276 )
278 columns = spec.get_result_columns()
279 row_type_adapter = make_tuple_type_adapter(
280 [columns.get_column_spec(column.logical_table, column.field) for column in columns]
281 )
282 rows = [row_type_adapter.validate_python(row) for row in model.rows]
284 universe = spec.dimensions.universe
285 dimension_records = None
286 if model.dimension_records is not None:
287 dimension_records = {}
288 for name, records in model.dimension_records.items():
289 element = universe[name]
290 dimension_records[element] = DimensionRecordSet(
291 element, (DimensionRecord.from_simple(r, universe) for r in records)
292 )
294 return GeneralResultPage(spec=spec, rows=rows, dimension_records=dimension_records)