Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 13%

238 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-15 02:03 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("SqlQueryContext",) 

30 

31import itertools 

32from collections.abc import Iterable, Iterator, Mapping, Set 

33from contextlib import ExitStack 

34from typing import TYPE_CHECKING, Any, cast 

35 

36import sqlalchemy 

37from lsst.daf.relation import ( 

38 BinaryOperationRelation, 

39 Calculation, 

40 Chain, 

41 ColumnTag, 

42 Deduplication, 

43 Engine, 

44 EngineError, 

45 Join, 

46 MarkerRelation, 

47 Projection, 

48 Relation, 

49 Transfer, 

50 UnaryOperation, 

51 UnaryOperationRelation, 

52 iteration, 

53 sql, 

54) 

55 

56from ..._column_tags import is_timespan_column 

57from ..._column_type_info import ColumnTypeInfo, LogicalColumn 

58from ...name_shrinker import NameShrinker 

59from ...timespan_database_representation import TimespanDatabaseRepresentation 

60from ._query_context import QueryContext 

61from .butler_sql_engine import ButlerSqlEngine 

62 

63if TYPE_CHECKING: 

64 from ..interfaces import Database 

65 

66 

67class SqlQueryContext(QueryContext): 

68 """An implementation of `sql.QueryContext` for `SqlRegistry`. 

69 

70 Parameters 

71 ---------- 

72 db : `Database` 

73 Object that abstracts the database engine. 

74 column_types : `ColumnTypeInfo` 

75 Information about column types that can vary with registry 

76 configuration. 

77 row_chunk_size : `int`, optional 

78 Number of rows to insert into temporary tables at once. If this is 

79 lower than ``db.get_constant_rows_max()`` it will be set to that value. 

80 """ 

81 

82 def __init__( 

83 self, 

84 db: Database, 

85 column_types: ColumnTypeInfo, 

86 row_chunk_size: int = 1000, 

87 ): 

88 super().__init__() 

89 self.sql_engine = ButlerSqlEngine( 

90 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length) 

91 ) 

92 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max()) 

93 self._db = db 

94 self._exit_stack: ExitStack | None = None 

95 

96 def __enter__(self) -> SqlQueryContext: 

97 assert self._exit_stack is None, "Context manager already entered." 

98 self._exit_stack = ExitStack().__enter__() 

99 self._exit_stack.enter_context(self._db.session()) 

100 return self 

101 

102 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: 

103 assert self._exit_stack is not None, "Context manager not yet entered." 

104 result = self._exit_stack.__exit__(exc_type, exc_value, traceback) 

105 self._exit_stack = None 

106 return result 

107 

108 @property 

109 def is_open(self) -> bool: 

110 return self._exit_stack is not None 

111 

112 @property 

113 def column_types(self) -> ColumnTypeInfo: 

114 """Information about column types that depend on registry configuration 

115 (`ColumnTypeInfo`). 

116 """ 

117 return self.sql_engine.column_types 

118 

119 @property 

120 def preferred_engine(self) -> Engine: 

121 # Docstring inherited. 

122 return self.sql_engine 

123 

124 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int: 

125 # Docstring inherited. 

126 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset()) 

127 if relation.engine == self.sql_engine: 

128 sql_executable = self.sql_engine.to_executable( 

129 relation, extra_columns=[sqlalchemy.sql.func.count()] 

130 ) 

131 with self._db.query(sql_executable) as sql_result: 

132 return cast(int, sql_result.scalar_one()) 

133 elif (rows := relation.payload) is not None: 

134 assert isinstance( 

135 rows, iteration.MaterializedRowIterable 

136 ), "Query guarantees that only materialized payloads are attached to its relations." 

137 return len(rows) 

138 elif discard: 

139 n = 0 

140 for _ in self.fetch_iterable(relation): 

141 n += 1 

142 return n 

143 else: 

144 raise RuntimeError( 

145 f"Query with relation {relation} has deferred operations that " 

146 "must be executed in Python in order to obtain an exact " 

147 "count. Pass discard=True to run the query and discard the " 

148 "result rows while counting them, run the query first, or " 

149 "pass exact=False to obtain an upper bound." 

150 ) 

151 

152 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool: 

153 # Docstring inherited. 

154 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1] 

155 if relation.engine == self.sql_engine: 

156 sql_executable = self.sql_engine.to_executable(relation) 

157 with self._db.query(sql_executable) as sql_result: 

158 return sql_result.one_or_none() is not None 

159 elif (rows := relation.payload) is not None: 

160 assert isinstance( 

161 rows, iteration.MaterializedRowIterable 

162 ), "Query guarantees that only materialized payloads are attached to its relations." 

163 return bool(rows) 

164 elif execute: 

165 for _ in self.fetch_iterable(relation): 

166 return True 

