Coverage for python / lsst / daf / butler / remote_butler / _query_driver.py: 0%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:36 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30from uuid import uuid4 

31 

32__all__ = ("RemoteQueryDriver",) 

33 

34 

35from collections.abc import Iterable, Iterator 

36from contextlib import ExitStack 

37from typing import Any, Literal, overload 

38 

39import httpx 

40 

41from ...butler import Butler 

42from .._dataset_type import DatasetType 

43from ..column_spec import make_tuple_type_adapter 

44from ..dimensions import ( 

45 DataCoordinate, 

46 DataIdValue, 

47 DimensionGroup, 

48 DimensionRecord, 

49 DimensionRecordSet, 

50 DimensionUniverse, 

51) 

52from ..queries.driver import ( 

53 DataCoordinateResultPage, 

54 DatasetRefResultPage, 

55 DimensionRecordResultPage, 

56 GeneralResultPage, 

57 QueryDriver, 

58 ResultPage, 

59) 

60from ..queries.result_specs import ( 

61 DataCoordinateResultSpec, 

62 DatasetRefResultSpec, 

63 DimensionRecordResultSpec, 

64 GeneralResultSpec, 

65 ResultSpec, 

66 SerializedResultSpec, 

67) 

68from ..queries.tree import DataCoordinateUploadKey, MaterializationKey, QueryTree, SerializedQueryTree 

69from ..registry import NoDefaultCollectionError 

70from ._http_connection import RemoteButlerHttpConnection, parse_model 

71from ._query_results import convert_dataset_ref_results, read_query_results 

72from .server_models import ( 

73 AdditionalQueryInput, 

74 DataCoordinateUpload, 

75 GeneralResultModel, 

76 MaterializedQuery, 

77 QueryAnyRequestModel, 

78 QueryAnyResponseModel, 

79 QueryCountRequestModel, 

80 QueryCountResponseModel, 

81 QueryExecuteRequestModel, 

82 QueryExecuteResultData, 

83 QueryExplainRequestModel, 

84 QueryExplainResponseModel, 

85 QueryInputs, 

86) 

87 

88 

89class RemoteQueryDriver(QueryDriver): 

90 """Implementation of QueryDriver for client/server Butler. 

91 

92 Parameters 

93 ---------- 

94 butler : `Butler` 

95 Butler instance that will use this QueryDriver. 

96 connection : `RemoteButlerHttpConnection` 

97 HTTP connection used to send queries to Butler server. 

98 """ 

99 

100 def __init__(self, butler: Butler, connection: RemoteButlerHttpConnection): 

101 self._butler = butler 

102 self._connection = connection 

103 self._stored_query_inputs: list[AdditionalQueryInput] = [] 

104 self._pending_queries: set[httpx.Response] = set() 

105 self._closed = False 

106 

107 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Literal[False]: 

108 self._closed = True 

109 # Clean up any queries that the user didn't finish iterating. The exit 

110 # stack helps handle any exceptions that may be thrown during cleanup. 

111 stack = ExitStack().__enter__() 

112 for pending in self._pending_queries: 

113 stack.callback(pending.close) 

114 self._pending_queries = set() 

115 stack.__exit__(exc_type, exc_value, traceback) 

116 return False 

117 

118 @property 

119 def universe(self) -> DimensionUniverse: 

120 return self._butler.dimensions 

121 

122 @overload 

123 def execute( 

124 self, result_spec: DataCoordinateResultSpec, tree: QueryTree 

125 ) -> Iterator[DataCoordinateResultPage]: ... 

126 

127 @overload 

128 def execute( 

129 self, result_spec: DimensionRecordResultSpec, tree: QueryTree 

130 ) -> Iterator[DimensionRecordResultPage]: ... 

131 

132 @overload 

133 def execute( 

134 self, result_spec: DatasetRefResultSpec, tree: QueryTree 

135 ) -> Iterator[DatasetRefResultPage]: ... 

136 

137 @overload 

138 def execute(self, result_spec: GeneralResultSpec, tree: QueryTree) -> Iterator[GeneralResultPage]: ... 

139 

140 def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPage]: 

141 if self._closed: 

142 raise RuntimeError("Cannot execute query: query context has been closed") 

143 

144 request = QueryExecuteRequestModel( 

145 query=self._create_query_input(tree), result_spec=SerializedResultSpec(result_spec) 

146 ) 

147 universe = self.universe 

148 with self._connection.post_with_stream_response("query/execute", request) as response: 

149 self._pending_queries.add(response) 

150 try: 

151 # There is one result page JSON object per line of the 

152 # response. 

153 for result_chunk in read_query_results(response): 

154 yield _convert_query_result_page(result_spec, result_chunk, universe) 

155 if self._closed: 

156 raise RuntimeError( 

157 "Cannot continue query result iteration: query context has been closed" 

158 ) 

159 finally: 

160 self._pending_queries.discard(response) 

161 

162 def materialize( 

163 self, 

164 tree: QueryTree, 

165 dimensions: DimensionGroup, 

166 datasets: frozenset[str], 

167 allow_duplicate_overlaps: bool = False, 

168 ) -> MaterializationKey: 

