Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 11%

236 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-09 02:11 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("SqlQueryContext",) 

24 

25import itertools 

26from collections.abc import Iterable, Iterator, Mapping, Set 

27from contextlib import ExitStack 

28from typing import TYPE_CHECKING, Any, cast 

29 

30import sqlalchemy 

31from lsst.daf.relation import ( 

32 BinaryOperationRelation, 

33 Calculation, 

34 Chain, 

35 ColumnTag, 

36 Deduplication, 

37 Engine, 

38 EngineError, 

39 Join, 

40 MarkerRelation, 

41 Projection, 

42 Relation, 

43 Transfer, 

44 UnaryOperation, 

45 UnaryOperationRelation, 

46 iteration, 

47 sql, 

48) 

49 

50from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column 

51from ..nameShrinker import NameShrinker 

52from ._query_context import QueryContext 

53from .butler_sql_engine import ButlerSqlEngine 

54 

55if TYPE_CHECKING: 

56 from ..interfaces import Database 

57 

58 

59class SqlQueryContext(QueryContext): 

60 """An implementation of `sql.QueryContext` for `SqlRegistry`. 

61 

62 Parameters 

63 ---------- 

64 db : `Database` 

65 Object that abstracts the database engine. 

66 sql_engine : `ButlerSqlEngine` 

67 Information about column types that can vary with registry 

68 configuration. 

69 row_chunk_size : `int`, optional 

70 Number of rows to insert into temporary tables at once. If this is 

71 lower than ``db.get_constant_rows_max()`` it will be set to that value. 

72 """ 

73 

74 def __init__( 

75 self, 

76 db: Database, 

77 column_types: ColumnTypeInfo, 

78 row_chunk_size: int = 1000, 

79 ): 

80 super().__init__() 

81 self.sql_engine = ButlerSqlEngine( 

82 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length) 

83 ) 

84 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max()) 

85 self._db = db 

86 self._exit_stack: ExitStack | None = None 

87 

88 def __enter__(self) -> SqlQueryContext: 

89 assert self._exit_stack is None, "Context manager already entered." 

90 self._exit_stack = ExitStack().__enter__() 

91 self._exit_stack.enter_context(self._db.session()) 

92 return self 

93 

94 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: 

95 assert self._exit_stack is not None, "Context manager not yet entered." 

96 result = self._exit_stack.__exit__(exc_type, exc_value, traceback) 

97 self._exit_stack = None 

98 return result 

99 

100 @property 

101 def is_open(self) -> bool: 

102 return self._exit_stack is not None 

103 

104 @property 

105 def column_types(self) -> ColumnTypeInfo: 

106 """Information about column types that depend on registry configuration 

107 (`ColumnTypeInfo`). 

108 """ 

109 return self.sql_engine.column_types 

110 

111 @property 

112 def preferred_engine(self) -> Engine: 

113 # Docstring inherited. 

114 return self.sql_engine 

115 

116 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int: 

117 # Docstring inherited. 

118 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset()) 

119 if relation.engine == self.sql_engine: 

120 sql_executable = self.sql_engine.to_executable( 

121 relation, extra_columns=[sqlalchemy.sql.func.count()] 

122 ) 

123 with self._db.query(sql_executable) as sql_result: 

124 return cast(int, sql_result.scalar_one()) 

125 elif (rows := relation.payload) is not None: 

126 assert isinstance( 

127 rows, iteration.MaterializedRowIterable 

128 ), "Query guarantees that only materialized payloads are attached to its relations." 

129 return len(rows) 

130 elif discard: 

131 n = 0 

132 for _ in self.fetch_iterable(relation): 

133 n += 1 

134 return n 

135 else: 

136 raise RuntimeError( 

137 f"Query with relation {relation} has deferred operations that " 

138 "must be executed in Python in order to obtain an exact " 

139 "count. Pass discard=True to run the query and discard the " 

140 "result rows while counting them, run the query first, or " 

141 "pass exact=False to obtain an upper bound." 

142 ) 

143 

144 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool: 

145 # Docstring inherited. 

146 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1] 

147 if relation.engine == self.sql_engine: 

148 sql_executable = self.sql_engine.to_executable(relation) 

149 with self._db.query(sql_executable) as sql_result: 

150 return sql_result.one_or_none() is not None 

151 elif (rows := relation.payload) is not None: 

152 assert isinstance( 

153 rows, iteration.MaterializedRowIterable 

154 ), "Query guarantees that only materialized payloads are attached to its relations." 

155 return bool(rows) 

156 elif execute: 

157 for _ in self.fetch_iterable(relation): 

158 return True 

159 return False 

160 else: 

