Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 12%

231 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-19 02:07 -0800

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("SqlQueryContext",) 

24 

25import itertools 

26from collections.abc import Iterable, Iterator, Mapping, Set 

27from contextlib import ExitStack 

28from typing import TYPE_CHECKING, Any 

29 

30import sqlalchemy 

31from lsst.daf.relation import ( 

32 BinaryOperationRelation, 

33 Calculation, 

34 Chain, 

35 ColumnTag, 

36 Deduplication, 

37 Engine, 

38 EngineError, 

39 Join, 

40 MarkerRelation, 

41 Projection, 

42 Relation, 

43 Transfer, 

44 UnaryOperation, 

45 UnaryOperationRelation, 

46 iteration, 

47 sql, 

48) 

49 

50from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column 

51from ._query_context import QueryContext 

52from .butler_sql_engine import ButlerSqlEngine 

53 

54if TYPE_CHECKING: 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true

55 from ..interfaces import Database 

56 

57 

58class SqlQueryContext(QueryContext): 

59 """An implementation of `sql.QueryContext` for `SqlRegistry`. 

60 

61 Parameters 

62 ---------- 

63 db : `Database` 

64 Object that abstracts the database engine. 

65 sql_engine : `ButlerSqlEngine` 

66 Information about column types that can vary with registry 

67 configuration. 

68 row_chunk_size : `int`, optional 

69 Number of rows to insert into temporary tables at once. If this is 

70 lower than ``db.get_constant_rows_max()`` it will be set to that value. 

71 """ 

72 

73 def __init__( 

74 self, 

75 db: Database, 

76 column_types: ColumnTypeInfo, 

77 row_chunk_size: int = 1000, 

78 ): 

79 super().__init__() 

80 self.sql_engine = ButlerSqlEngine(column_types=column_types) 

81 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max()) 

82 self._db = db 

83 self._exit_stack: ExitStack | None = None 

84 

85 def __enter__(self) -> SqlQueryContext: 

86 assert self._exit_stack is None, "Context manager already entered." 

87 self._exit_stack = ExitStack().__enter__() 

88 self._exit_stack.enter_context(self._db.session()) 

89 return self 

90 

91 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: 

92 assert self._exit_stack is not None, "Context manager not yet entered." 

93 result = self._exit_stack.__exit__(exc_type, exc_value, traceback) 

94 self._exit_stack = None 

95 return result 

96 

97 @property 

98 def is_open(self) -> bool: 

99 return self._exit_stack is not None 

100 

101 @property 

102 def column_types(self) -> ColumnTypeInfo: 

103 """Information about column types that depend on registry configuration 

104 (`ColumnTypeInfo`). 

105 """ 

106 return self.sql_engine.column_types 

107 

108 @property 

109 def preferred_engine(self) -> Engine: 

110 # Docstring inherited. 

111 return self.sql_engine 

112 

113 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int: 

114 # Docstring inherited. 

115 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset()) 

116 if relation.engine == self.sql_engine: 

117 sql_executable = self.sql_engine.to_executable( 

118 relation, extra_columns=[sqlalchemy.sql.func.count()] 

119 ) 

120 with self._db.query(sql_executable) as sql_result: 

121 return sql_result.scalar() 

122 elif (rows := relation.payload) is not None: 

123 assert isinstance( 

124 rows, iteration.MaterializedRowIterable 

125 ), "Query guarantees that only materialized payloads are attached to its relations." 

126 return len(rows) 

127 elif discard: 

128 n = 0 

129 for _ in self.fetch_iterable(relation): 

130 n += 1 

131 return n 

132 else: 

133 raise RuntimeError( 

134 f"Query with relation {relation} has deferred operations that " 

135 "must be executed in Python in order to obtain an exact " 

136 "count. Pass discard=True to run the query and discard the " 

137 "result rows while counting them, run the query first, or " 

138 "pass exact=False to obtain an upper bound." 

139 ) 

140 

141 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool: 

142 # Docstring inherited. 

143 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1] 

144 if relation.engine == self.sql_engine: 

145 sql_executable = self.sql_engine.to_executable(relation) 

146 with self._db.query(sql_executable) as sql_result: 

147 return sql_result.one_or_none() is not None 

148 elif (rows := relation.payload) is not None: 

149 assert isinstance( 

150 rows, iteration.MaterializedRowIterable 

151 ), "Query guarantees that only materialized payloads are attached to its relations." 

152 return bool(rows) 

153 elif execute: 

154 for _ in self.fetch_iterable(relation): 

155 return True 

156 return False 

157 else: 