167 return False 

168 else: 

169 raise RuntimeError( 

170 f"Query with relation {relation} has deferred operations that " 

171 "must be executed in Python in order to obtain an exact " 

172 "check for whether any rows would be returned. Pass " 

173 "execute=True to run the query until at least " 

174 "one row is found, run the query first, or pass exact=False " 

175 "test only whether the query _might_ have result rows." 

176 ) 

177 

178 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any: 

179 # Docstring inherited from lsst.daf.relation.Processor. 

180 if source.engine == self.sql_engine and destination == self.iteration_engine: 

181 return self._sql_to_iteration(source, materialize_as) 

182 if source.engine == self.iteration_engine and destination == self.sql_engine: 

183 return self._iteration_to_sql(source, materialize_as) 

184 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.") 

185 

186 def materialize(self, target: Relation, name: str) -> Any: 

187 # Docstring inherited from lsst.daf.relation.Processor. 

188 if target.engine == self.sql_engine: 

189 sql_executable = self.sql_engine.to_executable(target) 

190 table_spec = self.column_types.make_relation_table_spec(target.columns) 

191 if self._exit_stack is None: 

192 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

193 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name)) 

194 self._db.insert(table, select=sql_executable) 

195 payload = sql.Payload[LogicalColumn](table) 

196 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns) 

197 return payload 

198 return super().materialize(target, name) 

199 

200 def restore_columns( 

201 self, 

202 relation: Relation, 

203 columns_required: Set[ColumnTag], 

204 ) -> tuple[Relation, set[ColumnTag]]: 

205 # Docstring inherited. 

206 if relation.is_locked: 

207 return relation, set(relation.columns & columns_required) 

208 match relation: 

209 case UnaryOperationRelation(operation=operation, target=target): 

210 new_target, columns_found = self.restore_columns(target, columns_required) 

211 if not columns_found: 

212 return relation, columns_found 

213 match operation: 

214 case Projection(columns=columns): 

215 new_columns = columns | columns_found 

216 if new_columns != columns: 

217 return ( 

218 Projection(new_columns).apply(new_target), 

219 columns_found, 

220 ) 

221 case Calculation(tag=tag): 

222 if tag in columns_required: 

223 columns_found.add(tag) 

224 case Deduplication(): 

225 # Pulling key columns through the deduplication would 

226 # fundamentally change what it does; have to limit 

227 # ourselves to non-key columns here, and put a 

228 # Projection back in place. 

229 columns_found = {c for c in columns_found if not c.is_key} 

230 new_target = new_target.with_only_columns(relation.columns | columns_found) 

231 return relation.reapply(new_target), columns_found 

232 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs): 

233 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required) 

234 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required) 

235 match operation: 

236 case Join(): 

237 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs 

238 case Chain(): 

239 if columns_found_lhs != columns_found_rhs: 

240 # Got different answers for different join 

241 # branches; let's try again with just columns found 

242 # in both branches. 

243 new_columns_required = columns_found_lhs & columns_found_rhs 

244 if not new_columns_required: 

245 return relation, set() 

246 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required) 

247 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required) 

248 assert columns_found_lhs == columns_found_rhs 

249 return relation.reapply(new_lhs, new_rhs), columns_found_lhs 

250 case MarkerRelation(target=target): 

251 new_target, columns_found = self.restore_columns(target, columns_required) 

252 return relation.reapply(new_target), columns_found 

253 raise AssertionError("Match should be exhaustive and all branches should return.") 

254 

255 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]: 

256 # Docstring inherited. 

257 if relation.engine != self.iteration_engine or relation.is_locked: 

258 return relation, [] 

259 match relation: 

260 case UnaryOperationRelation(operation=operation, target=target): 

261 new_target, stripped = self.strip_postprocessing(target) 

262 stripped.append(operation) 

263 return new_target, stripped 

264 case Transfer(destination=self.iteration_engine, target=target): 

265 return target, [] 

266 return relation, [] 

267 

268 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation: 

269 # Docstring inherited. 

270 if relation.engine != self.iteration_engine or relation.is_locked: 

271 return relation 

272 match relation: 

273 case UnaryOperationRelation(operation=operation, target=target): 

274 columns_added_here = relation.columns - target.columns 

275 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here) 

276 if operation.columns_required <= new_columns: 

277 # This operation can stay... 

278 if new_target is target: 

279 # ...and nothing has actually changed upstream of it. 

280 return relation 

281 else: 

282 # ...but we should see if it can now be performed in 

283 # the preferred engine, since something it didn't 

284 # commute with may have been removed. 

285 return operation.apply(new_target, preferred_engine=self.preferred_engine) 

286 else: 

287 # This operation needs to be dropped, as we're about to 

288 # project away one more more of the columns it needs. 

289 return new_target 