161 raise RuntimeError( 

162 f"Query with relation {relation} has deferred operations that " 

163 "must be executed in Python in order to obtain an exact " 

164 "check for whether any rows would be returned. Pass " 

165 "execute=True to run the query until at least " 

166 "one row is found, run the query first, or pass exact=False " 

167 "test only whether the query _might_ have result rows." 

168 ) 

169 

170 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any: 

171 # Docstring inherited from lsst.daf.relation.Processor. 

172 if source.engine == self.sql_engine and destination == self.iteration_engine: 

173 return self._sql_to_iteration(source, materialize_as) 

174 if source.engine == self.iteration_engine and destination == self.sql_engine: 

175 return self._iteration_to_sql(source, materialize_as) 

176 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.") 

177 

178 def materialize(self, target: Relation, name: str) -> Any: 

179 # Docstring inherited from lsst.daf.relation.Processor. 

180 if target.engine == self.sql_engine: 

181 sql_executable = self.sql_engine.to_executable(target) 

182 table_spec = self.column_types.make_relation_table_spec(target.columns) 

183 if self._exit_stack is None: 

184 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

185 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name)) 

186 self._db.insert(table, select=sql_executable) 

187 payload = sql.Payload[LogicalColumn](table) 

188 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns) 

189 return payload 

190 return super().materialize(target, name) 

191 

192 def restore_columns( 

193 self, 

194 relation: Relation, 

195 columns_required: Set[ColumnTag], 

196 ) -> tuple[Relation, set[ColumnTag]]: 

197 # Docstring inherited. 

198 if relation.is_locked: 

199 return relation, set(relation.columns & columns_required) 

200 match relation: 

201 case UnaryOperationRelation(operation=operation, target=target): 

202 new_target, columns_found = self.restore_columns(target, columns_required) 

203 if not columns_found: 

204 return relation, columns_found 

205 match operation: 

206 case Projection(columns=columns): 

207 new_columns = columns | columns_found 

208 if new_columns != columns: 

209 return ( 

210 Projection(new_columns).apply(new_target), 

211 columns_found, 

212 ) 

213 case Calculation(tag=tag): 

214 if tag in columns_required: 

215 columns_found.add(tag) 

216 case Deduplication(): 

217 # Pulling key columns through the deduplication would 

218 # fundamentally change what it does; have to limit 

219 # ourselves to non-key columns here, and put a 

220 # Projection back in place. 

221 columns_found = {c for c in columns_found if not c.is_key} 

222 new_target = new_target.with_only_columns(relation.columns | columns_found) 

223 return relation.reapply(new_target), columns_found 

224 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs): 

225 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required) 

226 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required) 

227 match operation: 

228 case Join(): 

229 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs 

230 case Chain(): 

231 if columns_found_lhs != columns_found_rhs: 

232 # Got different answers for different join 

233 # branches; let's try again with just columns found 

234 # in both branches. 

235 new_columns_required = columns_found_lhs & columns_found_rhs 

236 if not new_columns_required: 

237 return relation, set() 

238 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required) 

239 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required) 

240 assert columns_found_lhs == columns_found_rhs 

241 return relation.reapply(new_lhs, new_rhs), columns_found_lhs 

242 case MarkerRelation(target=target): 

243 new_target, columns_found = self.restore_columns(target, columns_required) 

244 return relation.reapply(new_target), columns_found 

245 raise AssertionError("Match should be exhaustive and all branches should return.") 

246 

247 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]: 

248 # Docstring inherited. 

249 if relation.engine != self.iteration_engine or relation.is_locked: 

250 return relation, [] 

251 match relation: 

252 case UnaryOperationRelation(operation=operation, target=target): 

253 new_target, stripped = self.strip_postprocessing(target) 

254 stripped.append(operation) 

255 return new_target, stripped 

256 case Transfer(destination=self.iteration_engine, target=target): 

257 return target, [] 

258 return relation, [] 

259 

260 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation: 

261 # Docstring inherited. 

262 if relation.engine != self.iteration_engine or relation.is_locked: 

263 return relation 

264 match relation: 

265 case UnaryOperationRelation(operation=operation, target=target): 

266 columns_added_here = relation.columns - target.columns 

267 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here) 

268 if operation.columns_required <= new_columns: 

269 # This operation can stay... 

270 if new_target is target: 

271 # ...and nothing has actually changed upstream of it. 

272 return relation 

273 else: 

274 # ...but we should see if it can now be performed in 

275 # the preferred engine, since something it didn't 

276 # commute with may have been removed. 

277 return operation.apply(new_target, preferred_engine=self.preferred_engine) 

278 else: 

279 # This operation needs to be dropped, as we're about to 

280 # project away one more more of the columns it needs. 

281 return new_target 

282 return relation 

283 

284 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable: 

