Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 13%

236 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 08:00 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("SqlQueryContext",) 

30 

31import itertools 

32from collections.abc import Iterable, Iterator, Mapping, Set 

33from contextlib import ExitStack 

34from typing import TYPE_CHECKING, Any, cast 

35 

36import sqlalchemy 

37from lsst.daf.relation import ( 

38 BinaryOperationRelation, 

39 Calculation, 

40 Chain, 

41 ColumnTag, 

42 Deduplication, 

43 Engine, 

44 EngineError, 

45 Join, 

46 MarkerRelation, 

47 Projection, 

48 Relation, 

49 Transfer, 

50 UnaryOperation, 

51 UnaryOperationRelation, 

52 iteration, 

53 sql, 

54) 

55 

56from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column 

57from ..nameShrinker import NameShrinker 

58from ._query_context import QueryContext 

59from .butler_sql_engine import ButlerSqlEngine 

60 

61if TYPE_CHECKING: 

62 from ..interfaces import Database 

63 

64 

65class SqlQueryContext(QueryContext): 

66 """An implementation of `sql.QueryContext` for `SqlRegistry`. 

67 

68 Parameters 

69 ---------- 

70 db : `Database` 

71 Object that abstracts the database engine. 

72 sql_engine : `ButlerSqlEngine` 

73 Information about column types that can vary with registry 

74 configuration. 

75 row_chunk_size : `int`, optional 

76 Number of rows to insert into temporary tables at once. If this is 

77 lower than ``db.get_constant_rows_max()`` it will be set to that value. 

78 """ 

79 

80 def __init__( 

81 self, 

82 db: Database, 

83 column_types: ColumnTypeInfo, 

84 row_chunk_size: int = 1000, 

85 ): 

86 super().__init__() 

87 self.sql_engine = ButlerSqlEngine( 

88 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length) 

89 ) 

90 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max()) 

91 self._db = db 

92 self._exit_stack: ExitStack | None = None 

93 

94 def __enter__(self) -> SqlQueryContext: 

95 assert self._exit_stack is None, "Context manager already entered." 

96 self._exit_stack = ExitStack().__enter__() 

97 self._exit_stack.enter_context(self._db.session()) 

98 return self 

99 

100 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: 

101 assert self._exit_stack is not None, "Context manager not yet entered." 

102 result = self._exit_stack.__exit__(exc_type, exc_value, traceback) 

103 self._exit_stack = None 

104 return result 

105 

106 @property 

107 def is_open(self) -> bool: 

108 return self._exit_stack is not None 

109 

110 @property 

111 def column_types(self) -> ColumnTypeInfo: 

112 """Information about column types that depend on registry configuration 

113 (`ColumnTypeInfo`). 

114 """ 

115 return self.sql_engine.column_types 

116 

117 @property 

118 def preferred_engine(self) -> Engine: 

119 # Docstring inherited. 

120 return self.sql_engine 

121 

122 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int: 

123 # Docstring inherited. 

124 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset()) 

125 if relation.engine == self.sql_engine: 

126 sql_executable = self.sql_engine.to_executable( 

127 relation, extra_columns=[sqlalchemy.sql.func.count()] 

128 ) 

129 with self._db.query(sql_executable) as sql_result: 

130 return cast(int, sql_result.scalar_one()) 

131 elif (rows := relation.payload) is not None: 

132 assert isinstance( 

133 rows, iteration.MaterializedRowIterable 

134 ), "Query guarantees that only materialized payloads are attached to its relations." 

135 return len(rows) 

136 elif discard: 

137 n = 0 

138 for _ in self.fetch_iterable(relation): 

139 n += 1 

140 return n 

141 else: 

142 raise RuntimeError( 

143 f"Query with relation {relation} has deferred operations that " 

144 "must be executed in Python in order to obtain an exact " 

145 "count. Pass discard=True to run the query and discard the " 

146 "result rows while counting them, run the query first, or " 

147 "pass exact=False to obtain an upper bound." 

148 ) 

149 

150 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool: 

151 # Docstring inherited. 

152 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1] 

153 if relation.engine == self.sql_engine: 

154 sql_executable = self.sql_engine.to_executable(relation) 

155 with self._db.query(sql_executable) as sql_result: 

156 return sql_result.one_or_none() is not None 

157 elif (rows := relation.payload) is not None: 

158 assert isinstance( 

159 rows, iteration.MaterializedRowIterable 

160 ), "Query guarantees that only materialized payloads are attached to its relations." 

161 return bool(rows) 

162 elif execute: 

163 for _ in self.fetch_iterable(relation): 

164 return True 

165 return False 

166 else: 