158 raise RuntimeError( 

159 f"Query with relation {relation} has deferred operations that " 

160 "must be executed in Python in order to obtain an exact " 

161 "check for whether any rows would be returned. Pass " 

162 "execute=True to run the query until at least " 

163 "one row is found, run the query first, or pass exact=False " 

164 "test only whether the query _might_ have result rows." 

165 ) 

166 

167 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any: 

168 # Docstring inherited from lsst.daf.relation.Processor. 

169 if source.engine == self.sql_engine and destination == self.iteration_engine: 

170 return self._sql_to_iteration(source, materialize_as) 

171 if source.engine == self.iteration_engine and destination == self.sql_engine: 

172 return self._iteration_to_sql(source, materialize_as) 

173 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.") 

174 

175 def materialize(self, target: Relation, name: str) -> Any: 

176 # Docstring inherited from lsst.daf.relation.Processor. 

177 if target.engine == self.sql_engine: 

178 sql_executable = self.sql_engine.to_executable(target) 

179 table_spec = self.column_types.make_relation_table_spec(target.columns) 

180 if self._exit_stack is None: 

181 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

182 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name)) 

183 self._db.insert(table, select=sql_executable) 

184 payload = sql.Payload[LogicalColumn](table) 

185 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns) 

186 return payload 

187 return super().materialize(target, name) 

188 

189 def restore_columns( 

190 self, 

191 relation: Relation, 

192 columns_required: Set[ColumnTag], 

193 ) -> tuple[Relation, set[ColumnTag]]: 

194 # Docstring inherited. 

195 if relation.is_locked: 

196 return relation, set(relation.columns & columns_required) 

197 match relation: 

198 case UnaryOperationRelation(operation=operation, target=target): 

199 new_target, columns_found = self.restore_columns(target, columns_required) 

200 if not columns_found: 

201 return relation, columns_found 

202 match operation: 

203 case Projection(columns=columns): 

204 new_columns = columns | columns_found 

205 if new_columns != columns: 

206 return ( 

207 Projection(new_columns).apply(new_target), 

208 columns_found, 

209 ) 

210 case Calculation(tag=tag): 

211 if tag in columns_required: 

212 columns_found.add(tag) 

213 case Deduplication(): 

214 # Pulling key columns through the deduplication would 

215 # fundamentally change what it does; have to limit 

216 # ourselves to non-key columns here, and put a 

217 # Projection back in place. 

218 columns_found = {c for c in columns_found if not c.is_key} 

219 new_target = new_target.with_only_columns(relation.columns | columns_found) 

220 return relation.reapply(new_target), columns_found 

221 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs): 

222 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required) 

223 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required) 

224 match operation: 

225 case Join(): 

226 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs 

227 case Chain(): 

228 if columns_found_lhs != columns_found_rhs: 

229 # Got different answers for different join 

230 # branches; let's try again with just columns found 

231 # in both branches. 

232 new_columns_required = columns_found_lhs & columns_found_rhs 

233 if not new_columns_required: 

234 return relation, set() 

235 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required) 

236 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required) 

237 assert columns_found_lhs == columns_found_rhs 

238 return relation.reapply(new_lhs, new_rhs), columns_found_lhs 

239 case MarkerRelation(target=target): 

240 new_target, columns_found = self.restore_columns(target, columns_required) 

241 return relation.reapply(new_target), columns_found 

242 raise AssertionError("Match should be exhaustive and all branches should return.") 

243 

244 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]: 

245 # Docstring inherited. 

246 if relation.engine != self.iteration_engine or relation.is_locked: 

247 return relation, [] 

248 match relation: 

249 case UnaryOperationRelation(operation=operation, target=target): 

250 new_target, stripped = self.strip_postprocessing(target) 

251 stripped.append(operation) 

252 return new_target, stripped 

253 case Transfer(destination=self.iteration_engine, target=target): 

254 return target, [] 

255 return relation, [] 

256 

257 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation: 

258 # Docstring inherited. 

259 if relation.engine != self.iteration_engine or relation.is_locked: 

260 return relation 

261 match relation: 

262 case UnaryOperationRelation(operation=operation, target=target): 

263 columns_added_here = relation.columns - target.columns 

264 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here) 

265 if operation.columns_required <= new_columns: 

266 # This operation can stay. 

267 return relation.reapply(new_target) 

268 else: 

269 # This operation needs to be dropped, as we're about to 

270 # project away one more more of the columns it needs. 

271 return new_target 

272 return relation 

273 

274 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable: 