285 """Execute a relation transfer from the SQL engine to the iteration 

286 engine. 

287 

288 Parameters 

289 ---------- 

290 source : `~lsst.daf.relation.Relation` 

291 Input relation, in what is assumed to be `sql_engine`. 

292 materialize_as : `str` or `None` 

293 If not `None`, the name of a persistent materialization to apply 

294 in the iteration engine. This fetches all rows up front. 

295 

296 Returns 

297 ------- 

298 destination : `~lsst.daf.relation.iteration.RowIterable` 

299 Iteration engine payload iterable with the same content as the 

300 given SQL relation. 

301 """ 

302 sql_executable = self.sql_engine.to_executable(source) 

303 rows: iteration.RowIterable = _SqlRowIterable( 

304 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine) 

305 ) 

306 if materialize_as is not None: 

307 rows = rows.materialized() 

308 return rows 

309 

310 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload: 

311 """Execute a relation transfer from the iteration engine to the SQL 

312 engine. 

313 

314 Parameters 

315 ---------- 

316 source : `~lsst.daf.relation.Relation` 

317 Input relation, in what is assumed to be `iteration_engine`. 

318 materialize_as : `str` or `None` 

319 If not `None`, the name of a persistent materialization to apply 

320 in the SQL engine. This sets the name of the temporary table 

321 and ensures that one is used (instead of `Database.constant_rows`, 

322 which might otherwise be used for small row sets). 

323 

324 Returns 

325 ------- 

326 destination : `~lsst.daf.relation.sql.Payload` 

327 SQL engine payload struct with the same content as the given 

328 iteration relation. 

329 """ 

330 iterable = self.iteration_engine.execute(source) 

331 table_spec = self.column_types.make_relation_table_spec(source.columns) 

332 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine) 

333 sql_rows = [] 

334 iterator = iter(iterable) 

335 # Iterate over the first chunk of rows and transform them to hold only 

336 # types recognized by SQLAlchemy (i.e. expand out 

337 # TimespanDatabaseRepresentations). Note that this advances 

338 # `iterator`. 

339 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

340 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

341 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload") 

342 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max(): 

343 if self._exit_stack is None: 

344 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

345 # Either we're being asked to insert into a temporary table with a 

346 # "persistent" lifetime (the duration of the QueryContext), or we 

347 # have enough rows that the database says we can't use its 

348 # "constant_rows" construct (e.g. VALUES). Note that 

349 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is 

350 # greater than or equal to `Database.get_constant_rows_max`. 

351 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name)) 

352 while sql_rows: 

353 self._db.insert(table, *sql_rows) 

354 sql_rows.clear() 

355 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

356 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

357 payload = sql.Payload[LogicalColumn](table) 

358 else: 

359 # Small number of rows; use Database.constant_rows. 

360 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows)) 

361 payload.columns_available = self.sql_engine.extract_mapping( 

362 source.columns, payload.from_clause.columns 

363 ) 

364 return payload 

365 

366 def _strip_empty_invariant_operations( 

367 self, 

368 relation: Relation, 

369 exact: bool, 

370 ) -> Relation: 

371 """Return a modified relation tree that strips out relation operations 

372 that do not affect whether there are any rows from the end of the tree. 

373 

374 Parameters 

375 ---------- 

376 relation : `Relation` 

377 Original relation tree. 

378 exact : `bool` 

379 If `False` strip all iteration-engine operations, even those that 

380 can remove all rows. 

381 

382 Returns 

383 ------- 

384 modified : `Relation` 

385 Modified relation tree. 

386 """ 

387 if relation.payload is not None or relation.is_locked: 

388 return relation 

389 match relation: 

390 case UnaryOperationRelation(operation=operation, target=target): 

391 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine): 

392 return self._strip_empty_invariant_operations(target, exact) 

393 case Transfer(target=target): 

394 return self._strip_empty_invariant_operations(target, exact) 

395 case MarkerRelation(target=target): 

396 return relation.reapply(self._strip_empty_invariant_operations(target, exact)) 

397 return relation 

398 

399 def _strip_count_invariant_operations( 

400 self, 

401 relation: Relation, 

402 exact: bool, 

403 ) -> Relation: 

404 """Return a modified relation tree that strips out relation operations 

405 that do not affect the number of rows from the end of the tree. 

406 

407 Parameters 

408 ---------- 

409 relation : `Relation` 

410 Original relation tree. 

411 exact : `bool` 

412 If `False` strip all iteration-engine operations, even those that 

413 can affect the number of rows. 

414 

415 Returns 

416 ------- 

417 modified : `Relation` 

418 Modified relation tree. 

419 """ 

420 if relation.payload is not None or relation.is_locked: 

421 return relation 