290 return relation 

291 

292 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable: 

293 """Execute a relation transfer from the SQL engine to the iteration 

294 engine. 

295 

296 Parameters 

297 ---------- 

298 source : `~lsst.daf.relation.Relation` 

299 Input relation, in what is assumed to be `sql_engine`. 

300 materialize_as : `str` or `None` 

301 If not `None`, the name of a persistent materialization to apply 

302 in the iteration engine. This fetches all rows up front. 

303 

304 Returns 

305 ------- 

306 destination : `~lsst.daf.relation.iteration.RowIterable` 

307 Iteration engine payload iterable with the same content as the 

308 given SQL relation. 

309 """ 

310 sql_executable = self.sql_engine.to_executable(source) 

311 rows: iteration.RowIterable = _SqlRowIterable( 

312 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine) 

313 ) 

314 if materialize_as is not None: 

315 rows = rows.materialized() 

316 return rows 

317 

318 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload: 

319 """Execute a relation transfer from the iteration engine to the SQL 

320 engine. 

321 

322 Parameters 

323 ---------- 

324 source : `~lsst.daf.relation.Relation` 

325 Input relation, in what is assumed to be `iteration_engine`. 

326 materialize_as : `str` or `None` 

327 If not `None`, the name of a persistent materialization to apply 

328 in the SQL engine. This sets the name of the temporary table 

329 and ensures that one is used (instead of `Database.constant_rows`, 

330 which might otherwise be used for small row sets). 

331 

332 Returns 

333 ------- 

334 destination : `~lsst.daf.relation.sql.Payload` 

335 SQL engine payload struct with the same content as the given 

336 iteration relation. 

337 """ 

338 iterable = self.iteration_engine.execute(source) 

339 table_spec = self.column_types.make_relation_table_spec(source.columns) 

340 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine) 

341 sql_rows = [] 

342 iterator = iter(iterable) 

343 # Iterate over the first chunk of rows and transform them to hold only 

344 # types recognized by SQLAlchemy (i.e. expand out 

345 # TimespanDatabaseRepresentations). Note that this advances 

346 # `iterator`. 

347 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

348 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

349 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload") 

350 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max(): 

351 if self._exit_stack is None: 

352 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

353 # Either we're being asked to insert into a temporary table with a 

354 # "persistent" lifetime (the duration of the QueryContext), or we 

355 # have enough rows that the database says we can't use its 

356 # "constant_rows" construct (e.g. VALUES). Note that 

357 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is 

358 # greater than or equal to `Database.get_constant_rows_max`. 

359 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name)) 

360 while sql_rows: 

361 self._db.insert(table, *sql_rows) 

362 sql_rows.clear() 

363 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

364 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

365 payload = sql.Payload[LogicalColumn](table) 

366 else: 

367 # Small number of rows; use Database.constant_rows. 

368 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows)) 

369 payload.columns_available = self.sql_engine.extract_mapping( 

370 source.columns, payload.from_clause.columns 

371 ) 

372 return payload 

373 

374 def _strip_empty_invariant_operations( 

375 self, 

376 relation: Relation, 

377 exact: bool, 

378 ) -> Relation: 

379 """Return a modified relation tree that strips out relation operations 

380 that do not affect whether there are any rows from the end of the tree. 

381 

382 Parameters 

383 ---------- 

384 relation : `Relation` 

385 Original relation tree. 

386 exact : `bool` 

387 If `False` strip all iteration-engine operations, even those that 

388 can remove all rows. 

389 

390 Returns 

391 ------- 

392 modified : `Relation` 

393 Modified relation tree. 

394 """ 

395 if relation.payload is not None or relation.is_locked: 

396 return relation 

397 match relation: 

398 case UnaryOperationRelation(operation=operation, target=target): 

399 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine): 

400 return self._strip_empty_invariant_operations(target, exact) 

401 case Transfer(target=target): 

402 return self._strip_empty_invariant_operations(target, exact) 

403 case MarkerRelation(target=target): 

404 return relation.reapply(self._strip_empty_invariant_operations(target, exact)) 

405 return relation 

406 

407 def _strip_count_invariant_operations( 

408 self, 

409 relation: Relation, 

410 exact: bool, 

411 ) -> Relation: 

412 """Return a modified relation tree that strips out relation operations 

413 that do not affect the number of rows from the end of the tree. 

414 

415 Parameters 

416 ---------- 

417 relation : `Relation` 

418 Original relation tree. 

419 exact : `bool` 

420 If `False` strip all iteration-engine operations, even those that 

421 can affect the number of rows. 

422 

423 Returns 

424 ------- 

425 modified : `Relation` 

426 Modified relation tree. 

427 """ 

428 if relation.payload is not None or relation.is_locked: 

429 return relation 

430 match relation: 