275 """Execute a relation transfer from the SQL engine to the iteration 

276 engine. 

277 

278 Parameters 

279 ---------- 

280 source : `~lsst.daf.relation.Relation` 

281 Input relation, in what is assumed to be `sql_engine`. 

282 materialize_as : `str` or `None` 

283 If not `None`, the name of a persistent materialization to apply 

284 in the iteration engine. This fetches all rows up front. 

285 

286 Returns 

287 ------- 

288 destination : `~lsst.daf.relation.iteration.RowIterable` 

289 Iteration engine payload iterable with the same content as the 

290 given SQL relation. 

291 """ 

292 sql_executable = self.sql_engine.to_executable(source) 

293 rows: iteration.RowIterable = _SqlRowIterable( 

294 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine.column_types) 

295 ) 

296 if materialize_as is not None: 

297 rows = rows.materialized() 

298 return rows 

299 

300 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload: 

301 """Execute a relation transfer from the iteration engine to the SQL 

302 engine. 

303 

304 Parameters 

305 ---------- 

306 source : `~lsst.daf.relation.Relation` 

307 Input relation, in what is assumed to be `iteration_engine`. 

308 materialize_as : `str` or `None` 

309 If not `None`, the name of a persistent materialization to apply 

310 in the SQL engine. This sets the name of the temporary table 

311 and ensures that one is used (instead of `Database.constant_rows`, 

312 which might otherwise be used for small row sets). 

313 

314 Returns 

315 ------- 

316 destination : `~lsst.daf.relation.sql.Payload` 

317 SQL engine payload struct with the same content as the given 

318 iteration relation. 

319 """ 

320 iterable = self.iteration_engine.execute(source) 

321 table_spec = self.column_types.make_relation_table_spec(source.columns) 

322 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine.column_types) 

323 sql_rows = [] 

324 iterator = iter(iterable) 

325 # Iterate over the first chunk of rows and transform them to hold only 

326 # types recognized by SQLAlchemy (i.e. expand out 

327 # TimespanDatabaseRepresentations). Note that this advances 

328 # `iterator`. 

329 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

330 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

331 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload") 

332 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max(): 

333 if self._exit_stack is None: 

334 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

335 # Either we're being asked to insert into a temporary table with a 

336 # "persistent" lifetime (the duration of the QueryContext), or we 

337 # have enough rows that the database says we can't use its 

338 # "constant_rows" construct (e.g. VALUES). Note that 

339 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is 

340 # greater than or equal to `Database.get_constant_rows_max`. 

341 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name)) 

342 while sql_rows: 

343 self._db.insert(table, *sql_rows) 

344 sql_rows.clear() 

345 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

346 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

347 payload = sql.Payload[LogicalColumn](table) 

348 else: 

349 # Small number of rows; use Database.constant_rows. 

350 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows)) 

351 payload.columns_available = self.sql_engine.extract_mapping( 

352 source.columns, payload.from_clause.columns 

353 ) 

354 return payload 

355 

356 def _strip_empty_invariant_operations( 

357 self, 

358 relation: Relation, 

359 exact: bool, 

360 ) -> Relation: 

361 """Return a modified relation tree that strips out relation operations 

362 that do not affect whether there are any rows from the end of the tree. 

363 

364 Parameters 

365 ---------- 

366 relation : `Relation` 

367 Original relation tree. 

368 exact : `bool` 

369 If `False` strip all iteration-engine operations, even those that 

370 can remove all rows. 

371 

372 Returns 

373 ------- 

374 modified : `Relation` 

375 Modified relation tree. 

376 """ 

377 if relation.payload is not None or relation.is_locked: 

378 return relation 

379 match relation: 

380 case UnaryOperationRelation(operation=operation, target=target): 

381 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine): 

382 return self._strip_empty_invariant_operations(target, exact) 

383 case Transfer(target=target): 

384 return self._strip_empty_invariant_operations(target, exact) 

385 case MarkerRelation(target=target): 

386 return relation.reapply(self._strip_empty_invariant_operations(target, exact)) 

387 return relation 

388 

389 def _strip_count_invariant_operations( 

390 self, 

391 relation: Relation, 

392 exact: bool, 

393 ) -> Relation: 

394 """Return a modified relation tree that strips out relation operations 

395 that do not affect the number of rows from the end of the tree. 

396 

397 Parameters 

398 ---------- 

399 relation : `Relation` 

400 Original relation tree. 

401 exact : `bool` 

402 If `False` strip all iteration-engine operations, even those that 

403 can affect the number of rows. 

404 

405 Returns 

406 ------- 

407 modified : `Relation` 

408 Modified relation tree. 

409 """ 

410 if relation.payload is not None or relation.is_locked: 

411 return relation 

412 match relation: 

413 case UnaryOperationRelation(operation=operation, target=target): 

414 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine): 

415 return self._strip_count_invariant_operations(target, exact) 

416 case Transfer(target=target): 