167 raise RuntimeError( 

168 f"Query with relation {relation} has deferred operations that " 

169 "must be executed in Python in order to obtain an exact " 

170 "check for whether any rows would be returned. Pass " 

171 "execute=True to run the query until at least " 

172 "one row is found, run the query first, or pass exact=False " 

173 "test only whether the query _might_ have result rows." 

174 ) 

175 

176 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any: 

177 # Docstring inherited from lsst.daf.relation.Processor. 

178 if source.engine == self.sql_engine and destination == self.iteration_engine: 

179 return self._sql_to_iteration(source, materialize_as) 

180 if source.engine == self.iteration_engine and destination == self.sql_engine: 

181 return self._iteration_to_sql(source, materialize_as) 

182 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.") 

183 

184 def materialize(self, target: Relation, name: str) -> Any: 

185 # Docstring inherited from lsst.daf.relation.Processor. 

186 if target.engine == self.sql_engine: 

187 sql_executable = self.sql_engine.to_executable(target) 

188 table_spec = self.column_types.make_relation_table_spec(target.columns) 

189 if self._exit_stack is None: 

190 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

191 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name)) 

192 self._db.insert(table, select=sql_executable) 

193 payload = sql.Payload[LogicalColumn](table) 

194 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns) 

195 return payload 

196 return super().materialize(target, name) 

197 

198 def restore_columns( 

199 self, 

200 relation: Relation, 

201 columns_required: Set[ColumnTag], 

202 ) -> tuple[Relation, set[ColumnTag]]: 

203 # Docstring inherited. 

204 if relation.is_locked: 

205 return relation, set(relation.columns & columns_required) 

206 match relation: 

207 case UnaryOperationRelation(operation=operation, target=target): 

208 new_target, columns_found = self.restore_columns(target, columns_required) 

209 if not columns_found: 

210 return relation, columns_found 

211 match operation: 

212 case Projection(columns=columns): 

213 new_columns = columns | columns_found 

214 if new_columns != columns: 

215 return ( 

216 Projection(new_columns).apply(new_target), 

217 columns_found, 

218 ) 

219 case Calculation(tag=tag): 

220 if tag in columns_required: 

221 columns_found.add(tag) 

222 case Deduplication(): 

223 # Pulling key columns through the deduplication would 

224 # fundamentally change what it does; have to limit 

225 # ourselves to non-key columns here, and put a 

226 # Projection back in place. 

227 columns_found = {c for c in columns_found if not c.is_key} 

228 new_target = new_target.with_only_columns(relation.columns | columns_found) 

229 return relation.reapply(new_target), columns_found 

230 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs): 

231 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required) 

232 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required) 

233 match operation: 

234 case Join(): 

235 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs 

236 case Chain(): 

237 if columns_found_lhs != columns_found_rhs: 

238 # Got different answers for different join 

239 # branches; let's try again with just columns found 

240 # in both branches. 

241 new_columns_required = columns_found_lhs & columns_found_rhs 

242 if not new_columns_required: 

243 return relation, set() 

244 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required) 

245 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required) 

246 assert columns_found_lhs == columns_found_rhs 

247 return relation.reapply(new_lhs, new_rhs), columns_found_lhs 

248 case MarkerRelation(target=target): 

249 new_target, columns_found = self.restore_columns(target, columns_required) 

250 return relation.reapply(new_target), columns_found 

251 raise AssertionError("Match should be exhaustive and all branches should return.") 

252 

253 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]: 

254 # Docstring inherited. 

255 if relation.engine != self.iteration_engine or relation.is_locked: 

256 return relation, [] 

257 match relation: 

258 case UnaryOperationRelation(operation=operation, target=target): 

259 new_target, stripped = self.strip_postprocessing(target) 

260 stripped.append(operation) 

261 return new_target, stripped 

262 case Transfer(destination=self.iteration_engine, target=target): 

263 return target, [] 

264 return relation, [] 

265 

266 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation: 

267 # Docstring inherited. 

268 if relation.engine != self.iteration_engine or relation.is_locked: 

269 return relation 

270 match relation: 

271 case UnaryOperationRelation(operation=operation, target=target): 

272 columns_added_here = relation.columns - target.columns 

273 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here) 

274 if operation.columns_required <= new_columns: 

275 # This operation can stay... 

276 if new_target is target: 

277 # ...and nothing has actually changed upstream of it. 

278 return relation 

279 else: 

280 # ...but we should see if it can now be performed in 

281 # the preferred engine, since something it didn't 

282 # commute with may have been removed. 

283 return operation.apply(new_target, preferred_engine=self.preferred_engine) 

284 else: 

285 # This operation needs to be dropped, as we're about to 

286 # project away one more more of the columns it needs. 

287 return new_target 

288 return relation 

289 

290 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable: 

