Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 12%
231 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-10 02:33 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-10 02:33 -0800
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("SqlQueryContext",)
25import itertools
26from collections.abc import Iterable, Iterator, Mapping, Set
27from contextlib import ExitStack
28from typing import TYPE_CHECKING, Any
30import sqlalchemy
31from lsst.daf.relation import (
32 BinaryOperationRelation,
33 Calculation,
34 Chain,
35 ColumnTag,
36 Deduplication,
37 Engine,
38 EngineError,
39 Join,
40 MarkerRelation,
41 Projection,
42 Relation,
43 Transfer,
44 UnaryOperation,
45 UnaryOperationRelation,
46 iteration,
47 sql,
48)
50from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column
51from ._query_context import QueryContext
52from .butler_sql_engine import ButlerSqlEngine
54if TYPE_CHECKING: 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true
55 from ..interfaces import Database
58class SqlQueryContext(QueryContext):
59 """An implementation of `sql.QueryContext` for `SqlRegistry`.
61 Parameters
62 ----------
63 db : `Database`
64 Object that abstracts the database engine.
65 sql_engine : `ButlerSqlEngine`
66 Information about column types that can vary with registry
67 configuration.
68 row_chunk_size : `int`, optional
69 Number of rows to insert into temporary tables at once. If this is
70 lower than ``db.get_constant_rows_max()`` it will be set to that value.
71 """
73 def __init__(
74 self,
75 db: Database,
76 column_types: ColumnTypeInfo,
77 row_chunk_size: int = 1000,
78 ):
79 super().__init__()
80 self.sql_engine = ButlerSqlEngine(column_types=column_types)
81 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max())
82 self._db = db
83 self._exit_stack: ExitStack | None = None
85 def __enter__(self) -> SqlQueryContext:
86 assert self._exit_stack is None, "Context manager already entered."
87 self._exit_stack = ExitStack().__enter__()
88 self._exit_stack.enter_context(self._db.session())
89 return self
91 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
92 assert self._exit_stack is not None, "Context manager not yet entered."
93 result = self._exit_stack.__exit__(exc_type, exc_value, traceback)
94 self._exit_stack = None
95 return result
97 @property
98 def is_open(self) -> bool:
99 return self._exit_stack is not None
101 @property
102 def column_types(self) -> ColumnTypeInfo:
103 """Information about column types that depend on registry configuration
104 (`ColumnTypeInfo`).
105 """
106 return self.sql_engine.column_types
108 @property
109 def preferred_engine(self) -> Engine:
110 # Docstring inherited.
111 return self.sql_engine
113 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int:
114 # Docstring inherited.
115 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())
116 if relation.engine == self.sql_engine:
117 sql_executable = self.sql_engine.to_executable(
118 relation, extra_columns=[sqlalchemy.sql.func.count()]
119 )
120 with self._db.query(sql_executable) as sql_result:
121 return sql_result.scalar()
122 elif (rows := relation.payload) is not None:
123 assert isinstance(
124 rows, iteration.MaterializedRowIterable
125 ), "Query guarantees that only materialized payloads are attached to its relations."
126 return len(rows)
127 elif discard:
128 n = 0
129 for _ in self.fetch_iterable(relation):
130 n += 1
131 return n
132 else:
133 raise RuntimeError(
134 f"Query with relation {relation} has deferred operations that "
135 "must be executed in Python in order to obtain an exact "
136 "count. Pass discard=True to run the query and discard the "
137 "result rows while counting them, run the query first, or "
138 "pass exact=False to obtain an upper bound."
139 )
141 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool:
142 # Docstring inherited.
143 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1]
144 if relation.engine == self.sql_engine:
145 sql_executable = self.sql_engine.to_executable(relation)
146 with self._db.query(sql_executable) as sql_result:
147 return sql_result.one_or_none() is not None
148 elif (rows := relation.payload) is not None:
149 assert isinstance(
150 rows, iteration.MaterializedRowIterable
151 ), "Query guarantees that only materialized payloads are attached to its relations."
152 return bool(rows)
153 elif execute:
154 for _ in self.fetch_iterable(relation):
155 return True
156 return False
157 else:
158 raise RuntimeError(
159 f"Query with relation {relation} has deferred operations that "
160 "must be executed in Python in order to obtain an exact "
161 "check for whether any rows would be returned. Pass "
162 "execute=True to run the query until at least "
163 "one row is found, run the query first, or pass exact=False "
164 "test only whether the query _might_ have result rows."
165 )
167 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any:
168 # Docstring inherited from lsst.daf.relation.Processor.
169 if source.engine == self.sql_engine and destination == self.iteration_engine:
170 return self._sql_to_iteration(source, materialize_as)
171 if source.engine == self.iteration_engine and destination == self.sql_engine:
172 return self._iteration_to_sql(source, materialize_as)
173 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.")
175 def materialize(self, target: Relation, name: str) -> Any:
176 # Docstring inherited from lsst.daf.relation.Processor.
177 if target.engine == self.sql_engine:
178 sql_executable = self.sql_engine.to_executable(target)
179 table_spec = self.column_types.make_relation_table_spec(target.columns)
180 if self._exit_stack is None:
181 raise RuntimeError("This operation requires the QueryContext to have been entered.")
182 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name))
183 self._db.insert(table, select=sql_executable)
184 payload = sql.Payload[LogicalColumn](table)
185 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns)
186 return payload
187 return super().materialize(target, name)
189 def restore_columns(
190 self,
191 relation: Relation,
192 columns_required: Set[ColumnTag],
193 ) -> tuple[Relation, set[ColumnTag]]:
194 # Docstring inherited.
195 if relation.is_locked:
196 return relation, set(relation.columns & columns_required)
197 match relation:
198 case UnaryOperationRelation(operation=operation, target=target):
199 new_target, columns_found = self.restore_columns(target, columns_required)
200 if not columns_found:
201 return relation, columns_found
202 match operation:
203 case Projection(columns=columns):
204 new_columns = columns | columns_found
205 if new_columns != columns:
206 return (
207 Projection(new_columns).apply(new_target),
208 columns_found,
209 )
210 case Calculation(tag=tag):
211 if tag in columns_required:
212 columns_found.add(tag)
213 case Deduplication():
214 # Pulling key columns through the deduplication would
215 # fundamentally change what it does; have to limit
216 # ourselves to non-key columns here, and put a
217 # Projection back in place.
218 columns_found = {c for c in columns_found if not c.is_key}
219 new_target = new_target.with_only_columns(relation.columns | columns_found)
220 return relation.reapply(new_target), columns_found
221 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs):
222 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required)
223 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required)
224 match operation:
225 case Join():
226 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs
227 case Chain():
228 if columns_found_lhs != columns_found_rhs:
229 # Got different answers for different join
230 # branches; let's try again with just columns found
231 # in both branches.
232 new_columns_required = columns_found_lhs & columns_found_rhs
233 if not new_columns_required:
234 return relation, set()
235 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required)
236 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required)
237 assert columns_found_lhs == columns_found_rhs
238 return relation.reapply(new_lhs, new_rhs), columns_found_lhs
239 case MarkerRelation(target=target):
240 new_target, columns_found = self.restore_columns(target, columns_required)
241 return relation.reapply(new_target), columns_found
242 raise AssertionError("Match should be exhaustive and all branches should return.")
244 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]:
245 # Docstring inherited.
246 if relation.engine != self.iteration_engine or relation.is_locked:
247 return relation, []
248 match relation:
249 case UnaryOperationRelation(operation=operation, target=target):
250 new_target, stripped = self.strip_postprocessing(target)
251 stripped.append(operation)
252 return new_target, stripped
253 case Transfer(destination=self.iteration_engine, target=target):
254 return target, []
255 return relation, []
257 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation:
258 # Docstring inherited.
259 if relation.engine != self.iteration_engine or relation.is_locked:
260 return relation
261 match relation:
262 case UnaryOperationRelation(operation=operation, target=target):
263 columns_added_here = relation.columns - target.columns
264 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here)
265 if operation.columns_required <= new_columns:
266 # This operation can stay.
267 return relation.reapply(new_target)
268 else:
269 # This operation needs to be dropped, as we're about to
270 # project away one more more of the columns it needs.
271 return new_target
272 return relation
274 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable:
275 """Execute a relation transfer from the SQL engine to the iteration
276 engine.
278 Parameters
279 ----------
280 source : `~lsst.daf.relation.Relation`
281 Input relation, in what is assumed to be `sql_engine`.
282 materialize_as : `str` or `None`
283 If not `None`, the name of a persistent materialization to apply
284 in the iteration engine. This fetches all rows up front.
286 Returns
287 -------
288 destination : `~lsst.daf.relation.iteration.RowIterable`
289 Iteration engine payload iterable with the same content as the
290 given SQL relation.
291 """
292 sql_executable = self.sql_engine.to_executable(source)
293 rows: iteration.RowIterable = _SqlRowIterable(
294 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine.column_types)
295 )
296 if materialize_as is not None:
297 rows = rows.materialized()
298 return rows
300 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload:
301 """Execute a relation transfer from the iteration engine to the SQL
302 engine.
304 Parameters
305 ----------
306 source : `~lsst.daf.relation.Relation`
307 Input relation, in what is assumed to be `iteration_engine`.
308 materialize_as : `str` or `None`
309 If not `None`, the name of a persistent materialization to apply
310 in the SQL engine. This sets the name of the temporary table
311 and ensures that one is used (instead of `Database.constant_rows`,
312 which might otherwise be used for small row sets).
314 Returns
315 -------
316 destination : `~lsst.daf.relation.sql.Payload`
317 SQL engine payload struct with the same content as the given
318 iteration relation.
319 """
320 iterable = self.iteration_engine.execute(source)
321 table_spec = self.column_types.make_relation_table_spec(source.columns)
322 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine.column_types)
323 sql_rows = []
324 iterator = iter(iterable)
325 # Iterate over the first chunk of rows and transform them to hold only
326 # types recognized by SQLAlchemy (i.e. expand out
327 # TimespanDatabaseRepresentations). Note that this advances
328 # `iterator`.
329 for relation_row in itertools.islice(iterator, self._row_chunk_size):
330 sql_rows.append(row_transformer.relation_to_sql(relation_row))
331 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload")
332 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max():
333 if self._exit_stack is None:
334 raise RuntimeError("This operation requires the QueryContext to have been entered.")
335 # Either we're being asked to insert into a temporary table with a
336 # "persistent" lifetime (the duration of the QueryContext), or we
337 # have enough rows that the database says we can't use its
338 # "constant_rows" construct (e.g. VALUES). Note that
339 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is
340 # greater than or equal to `Database.get_constant_rows_max`.
341 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name))
342 while sql_rows:
343 self._db.insert(table, *sql_rows)
344 sql_rows.clear()
345 for relation_row in itertools.islice(iterator, self._row_chunk_size):
346 sql_rows.append(row_transformer.relation_to_sql(relation_row))
347 payload = sql.Payload[LogicalColumn](table)
348 else:
349 # Small number of rows; use Database.constant_rows.
350 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows))
351 payload.columns_available = self.sql_engine.extract_mapping(
352 source.columns, payload.from_clause.columns
353 )
354 return payload
356 def _strip_empty_invariant_operations(
357 self,
358 relation: Relation,
359 exact: bool,
360 ) -> Relation:
361 """Return a modified relation tree that strips out relation operations
362 that do not affect whether there are any rows from the end of the tree.
364 Parameters
365 ----------
366 relation : `Relation`
367 Original relation tree.
368 exact : `bool`
369 If `False` strip all iteration-engine operations, even those that
370 can remove all rows.
372 Returns
373 -------
374 modified : `Relation`
375 Modified relation tree.
376 """
377 if relation.payload is not None or relation.is_locked:
378 return relation
379 match relation:
380 case UnaryOperationRelation(operation=operation, target=target):
381 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine):
382 return self._strip_empty_invariant_operations(target, exact)
383 case Transfer(target=target):
384 return self._strip_empty_invariant_operations(target, exact)
385 case MarkerRelation(target=target):
386 return relation.reapply(self._strip_empty_invariant_operations(target, exact))
387 return relation
389 def _strip_count_invariant_operations(
390 self,
391 relation: Relation,
392 exact: bool,
393 ) -> Relation:
394 """Return a modified relation tree that strips out relation operations
395 that do not affect the number of rows from the end of the tree.
397 Parameters
398 ----------
399 relation : `Relation`
400 Original relation tree.
401 exact : `bool`
402 If `False` strip all iteration-engine operations, even those that
403 can affect the number of rows.
405 Returns
406 -------
407 modified : `Relation`
408 Modified relation tree.
409 """
410 if relation.payload is not None or relation.is_locked:
411 return relation
412 match relation:
413 case UnaryOperationRelation(operation=operation, target=target):
414 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine):
415 return self._strip_count_invariant_operations(target, exact)
416 case Transfer(target=target):
417 return self._strip_count_invariant_operations(target, exact)
418 case MarkerRelation(target=target):
419 return relation.reapply(self._strip_count_invariant_operations(target, exact))
420 return relation
423class _SqlRowIterable(iteration.RowIterable):
424 """An implementation of `lsst.daf.relation.iteration.RowIterable` that
425 executes a SQL query.
427 Parameters
428 ----------
429 context : `SqlQueryContext`
430 Context to execute the query with. If this context has already been
431 entered, a lazy iterable will be returned that may or may not use
432 server side cursors, since the context's lifetime can be used to manage
433 that cursor's lifetime. If the context has not been entered, all
434 results will be fetched up front in raw form, but will still be
435 processed into mappings keyed by `ColumnTag` lazily.
436 sql_executable : `sqlalchemy.sql.expression.SelectBase`
437 SQL query to execute; assumed to have been built by
438 `lsst.daf.relation.sql.Engine.to_executable` or similar.
439 row_transformer : `_SqlRowTransformer`
440 Object that converts SQLAlchemy result-row mappings (with `str`
441 keys and possibly-unpacked timespan values) to relation row
442 mappings (with `ColumnTag` keys and `LogicalColumn` values).
443 """
445 def __init__(
446 self,
447 context: SqlQueryContext,
448 sql_executable: sqlalchemy.sql.expression.SelectBase,
449 row_transformer: _SqlRowTransformer,
450 ):
451 self._context = context
452 self._sql_executable = sql_executable
453 self._row_transformer = row_transformer
455 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
456 with self._context._db.query(self._sql_executable) as sql_result:
457 raw_rows = sql_result.mappings()
458 if self._context._exit_stack is None:
459 raw_rows = raw_rows.fetchall()
460 for sql_row in raw_rows:
461 yield self._row_transformer.sql_to_relation(sql_row)
464class _SqlRowTransformer:
465 """Object that converts SQLAlchemy result-row mappings to relation row
466 mappings.
468 Parameters
469 ----------
470 columns : `Iterable` [ `ColumnTag` ]
471 Set of columns to handle. Rows must have at least these columns, but
472 may have more.
473 column_types : `ColumnTypeInfo`
474 Information about column types that can vary with registry
475 configuration.
476 """
478 def __init__(self, columns: Iterable[ColumnTag], column_types: ColumnTypeInfo):
479 self._scalar_columns = set(columns)
480 self._timespan_columns: dict[ColumnTag, type[TimespanDatabaseRepresentation]] = {}
481 self._timespan_columns.update(
482 {tag: column_types.timespan_cls for tag in columns if is_timespan_column(tag)}
483 )
484 self._scalar_columns.difference_update(self._timespan_columns.keys())
485 self._has_no_columns = not columns
487 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns")
489 def sql_to_relation(self, sql_row: Mapping[str, Any]) -> dict[ColumnTag, Any]:
490 """Convert a result row from a SQLAlchemy result into the form expected
491 by `lsst.daf.relation.iteration`.
493 Parameters
494 ----------
495 sql_row : `Mapping`
496 Mapping with `str` keys and possibly-unpacked values for timespan
497 columns.
499 Returns
500 -------
501 relation_row : `Mapping`
502 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
503 columns.
504 """
505 relation_row = {tag: sql_row[tag.qualified_name] for tag in self._scalar_columns}
506 for tag, db_repr_cls in self._timespan_columns.items():
507 relation_row[tag] = db_repr_cls.extract(sql_row, name=tag.qualified_name)
508 return relation_row
510 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]:
511 """Convert a `lsst.daf.relation.iteration` row into the mapping form
512 used by SQLAlchemy.
514 Parameters
515 ----------
516 relation_row : `Mapping`
517 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
518 columns.
520 Returns
521 -------
522 sql_row : `Mapping`
523 Mapping with `str` keys and possibly-unpacked values for timespan
524 columns.
525 """
526 sql_row = {tag.qualified_name: relation_row[tag] for tag in self._scalar_columns}
527 for tag, db_repr_cls in self._timespan_columns.items():
528 db_repr_cls.update(relation_row[tag], result=sql_row, name=tag.qualified_name)
529 if self._has_no_columns:
530 sql_row["IGNORED"] = None
531 return sql_row