Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 12%

235 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-11 02:06 -0800

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("SqlQueryContext",) 

24 

25import itertools 

26from collections.abc import Iterable, Iterator, Mapping, Set 

27from contextlib import ExitStack 

28from typing import TYPE_CHECKING, Any, cast 

29 

30import sqlalchemy 

31from lsst.daf.relation import ( 

32 BinaryOperationRelation, 

33 Calculation, 

34 Chain, 

35 ColumnTag, 

36 Deduplication, 

37 Engine, 

38 EngineError, 

39 Join, 

40 MarkerRelation, 

41 Projection, 

42 Relation, 

43 Transfer, 

44 UnaryOperation, 

45 UnaryOperationRelation, 

46 iteration, 

47 sql, 

48) 

49 

50from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column 

51from ..nameShrinker import NameShrinker 

52from ._query_context import QueryContext 

53from .butler_sql_engine import ButlerSqlEngine 

54 

55if TYPE_CHECKING: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true

56 from ..interfaces import Database 

57 

58 

59class SqlQueryContext(QueryContext): 

60 """An implementation of `sql.QueryContext` for `SqlRegistry`. 

61 

62 Parameters 

63 ---------- 

64 db : `Database` 

65 Object that abstracts the database engine. 

66 sql_engine : `ButlerSqlEngine` 

67 Information about column types that can vary with registry 

68 configuration. 

69 row_chunk_size : `int`, optional 

70 Number of rows to insert into temporary tables at once. If this is 

71 lower than ``db.get_constant_rows_max()`` it will be set to that value. 

72 """ 

73 

74 def __init__( 

75 self, 

76 db: Database, 

77 column_types: ColumnTypeInfo, 

78 row_chunk_size: int = 1000, 

79 ): 

80 super().__init__() 

81 self.sql_engine = ButlerSqlEngine( 

82 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length) 

83 ) 

84 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max()) 

85 self._db = db 

86 self._exit_stack: ExitStack | None = None 

87 

88 def __enter__(self) -> SqlQueryContext: 

89 assert self._exit_stack is None, "Context manager already entered." 

90 self._exit_stack = ExitStack().__enter__() 

91 self._exit_stack.enter_context(self._db.session()) 

92 return self 

93 

94 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: 

95 assert self._exit_stack is not None, "Context manager not yet entered." 

96 result = self._exit_stack.__exit__(exc_type, exc_value, traceback) 

97 self._exit_stack = None 

98 return result 

99 

100 @property 

101 def is_open(self) -> bool: 

102 return self._exit_stack is not None 

103 

104 @property 

105 def column_types(self) -> ColumnTypeInfo: 

106 """Information about column types that depend on registry configuration 

107 (`ColumnTypeInfo`). 

108 """ 

109 return self.sql_engine.column_types 

110 

111 @property 

112 def preferred_engine(self) -> Engine: 

113 # Docstring inherited. 

114 return self.sql_engine 

115 

116 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int: 

117 # Docstring inherited. 

118 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset()) 

119 if relation.engine == self.sql_engine: 

120 sql_executable = self.sql_engine.to_executable( 

121 relation, extra_columns=[sqlalchemy.sql.func.count()] 

122 ) 

123 with self._db.query(sql_executable) as sql_result: 

124 return cast(int, sql_result.scalar_one()) 

125 elif (rows := relation.payload) is not None: 

126 assert isinstance( 

127 rows, iteration.MaterializedRowIterable 

128 ), "Query guarantees that only materialized payloads are attached to its relations." 

129 return len(rows) 

130 elif discard: 

131 n = 0 

132 for _ in self.fetch_iterable(relation): 

133 n += 1 

134 return n 

135 else: 

136 raise RuntimeError( 

137 f"Query with relation {relation} has deferred operations that " 

138 "must be executed in Python in order to obtain an exact " 

139 "count. Pass discard=True to run the query and discard the " 

140 "result rows while counting them, run the query first, or " 

141 "pass exact=False to obtain an upper bound." 

142 ) 

143 

144 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool: 

145 # Docstring inherited. 

146 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1] 

147 if relation.engine == self.sql_engine: 

148 sql_executable = self.sql_engine.to_executable(relation) 

149 with self._db.query(sql_executable) as sql_result: 

150 return sql_result.one_or_none() is not None 

151 elif (rows := relation.payload) is not None: 

152 assert isinstance( 

153 rows, iteration.MaterializedRowIterable 

154 ), "Query guarantees that only materialized payloads are attached to its relations." 

155 return bool(rows) 

156 elif execute: 

157 for _ in self.fetch_iterable(relation): 

158 return True 

159 return False 

160 else: 