291 """Execute a relation transfer from the SQL engine to the iteration 

292 engine. 

293 

294 Parameters 

295 ---------- 

296 source : `~lsst.daf.relation.Relation` 

297 Input relation, in what is assumed to be `sql_engine`. 

298 materialize_as : `str` or `None` 

299 If not `None`, the name of a persistent materialization to apply 

300 in the iteration engine. This fetches all rows up front. 

301 

302 Returns 

303 ------- 

304 destination : `~lsst.daf.relation.iteration.RowIterable` 

305 Iteration engine payload iterable with the same content as the 

306 given SQL relation. 

307 """ 

308 sql_executable = self.sql_engine.to_executable(source) 

309 rows: iteration.RowIterable = _SqlRowIterable( 

310 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine) 

311 ) 

312 if materialize_as is not None: 

313 rows = rows.materialized() 

314 return rows 

315 

316 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload: 

317 """Execute a relation transfer from the iteration engine to the SQL 

318 engine. 

319 

320 Parameters 

321 ---------- 

322 source : `~lsst.daf.relation.Relation` 

323 Input relation, in what is assumed to be `iteration_engine`. 

324 materialize_as : `str` or `None` 

325 If not `None`, the name of a persistent materialization to apply 

326 in the SQL engine. This sets the name of the temporary table 

327 and ensures that one is used (instead of `Database.constant_rows`, 

328 which might otherwise be used for small row sets). 

329 

330 Returns 

331 ------- 

332 destination : `~lsst.daf.relation.sql.Payload` 

333 SQL engine payload struct with the same content as the given 

334 iteration relation. 

335 """ 

336 iterable = self.iteration_engine.execute(source) 

337 table_spec = self.column_types.make_relation_table_spec(source.columns) 

338 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine) 

339 sql_rows = [] 

340 iterator = iter(iterable) 

341 # Iterate over the first chunk of rows and transform them to hold only 

342 # types recognized by SQLAlchemy (i.e. expand out 

343 # TimespanDatabaseRepresentations). Note that this advances 

344 # `iterator`. 

345 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

346 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

347 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload") 

348 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max(): 

349 if self._exit_stack is None: 

350 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

351 # Either we're being asked to insert into a temporary table with a 

352 # "persistent" lifetime (the duration of the QueryContext), or we 

353 # have enough rows that the database says we can't use its 

354 # "constant_rows" construct (e.g. VALUES). Note that 

355 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is 

356 # greater than or equal to `Database.get_constant_rows_max`. 

357 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name)) 

358 while sql_rows: 

359 self._db.insert(table, *sql_rows) 

360 sql_rows.clear() 

361 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

362 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

363 payload = sql.Payload[LogicalColumn](table) 

364 else: 

365 # Small number of rows; use Database.constant_rows. 

366 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows)) 

367 payload.columns_available = self.sql_engine.extract_mapping( 

368 source.columns, payload.from_clause.columns 

369 ) 

370 return payload 

371 

372 def _strip_empty_invariant_operations( 

373 self, 

374 relation: Relation, 

375 exact: bool, 

376 ) -> Relation: 

377 """Return a modified relation tree that strips out relation operations 

378 that do not affect whether there are any rows from the end of the tree. 

379 

380 Parameters 

381 ---------- 

382 relation : `Relation` 

383 Original relation tree. 

384 exact : `bool` 

385 If `False` strip all iteration-engine operations, even those that 

386 can remove all rows. 

387 

388 Returns 

389 ------- 

390 modified : `Relation` 

391 Modified relation tree. 

392 """ 

393 if relation.payload is not None or relation.is_locked: 

394 return relation 

395 match relation: 

396 case UnaryOperationRelation(operation=operation, target=target): 

397 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine): 

398 return self._strip_empty_invariant_operations(target, exact) 

399 case Transfer(target=target): 

400 return self._strip_empty_invariant_operations(target, exact) 

401 case MarkerRelation(target=target): 

402 return relation.reapply(self._strip_empty_invariant_operations(target, exact)) 

403 return relation 

404 

405 def _strip_count_invariant_operations( 

406 self, 

407 relation: Relation, 

408 exact: bool, 

409 ) -> Relation: 

410 """Return a modified relation tree that strips out relation operations 

411 that do not affect the number of rows from the end of the tree. 

412 

413 Parameters 

414 ---------- 

415 relation : `Relation` 

416 Original relation tree. 

417 exact : `bool` 

418 If `False` strip all iteration-engine operations, even those that 

419 can affect the number of rows. 

420 

421 Returns 

422 ------- 

423 modified : `Relation` 

424 Modified relation tree. 

425 """ 

426 if relation.payload is not None or relation.is_locked: 

427 return relation 

428 match relation: 

429 case UnaryOperationRelation(operation=operation, target=target): 