417 return self._strip_count_invariant_operations(target, exact) 

418 case MarkerRelation(target=target): 

419 return relation.reapply(self._strip_count_invariant_operations(target, exact)) 

420 return relation 

421 

422 

423class _SqlRowIterable(iteration.RowIterable): 

424 """An implementation of `lsst.daf.relation.iteration.RowIterable` that 

425 executes a SQL query. 

426 

427 Parameters 

428 ---------- 

429 context : `SqlQueryContext` 

430 Context to execute the query with. If this context has already been 

431 entered, a lazy iterable will be returned that may or may not use 

432 server side cursors, since the context's lifetime can be used to manage 

433 that cursor's lifetime. If the context has not been entered, all 

434 results will be fetched up front in raw form, but will still be 

435 processed into mappings keyed by `ColumnTag` lazily. 

436 sql_executable : `sqlalchemy.sql.expression.SelectBase` 

437 SQL query to execute; assumed to have been built by 

438 `lsst.daf.relation.sql.Engine.to_executable` or similar. 

439 row_transformer : `_SqlRowTransformer` 

440 Object that converts SQLAlchemy result-row mappings (with `str` 

441 keys and possibly-unpacked timespan values) to relation row 

442 mappings (with `ColumnTag` keys and `LogicalColumn` values). 

443 """ 

444 

445 def __init__( 

446 self, 

447 context: SqlQueryContext, 

448 sql_executable: sqlalchemy.sql.expression.SelectBase, 

449 row_transformer: _SqlRowTransformer, 

450 ): 

451 self._context = context 

452 self._sql_executable = sql_executable 

453 self._row_transformer = row_transformer 

454 

455 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]: 

456 with self._context._db.query(self._sql_executable) as sql_result: 

457 raw_rows = sql_result.mappings() 

458 if self._context._exit_stack is None: 

459 raw_rows = raw_rows.fetchall() 

460 for sql_row in raw_rows: 

461 yield self._row_transformer.sql_to_relation(sql_row) 

462 

463 

464class _SqlRowTransformer: 

465 """Object that converts SQLAlchemy result-row mappings to relation row 

466 mappings. 

467 

468 Parameters 

469 ---------- 

470 columns : `Iterable` [ `ColumnTag` ] 

471 Set of columns to handle. Rows must have at least these columns, but 

472 may have more. 

473 column_types : `ColumnTypeInfo` 

474 Information about column types that can vary with registry 

475 configuration. 

476 """ 

477 

478 def __init__(self, columns: Iterable[ColumnTag], column_types: ColumnTypeInfo): 

479 self._scalar_columns = set(columns) 

480 self._timespan_columns: dict[ColumnTag, type[TimespanDatabaseRepresentation]] = {} 

481 self._timespan_columns.update( 

482 {tag: column_types.timespan_cls for tag in columns if is_timespan_column(tag)} 

483 ) 

484 self._scalar_columns.difference_update(self._timespan_columns.keys()) 

485 self._has_no_columns = not columns 

486 

487 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns") 

488 

489 def sql_to_relation(self, sql_row: Mapping[str, Any]) -> dict[ColumnTag, Any]: 

490 """Convert a result row from a SQLAlchemy result into the form expected 

491 by `lsst.daf.relation.iteration`. 

492 

493 Parameters 

494 ---------- 

495 sql_row : `Mapping` 

496 Mapping with `str` keys and possibly-unpacked values for timespan 

497 columns. 

498 

499 Returns 

500 ------- 

501 relation_row : `Mapping` 

502 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

503 columns. 

504 """ 

505 relation_row = {tag: sql_row[tag.qualified_name] for tag in self._scalar_columns} 

506 for tag, db_repr_cls in self._timespan_columns.items(): 

507 relation_row[tag] = db_repr_cls.extract(sql_row, name=tag.qualified_name) 

508 return relation_row 

509 

510 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]: 

511 """Convert a `lsst.daf.relation.iteration` row into the mapping form 

512 used by SQLAlchemy. 

513 

514 Parameters 

515 ---------- 

516 relation_row : `Mapping` 

517 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

518 columns. 

519 

520 Returns 

521 ------- 

522 sql_row : `Mapping` 

523 Mapping with `str` keys and possibly-unpacked values for timespan 

524 columns. 

525 """ 

526 sql_row = {tag.qualified_name: relation_row[tag] for tag in self._scalar_columns} 

527 for tag, db_repr_cls in self._timespan_columns.items(): 

528 db_repr_cls.update(relation_row[tag], result=sql_row, name=tag.qualified_name) 

529 if self._has_no_columns: 

530 sql_row["IGNORED"] = None 

531 return sql_row