161 raise RuntimeError( 

162 f"Query with relation {relation} has deferred operations that " 

163 "must be executed in Python in order to obtain an exact " 

164 "check for whether any rows would be returned. Pass " 

165 "execute=True to run the query until at least " 

166 "one row is found, run the query first, or pass exact=False " 

167 "test only whether the query _might_ have result rows." 

168 ) 

169 

170 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any: 

171 # Docstring inherited from lsst.daf.relation.Processor. 

172 if source.engine == self.sql_engine and destination == self.iteration_engine: 

173 return self._sql_to_iteration(source, materialize_as) 

174 if source.engine == self.iteration_engine and destination == self.sql_engine: 

175 return self._iteration_to_sql(source, materialize_as) 

176 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.") 

177 

178 def materialize(self, target: Relation, name: str) -> Any: 

179 # Docstring inherited from lsst.daf.relation.Processor. 

180 if target.engine == self.sql_engine: 

181 sql_executable = self.sql_engine.to_executable(target) 

182 table_spec = self.column_types.make_relation_table_spec(target.columns) 

183 if self._exit_stack is None: 

184 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

185 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name)) 

186 self._db.insert(table, select=sql_executable) 

187 payload = sql.Payload[LogicalColumn](table) 

188 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns) 

189 return payload 

190 return super().materialize(target, name) 

191 

192 def restore_columns( 

193 self, 

194 relation: Relation, 

195 columns_required: Set[ColumnTag], 

196 ) -> tuple[Relation, set[ColumnTag]]: 

197 # Docstring inherited. 

198 if relation.is_locked: 

199 return relation, set(relation.columns & columns_required) 

200 match relation: 

201 case UnaryOperationRelation(operation=operation, target=target): 

202 new_target, columns_found = self.restore_columns(target, columns_required) 

203 if not columns_found: 

204 return relation, columns_found 

205 match operation: 

206 case Projection(columns=columns): 

207 new_columns = columns | columns_found 

208 if new_columns != columns: 

209 return ( 

210 Projection(new_columns).apply(new_target), 

211 columns_found, 

212 ) 

213 case Calculation(tag=tag): 

214 if tag in columns_required: 

215 columns_found.add(tag) 

216 case Deduplication(): 

217 # Pulling key columns through the deduplication would 

218 # fundamentally change what it does; have to limit 

219 # ourselves to non-key columns here, and put a 

220 # Projection back in place. 

221 columns_found = {c for c in columns_found if not c.is_key} 

222 new_target = new_target.with_only_columns(relation.columns | columns_found) 

223 return relation.reapply(new_target), columns_found 

224 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs): 

225 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required) 

226 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required) 

227 match operation: 

228 case Join(): 

229 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs 

230 case Chain(): 

231 if columns_found_lhs != columns_found_rhs: 

232 # Got different answers for different join 

233 # branches; let's try again with just columns found 

234 # in both branches. 

235 new_columns_required = columns_found_lhs & columns_found_rhs 

236 if not new_columns_required: 

237 return relation, set() 

238 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required) 

239 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required) 

240 assert columns_found_lhs == columns_found_rhs 

241 return relation.reapply(new_lhs, new_rhs), columns_found_lhs 

242 case MarkerRelation(target=target): 

243 new_target, columns_found = self.restore_columns(target, columns_required) 

244 return relation.reapply(new_target), columns_found 

245 raise AssertionError("Match should be exhaustive and all branches should return.") 

246 

247 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]: 

248 # Docstring inherited. 

249 if relation.engine != self.iteration_engine or relation.is_locked: 

250 return relation, [] 

251 match relation: 

252 case UnaryOperationRelation(operation=operation, target=target): 

253 new_target, stripped = self.strip_postprocessing(target) 

254 stripped.append(operation) 

255 return new_target, stripped 

256 case Transfer(destination=self.iteration_engine, target=target): 

257 return target, [] 

258 return relation, [] 

259 

260 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation: 

261 # Docstring inherited. 

262 if relation.engine != self.iteration_engine or relation.is_locked: 

263 return relation 

264 match relation: 

265 case UnaryOperationRelation(operation=operation, target=target): 

266 columns_added_here = relation.columns - target.columns 

267 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here) 

268 if operation.columns_required <= new_columns: 

269 # This operation can stay. 

270 return relation.reapply(new_target) 

271 else: 

272 # This operation needs to be dropped, as we're about to 

273 # project away one more more of the columns it needs. 

274 return new_target 

275 return relation 

276 

277 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable: 