430 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine): 

431 return self._strip_count_invariant_operations(target, exact) 

432 case Transfer(target=target): 

433 return self._strip_count_invariant_operations(target, exact) 

434 case MarkerRelation(target=target): 

435 return relation.reapply(self._strip_count_invariant_operations(target, exact)) 

436 return relation 

437 

438 

439class _SqlRowIterable(iteration.RowIterable): 

440 """An implementation of `lsst.daf.relation.iteration.RowIterable` that 

441 executes a SQL query. 

442 

443 Parameters 

444 ---------- 

445 context : `SqlQueryContext` 

446 Context to execute the query with. If this context has already been 

447 entered, a lazy iterable will be returned that may or may not use 

448 server side cursors, since the context's lifetime can be used to manage 

449 that cursor's lifetime. If the context has not been entered, all 

450 results will be fetched up front in raw form, but will still be 

451 processed into mappings keyed by `ColumnTag` lazily. 

452 sql_executable : `sqlalchemy.sql.expression.SelectBase` 

453 SQL query to execute; assumed to have been built by 

454 `lsst.daf.relation.sql.Engine.to_executable` or similar. 

455 row_transformer : `_SqlRowTransformer` 

456 Object that converts SQLAlchemy result-row mappings (with `str` 

457 keys and possibly-unpacked timespan values) to relation row 

458 mappings (with `ColumnTag` keys and `LogicalColumn` values). 

459 """ 

460 

461 def __init__( 

462 self, 

463 context: SqlQueryContext, 

464 sql_executable: sqlalchemy.sql.expression.SelectBase, 

465 row_transformer: _SqlRowTransformer, 

466 ): 

467 self._context = context 

468 self._sql_executable = sql_executable 

469 self._row_transformer = row_transformer 

470 

471 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]: 

472 if self._context._exit_stack is None: 

473 # Have to read results into memory and close database connection. 

474 with self._context._db.query(self._sql_executable) as sql_result: 

475 rows = sql_result.mappings().fetchall() 

476 for sql_row in rows: 

477 yield self._row_transformer.sql_to_relation(sql_row) 

478 else: 

479 with self._context._db.query(self._sql_executable) as sql_result: 

480 raw_rows = sql_result.mappings() 

481 for sql_row in raw_rows: 

482 yield self._row_transformer.sql_to_relation(sql_row) 

483 

484 

485class _SqlRowTransformer: 

486 """Object that converts SQLAlchemy result-row mappings to relation row 

487 mappings. 

488 

489 Parameters 

490 ---------- 

491 columns : `~collections.abc.Iterable` [ `ColumnTag` ] 

492 Set of columns to handle. Rows must have at least these columns, but 

493 may have more. 

494 engine : `ButlerSqlEngine` 

495 Relation engine; used to transform column tags into SQL identifiers and 

496 obtain column type information. 

497 """ 

498 

499 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine): 

500 self._scalar_columns: list[tuple[str, ColumnTag]] = [] 

501 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = [] 

502 for tag in columns: 

503 if is_timespan_column(tag): 

504 self._timespan_columns.append( 

505 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls) 

506 ) 

507 else: 

508 self._scalar_columns.append((engine.get_identifier(tag), tag)) 

509 self._has_no_columns = not (self._scalar_columns or self._timespan_columns) 

510 

511 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns") 

512 

513 def sql_to_relation(self, sql_row: sqlalchemy.RowMapping) -> dict[ColumnTag, Any]: 

514 """Convert a result row from a SQLAlchemy result into the form expected 

515 by `lsst.daf.relation.iteration`. 

516 

517 Parameters 

518 ---------- 

519 sql_row : `~collections.abc.Mapping` 

520 Mapping with `str` keys and possibly-unpacked values for timespan 

521 columns. 

522 

523 Returns 

524 ------- 

525 relation_row : `~collections.abc.Mapping` 

526 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

527 columns. 

528 """ 

529 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns} 

530 for identifier, tag, db_repr_cls in self._timespan_columns: 

531 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier) 

532 return relation_row 

533 

534 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]: 

535 """Convert a `lsst.daf.relation.iteration` row into the mapping form 

536 used by SQLAlchemy. 

537 

538 Parameters 

539 ---------- 

540 relation_row : `~collections.abc.Mapping` 

541 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

542 columns. 

543 

544 Returns 

545 ------- 

546 sql_row : `~collections.abc.Mapping` 

547 Mapping with `str` keys and possibly-unpacked values for timespan 

548 columns. 

549 """ 

550 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns} 

551 for identifier, tag, db_repr_cls in self._timespan_columns: 

552 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier) 

553 if self._has_no_columns: 

554 sql_row["IGNORED"] = None 

555 return sql_row