431 case UnaryOperationRelation(operation=operation, target=target): 

432 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine): 

433 return self._strip_count_invariant_operations(target, exact) 

434 case Transfer(target=target): 

435 return self._strip_count_invariant_operations(target, exact) 

436 case MarkerRelation(target=target): 

437 return relation.reapply(self._strip_count_invariant_operations(target, exact)) 

438 return relation 

439 

440 

441class _SqlRowIterable(iteration.RowIterable): 

442 """An implementation of `lsst.daf.relation.iteration.RowIterable` that 

443 executes a SQL query. 

444 

445 Parameters 

446 ---------- 

447 context : `SqlQueryContext` 

448 Context to execute the query with. If this context has already been 

449 entered, a lazy iterable will be returned that may or may not use 

450 server side cursors, since the context's lifetime can be used to manage 

451 that cursor's lifetime. If the context has not been entered, all 

452 results will be fetched up front in raw form, but will still be 

453 processed into mappings keyed by `ColumnTag` lazily. 

454 sql_executable : `sqlalchemy.sql.expression.SelectBase` 

455 SQL query to execute; assumed to have been built by 

456 `lsst.daf.relation.sql.Engine.to_executable` or similar. 

457 row_transformer : `_SqlRowTransformer` 

458 Object that converts SQLAlchemy result-row mappings (with `str` 

459 keys and possibly-unpacked timespan values) to relation row 

460 mappings (with `ColumnTag` keys and `LogicalColumn` values). 

461 """ 

462 

463 def __init__( 

464 self, 

465 context: SqlQueryContext, 

466 sql_executable: sqlalchemy.sql.expression.SelectBase, 

467 row_transformer: _SqlRowTransformer, 

468 ): 

469 self._context = context 

470 self._sql_executable = sql_executable 

471 self._row_transformer = row_transformer 

472 

473 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]: 

474 if self._context._exit_stack is None: 

475 # Have to read results into memory and close database connection. 

476 with self._context._db.query(self._sql_executable) as sql_result: 

477 rows = sql_result.mappings().fetchall() 

478 for sql_row in rows: 

479 yield self._row_transformer.sql_to_relation(sql_row) 

480 else: 

481 with self._context._db.query(self._sql_executable) as sql_result: 

482 raw_rows = sql_result.mappings() 

483 for sql_row in raw_rows: 

484 yield self._row_transformer.sql_to_relation(sql_row) 

485 

486 

487class _SqlRowTransformer: 

488 """Object that converts SQLAlchemy result-row mappings to relation row 

489 mappings. 

490 

491 Parameters 

492 ---------- 

493 columns : `~collections.abc.Iterable` [ `ColumnTag` ] 

494 Set of columns to handle. Rows must have at least these columns, but 

495 may have more. 

496 engine : `ButlerSqlEngine` 

497 Relation engine; used to transform column tags into SQL identifiers and 

498 obtain column type information. 

499 """ 

500 

501 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine): 

502 self._scalar_columns: list[tuple[str, ColumnTag]] = [] 

503 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = [] 

504 for tag in columns: 

505 if is_timespan_column(tag): 

506 self._timespan_columns.append( 

507 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls) 

508 ) 

509 else: 

510 self._scalar_columns.append((engine.get_identifier(tag), tag)) 

511 self._has_no_columns = not (self._scalar_columns or self._timespan_columns) 

512 

513 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns") 

514 

515 def sql_to_relation(self, sql_row: sqlalchemy.RowMapping) -> dict[ColumnTag, Any]: 

516 """Convert a result row from a SQLAlchemy result into the form expected 

517 by `lsst.daf.relation.iteration`. 

518 

519 Parameters 

520 ---------- 

521 sql_row : `~collections.abc.Mapping` 

522 Mapping with `str` keys and possibly-unpacked values for timespan 

523 columns. 

524 

525 Returns 

526 ------- 

527 relation_row : `~collections.abc.Mapping` 

528 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

529 columns. 

530 """ 

531 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns} 

532 for identifier, tag, db_repr_cls in self._timespan_columns: 

533 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier) 

534 return relation_row 

535 

536 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]: 

537 """Convert a `lsst.daf.relation.iteration` row into the mapping form 

538 used by SQLAlchemy. 

539 

540 Parameters 

541 ---------- 

542 relation_row : `~collections.abc.Mapping` 

543 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

544 columns. 

545 

546 Returns 

547 ------- 

548 sql_row : `~collections.abc.Mapping` 

549 Mapping with `str` keys and possibly-unpacked values for timespan 

550 columns. 

551 """ 

552 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns} 

553 for identifier, tag, db_repr_cls in self._timespan_columns: 

554 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier) 

555 if self._has_no_columns: 

556 sql_row["IGNORED"] = None 

557 return sql_row