422 match relation: 

423 case UnaryOperationRelation(operation=operation, target=target): 

424 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine): 

425 return self._strip_count_invariant_operations(target, exact) 

426 case Transfer(target=target): 

427 return self._strip_count_invariant_operations(target, exact) 

428 case MarkerRelation(target=target): 

429 return relation.reapply(self._strip_count_invariant_operations(target, exact)) 

430 return relation 

431 

432 

433class _SqlRowIterable(iteration.RowIterable): 

434 """An implementation of `lsst.daf.relation.iteration.RowIterable` that 

435 executes a SQL query. 

436 

437 Parameters 

438 ---------- 

439 context : `SqlQueryContext` 

440 Context to execute the query with. If this context has already been 

441 entered, a lazy iterable will be returned that may or may not use 

442 server side cursors, since the context's lifetime can be used to manage 

443 that cursor's lifetime. If the context has not been entered, all 

444 results will be fetched up front in raw form, but will still be 

445 processed into mappings keyed by `ColumnTag` lazily. 

446 sql_executable : `sqlalchemy.sql.expression.SelectBase` 

447 SQL query to execute; assumed to have been built by 

448 `lsst.daf.relation.sql.Engine.to_executable` or similar. 

449 row_transformer : `_SqlRowTransformer` 

450 Object that converts SQLAlchemy result-row mappings (with `str` 

451 keys and possibly-unpacked timespan values) to relation row 

452 mappings (with `ColumnTag` keys and `LogicalColumn` values). 

453 """ 

454 

455 def __init__( 

456 self, 

457 context: SqlQueryContext, 

458 sql_executable: sqlalchemy.sql.expression.SelectBase, 

459 row_transformer: _SqlRowTransformer, 

460 ): 

461 self._context = context 

462 self._sql_executable = sql_executable 

463 self._row_transformer = row_transformer 

464 

465 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]: 

466 if self._context._exit_stack is None: 

467 # Have to read results into memory and close database connection. 

468 with self._context._db.query(self._sql_executable) as sql_result: 

469 rows = sql_result.mappings().fetchall() 

470 for sql_row in rows: 

471 yield self._row_transformer.sql_to_relation(sql_row) 

472 else: 

473 with self._context._db.query(self._sql_executable) as sql_result: 

474 raw_rows = sql_result.mappings() 

475 for sql_row in raw_rows: 

476 yield self._row_transformer.sql_to_relation(sql_row) 

477 

478 

479class _SqlRowTransformer: 

480 """Object that converts SQLAlchemy result-row mappings to relation row 

481 mappings. 

482 

483 Parameters 

484 ---------- 

485 columns : `Iterable` [ `ColumnTag` ] 

486 Set of columns to handle. Rows must have at least these columns, but 

487 may have more. 

488 engine : `ButlerSqlEngine` 

489 Relation engine; used to transform column tags into SQL identifiers and 

490 obtain column type information. 

491 """ 

492 

493 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine): 

494 self._scalar_columns: list[tuple[str, ColumnTag]] = [] 

495 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = [] 

496 for tag in columns: 

497 if is_timespan_column(tag): 

498 self._timespan_columns.append( 

499 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls) 

500 ) 

501 else: 

502 self._scalar_columns.append((engine.get_identifier(tag), tag)) 

503 self._has_no_columns = not (self._scalar_columns or self._timespan_columns) 

504 

505 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns") 

506 

507 def sql_to_relation(self, sql_row: sqlalchemy.RowMapping) -> dict[ColumnTag, Any]: 

508 """Convert a result row from a SQLAlchemy result into the form expected 

509 by `lsst.daf.relation.iteration`. 

510 

511 Parameters 

512 ---------- 

513 sql_row : `Mapping` 

514 Mapping with `str` keys and possibly-unpacked values for timespan 

515 columns. 

516 

517 Returns 

518 ------- 

519 relation_row : `Mapping` 

520 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

521 columns. 

522 """ 

523 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns} 

524 for identifier, tag, db_repr_cls in self._timespan_columns: 

525 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier) 

526 return relation_row 

527 

528 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]: 

529 """Convert a `lsst.daf.relation.iteration` row into the mapping form 

530 used by SQLAlchemy. 

531 

532 Parameters 

533 ---------- 

534 relation_row : `Mapping` 

535 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

536 columns. 

537 

538 Returns 

539 ------- 

540 sql_row : `Mapping` 

541 Mapping with `str` keys and possibly-unpacked values for timespan 

542 columns. 

543 """ 

544 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns} 

545 for identifier, tag, db_repr_cls in self._timespan_columns: 

546 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier) 

547 if self._has_no_columns: 

548 sql_row["IGNORED"] = None 

549 return sql_row