278 """Execute a relation transfer from the SQL engine to the iteration 

279 engine. 

280 

281 Parameters 

282 ---------- 

283 source : `~lsst.daf.relation.Relation` 

284 Input relation, in what is assumed to be `sql_engine`. 

285 materialize_as : `str` or `None` 

286 If not `None`, the name of a persistent materialization to apply 

287 in the iteration engine. This fetches all rows up front. 

288 

289 Returns 

290 ------- 

291 destination : `~lsst.daf.relation.iteration.RowIterable` 

292 Iteration engine payload iterable with the same content as the 

293 given SQL relation. 

294 """ 

295 sql_executable = self.sql_engine.to_executable(source) 

296 rows: iteration.RowIterable = _SqlRowIterable( 

297 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine) 

298 ) 

299 if materialize_as is not None: 

300 rows = rows.materialized() 

301 return rows 

302 

303 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload: 

304 """Execute a relation transfer from the iteration engine to the SQL 

305 engine. 

306 

307 Parameters 

308 ---------- 

309 source : `~lsst.daf.relation.Relation` 

310 Input relation, in what is assumed to be `iteration_engine`. 

311 materialize_as : `str` or `None` 

312 If not `None`, the name of a persistent materialization to apply 

313 in the SQL engine. This sets the name of the temporary table 

314 and ensures that one is used (instead of `Database.constant_rows`, 

315 which might otherwise be used for small row sets). 

316 

317 Returns 

318 ------- 

319 destination : `~lsst.daf.relation.sql.Payload` 

320 SQL engine payload struct with the same content as the given 

321 iteration relation. 

322 """ 

323 iterable = self.iteration_engine.execute(source) 

324 table_spec = self.column_types.make_relation_table_spec(source.columns) 

325 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine) 

326 sql_rows = [] 

327 iterator = iter(iterable) 

328 # Iterate over the first chunk of rows and transform them to hold only 

329 # types recognized by SQLAlchemy (i.e. expand out 

330 # TimespanDatabaseRepresentations). Note that this advances 

331 # `iterator`. 

332 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

333 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

334 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload") 

335 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max(): 

336 if self._exit_stack is None: 

337 raise RuntimeError("This operation requires the QueryContext to have been entered.") 

338 # Either we're being asked to insert into a temporary table with a 

339 # "persistent" lifetime (the duration of the QueryContext), or we 

340 # have enough rows that the database says we can't use its 

341 # "constant_rows" construct (e.g. VALUES). Note that 

342 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is 

343 # greater than or equal to `Database.get_constant_rows_max`. 

344 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name)) 

345 while sql_rows: 

346 self._db.insert(table, *sql_rows) 

347 sql_rows.clear() 

348 for relation_row in itertools.islice(iterator, self._row_chunk_size): 

349 sql_rows.append(row_transformer.relation_to_sql(relation_row)) 

350 payload = sql.Payload[LogicalColumn](table) 

351 else: 

352 # Small number of rows; use Database.constant_rows. 

353 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows)) 

354 payload.columns_available = self.sql_engine.extract_mapping( 

355 source.columns, payload.from_clause.columns 

356 ) 

357 return payload 

358 

359 def _strip_empty_invariant_operations( 

360 self, 

361 relation: Relation, 

362 exact: bool, 

363 ) -> Relation: 

364 """Return a modified relation tree that strips out relation operations 

365 that do not affect whether there are any rows from the end of the tree. 

366 

367 Parameters 

368 ---------- 

369 relation : `Relation` 

370 Original relation tree. 

371 exact : `bool` 

372 If `False` strip all iteration-engine operations, even those that 

373 can remove all rows. 

374 

375 Returns 

376 ------- 

377 modified : `Relation` 

378 Modified relation tree. 

379 """ 

380 if relation.payload is not None or relation.is_locked: 

381 return relation 

382 match relation: 

383 case UnaryOperationRelation(operation=operation, target=target): 

384 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine): 

385 return self._strip_empty_invariant_operations(target, exact) 

386 case Transfer(target=target): 

387 return self._strip_empty_invariant_operations(target, exact) 

388 case MarkerRelation(target=target): 

389 return relation.reapply(self._strip_empty_invariant_operations(target, exact)) 

390 return relation 

391 

392 def _strip_count_invariant_operations( 

393 self, 

394 relation: Relation, 

395 exact: bool, 

396 ) -> Relation: 

397 """Return a modified relation tree that strips out relation operations 

398 that do not affect the number of rows from the end of the tree. 

399 

400 Parameters 

401 ---------- 

402 relation : `Relation` 

403 Original relation tree. 

404 exact : `bool` 

405 If `False` strip all iteration-engine operations, even those that 

406 can affect the number of rows. 

407 

408 Returns 

409 ------- 

410 modified : `Relation` 

411 Modified relation tree. 

412 """ 

413 if relation.payload is not None or relation.is_locked: 