169 key = uuid4() 

170 self._stored_query_inputs.append( 

171 MaterializedQuery( 

172 key=key, 

173 tree=SerializedQueryTree(tree.model_copy(deep=True)), 

174 dimensions=dimensions.to_simple(), 

175 datasets=datasets, 

176 allow_duplicate_overlaps=allow_duplicate_overlaps, 

177 ), 

178 ) 

179 return key 

180 

181 def upload_data_coordinates( 

182 self, dimensions: DimensionGroup, rows: Iterable[tuple[DataIdValue, ...]] 

183 ) -> DataCoordinateUploadKey: 

184 key = uuid4() 

185 self._stored_query_inputs.append( 

186 DataCoordinateUpload(key=key, dimensions=dimensions.to_simple(), rows=list(rows)) 

187 ) 

188 return key 

189 

190 def count( 

191 self, 

192 tree: QueryTree, 

193 result_spec: ResultSpec, 

194 *, 

195 exact: bool, 

196 discard: bool, 

197 ) -> int: 

198 request = QueryCountRequestModel( 

199 query=self._create_query_input(tree), 

200 result_spec=SerializedResultSpec(result_spec), 

201 exact=exact, 

202 discard=discard, 

203 ) 

204 response = self._connection.post("query/count", request) 

205 result = parse_model(response, QueryCountResponseModel) 

206 return result.count 

207 

208 def any(self, tree: QueryTree, *, execute: bool, exact: bool) -> bool: 

209 request = QueryAnyRequestModel( 

210 query=self._create_query_input(tree), 

211 exact=exact, 

212 execute=execute, 

213 ) 

214 response = self._connection.post("query/any", request) 

215 result = parse_model(response, QueryAnyResponseModel) 

216 return result.found_rows 

217 

218 def explain_no_results(self, tree: QueryTree, execute: bool) -> Iterable[str]: 

219 request = QueryExplainRequestModel( 

220 query=self._create_query_input(tree), 

221 execute=execute, 

222 ) 

223 response = self._connection.post("query/explain", request) 

224 result = parse_model(response, QueryExplainResponseModel) 

225 return result.messages 

226 

227 def get_default_collections(self) -> tuple[str, ...]: 

228 collections = tuple(self._butler.collections.defaults) 

229 if not collections: 

230 raise NoDefaultCollectionError("No collections provided and no default collections.") 

231 return collections 

232 

233 def get_dataset_type(self, name: str) -> DatasetType: 

234 return self._butler.get_dataset_type(name) 

235 

236 def _create_query_input(self, tree: QueryTree) -> QueryInputs: 

237 return QueryInputs( 

238 tree=SerializedQueryTree(tree), 

239 default_data_id=self._butler.registry.defaults.dataId.to_simple(), 

240 additional_query_inputs=self._stored_query_inputs, 

241 ) 

242 

243 

244def _convert_query_result_page( 

245 result_spec: ResultSpec, result: QueryExecuteResultData, universe: DimensionUniverse 

246) -> ResultPage: 

247 if result_spec.result_type == "dimension_record": 

248 assert result.type == "dimension_record" 

249 return DimensionRecordResultPage( 

250 spec=result_spec, 

251 rows=[DimensionRecord.from_simple(r, universe) for r in result.rows], 

252 ) 

253 elif result_spec.result_type == "data_coordinate": 

254 assert result.type == "data_coordinate" 

255 return DataCoordinateResultPage( 

256 spec=result_spec, 

257 rows=[DataCoordinate.from_simple(r, universe) for r in result.rows], 

258 ) 

259 elif result_spec.result_type == "dataset_ref": 

260 return DatasetRefResultPage(spec=result_spec, rows=convert_dataset_ref_results(result, universe)) 

261 elif result_spec.result_type == "general": 

262 assert result.type == "general" 

263 return _convert_general_result(result_spec, result) 

264 else: 

265 raise NotImplementedError(f"Unhandled result type {result_spec.result_type}") 

266 

267 

268def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage: 

269 """Convert GeneralResultModel to a general result page.""" 

270 if spec.include_dimension_records: 

271 # dimension_records must not be None when `include_dimension_records` 

272 # is True, but it will be None if remote server was not upgraded. 

273 if model.dimension_records is None: 

274 raise ValueError( 

275 "Missing dimension records in general result -- it is likely that server needs an upgrade." 

276 ) 

277 

278 columns = spec.get_result_columns() 

279 row_type_adapter = make_tuple_type_adapter( 

280 [columns.get_column_spec(column.logical_table, column.field) for column in columns] 

281 ) 

282 rows = [row_type_adapter.validate_python(row) for row in model.rows] 

283 

284 universe = spec.dimensions.universe 

285 dimension_records = None 

286 if model.dimension_records is not None: 

287 dimension_records = {} 

288 for name, records in model.dimension_records.items(): 

289 element = universe[name] 

290 dimension_records[element] = DimensionRecordSet( 

291 element, (DimensionRecord.from_simple(r, universe) for r in records) 

292 ) 

293 

294 return GeneralResultPage(spec=spec, rows=rows, dimension_records=dimension_records)