414 return relation 

415 match relation: 

416 case UnaryOperationRelation(operation=operation, target=target): 

417 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine): 

418 return self._strip_count_invariant_operations(target, exact) 

419 case Transfer(target=target): 

420 return self._strip_count_invariant_operations(target, exact) 

421 case MarkerRelation(target=target): 

422 return relation.reapply(self._strip_count_invariant_operations(target, exact)) 

423 return relation 

424 

425 

426class _SqlRowIterable(iteration.RowIterable): 

427 """An implementation of `lsst.daf.relation.iteration.RowIterable` that 

428 executes a SQL query. 

429 

430 Parameters 

431 ---------- 

432 context : `SqlQueryContext` 

433 Context to execute the query with. If this context has already been 

434 entered, a lazy iterable will be returned that may or may not use 

435 server side cursors, since the context's lifetime can be used to manage 

436 that cursor's lifetime. If the context has not been entered, all 

437 results will be fetched up front in raw form, but will still be 

438 processed into mappings keyed by `ColumnTag` lazily. 

439 sql_executable : `sqlalchemy.sql.expression.SelectBase` 

440 SQL query to execute; assumed to have been built by 

441 `lsst.daf.relation.sql.Engine.to_executable` or similar. 

442 row_transformer : `_SqlRowTransformer` 

443 Object that converts SQLAlchemy result-row mappings (with `str` 

444 keys and possibly-unpacked timespan values) to relation row 

445 mappings (with `ColumnTag` keys and `LogicalColumn` values). 

446 """ 

447 

448 def __init__( 

449 self, 

450 context: SqlQueryContext, 

451 sql_executable: sqlalchemy.sql.expression.SelectBase, 

452 row_transformer: _SqlRowTransformer, 

453 ): 

454 self._context = context 

455 self._sql_executable = sql_executable 

456 self._row_transformer = row_transformer 

457 

458 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]: 

459 with self._context._db.query(self._sql_executable) as sql_result: 

460 raw_rows = sql_result.mappings() 

461 if self._context._exit_stack is None: 

462 for sql_row in raw_rows.fetchall(): 

463 yield self._row_transformer.sql_to_relation(sql_row) 

464 else: 

465 for sql_row in raw_rows: 

466 yield self._row_transformer.sql_to_relation(sql_row) 

467 

468 

469class _SqlRowTransformer: 

470 """Object that converts SQLAlchemy result-row mappings to relation row 

471 mappings. 

472 

473 Parameters 

474 ---------- 

475 columns : `Iterable` [ `ColumnTag` ] 

476 Set of columns to handle. Rows must have at least these columns, but 

477 may have more. 

478 engine : `ButlerSqlEngine` 

479 Relation engine; used to transform column tags into SQL identifiers and 

480 obtain column type information. 

481 """ 

482 

483 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine): 

484 self._scalar_columns: list[tuple[str, ColumnTag]] = [] 

485 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = [] 

486 for tag in columns: 

487 if is_timespan_column(tag): 

488 self._timespan_columns.append( 

489 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls) 

490 ) 

491 else: 

492 self._scalar_columns.append((engine.get_identifier(tag), tag)) 

493 self._has_no_columns = not (self._scalar_columns or self._timespan_columns) 

494 

495 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns") 

496 

497 def sql_to_relation(self, sql_row: Mapping[str, Any]) -> dict[ColumnTag, Any]: 

498 """Convert a result row from a SQLAlchemy result into the form expected 

499 by `lsst.daf.relation.iteration`. 

500 

501 Parameters 

502 ---------- 

503 sql_row : `Mapping` 

504 Mapping with `str` keys and possibly-unpacked values for timespan 

505 columns. 

506 

507 Returns 

508 ------- 

509 relation_row : `Mapping` 

510 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

511 columns. 

512 """ 

513 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns} 

514 for identifier, tag, db_repr_cls in self._timespan_columns: 

515 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier) 

516 return relation_row 

517 

518 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]: 

519 """Convert a `lsst.daf.relation.iteration` row into the mapping form 

520 used by SQLAlchemy. 

521 

522 Parameters 

523 ---------- 

524 relation_row : `Mapping` 

525 Mapping with `ColumnTag` keys and `Timespan` objects for timespan 

526 columns. 

527 

528 Returns 

529 ------- 

530 sql_row : `Mapping` 

531 Mapping with `str` keys and possibly-unpacked values for timespan 

532 columns. 

533 """ 

534 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns} 

535 for identifier, tag, db_repr_cls in self._timespan_columns: 

536 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier) 

537 if self._has_no_columns: 

538 sql_row["IGNORED"] = None 

539 return sql_row