Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 12%
235 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-28 02:30 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-28 02:30 -0800
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("SqlQueryContext",)
25import itertools
26from collections.abc import Iterable, Iterator, Mapping, Set
27from contextlib import ExitStack
28from typing import TYPE_CHECKING, Any, cast
30import sqlalchemy
31from lsst.daf.relation import (
32 BinaryOperationRelation,
33 Calculation,
34 Chain,
35 ColumnTag,
36 Deduplication,
37 Engine,
38 EngineError,
39 Join,
40 MarkerRelation,
41 Projection,
42 Relation,
43 Transfer,
44 UnaryOperation,
45 UnaryOperationRelation,
46 iteration,
47 sql,
48)
50from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column
51from ..nameShrinker import NameShrinker
52from ._query_context import QueryContext
53from .butler_sql_engine import ButlerSqlEngine
55if TYPE_CHECKING: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true
56 from ..interfaces import Database
59class SqlQueryContext(QueryContext):
60 """An implementation of `sql.QueryContext` for `SqlRegistry`.
62 Parameters
63 ----------
64 db : `Database`
65 Object that abstracts the database engine.
66 sql_engine : `ButlerSqlEngine`
67 Information about column types that can vary with registry
68 configuration.
69 row_chunk_size : `int`, optional
70 Number of rows to insert into temporary tables at once. If this is
71 lower than ``db.get_constant_rows_max()`` it will be set to that value.
72 """
74 def __init__(
75 self,
76 db: Database,
77 column_types: ColumnTypeInfo,
78 row_chunk_size: int = 1000,
79 ):
80 super().__init__()
81 self.sql_engine = ButlerSqlEngine(
82 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length)
83 )
84 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max())
85 self._db = db
86 self._exit_stack: ExitStack | None = None
88 def __enter__(self) -> SqlQueryContext:
89 assert self._exit_stack is None, "Context manager already entered."
90 self._exit_stack = ExitStack().__enter__()
91 self._exit_stack.enter_context(self._db.session())
92 return self
94 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
95 assert self._exit_stack is not None, "Context manager not yet entered."
96 result = self._exit_stack.__exit__(exc_type, exc_value, traceback)
97 self._exit_stack = None
98 return result
100 @property
101 def is_open(self) -> bool:
102 return self._exit_stack is not None
104 @property
105 def column_types(self) -> ColumnTypeInfo:
106 """Information about column types that depend on registry configuration
107 (`ColumnTypeInfo`).
108 """
109 return self.sql_engine.column_types
111 @property
112 def preferred_engine(self) -> Engine:
113 # Docstring inherited.
114 return self.sql_engine
116 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int:
117 # Docstring inherited.
118 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())
119 if relation.engine == self.sql_engine:
120 sql_executable = self.sql_engine.to_executable(
121 relation, extra_columns=[sqlalchemy.sql.func.count()]
122 )
123 with self._db.query(sql_executable) as sql_result:
124 return cast(int, sql_result.scalar_one())
125 elif (rows := relation.payload) is not None:
126 assert isinstance(
127 rows, iteration.MaterializedRowIterable
128 ), "Query guarantees that only materialized payloads are attached to its relations."
129 return len(rows)
130 elif discard:
131 n = 0
132 for _ in self.fetch_iterable(relation):
133 n += 1
134 return n
135 else:
136 raise RuntimeError(
137 f"Query with relation {relation} has deferred operations that "
138 "must be executed in Python in order to obtain an exact "
139 "count. Pass discard=True to run the query and discard the "
140 "result rows while counting them, run the query first, or "
141 "pass exact=False to obtain an upper bound."
142 )
144 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool:
145 # Docstring inherited.
146 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1]
147 if relation.engine == self.sql_engine:
148 sql_executable = self.sql_engine.to_executable(relation)
149 with self._db.query(sql_executable) as sql_result:
150 return sql_result.one_or_none() is not None
151 elif (rows := relation.payload) is not None:
152 assert isinstance(
153 rows, iteration.MaterializedRowIterable
154 ), "Query guarantees that only materialized payloads are attached to its relations."
155 return bool(rows)
156 elif execute:
157 for _ in self.fetch_iterable(relation):
158 return True
159 return False
160 else:
161 raise RuntimeError(
162 f"Query with relation {relation} has deferred operations that "
163 "must be executed in Python in order to obtain an exact "
164 "check for whether any rows would be returned. Pass "
165 "execute=True to run the query until at least "
166 "one row is found, run the query first, or pass exact=False "
167 "test only whether the query _might_ have result rows."
168 )
170 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any:
171 # Docstring inherited from lsst.daf.relation.Processor.
172 if source.engine == self.sql_engine and destination == self.iteration_engine:
173 return self._sql_to_iteration(source, materialize_as)
174 if source.engine == self.iteration_engine and destination == self.sql_engine:
175 return self._iteration_to_sql(source, materialize_as)
176 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.")
178 def materialize(self, target: Relation, name: str) -> Any:
179 # Docstring inherited from lsst.daf.relation.Processor.
180 if target.engine == self.sql_engine:
181 sql_executable = self.sql_engine.to_executable(target)
182 table_spec = self.column_types.make_relation_table_spec(target.columns)
183 if self._exit_stack is None:
184 raise RuntimeError("This operation requires the QueryContext to have been entered.")
185 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name))
186 self._db.insert(table, select=sql_executable)
187 payload = sql.Payload[LogicalColumn](table)
188 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns)
189 return payload
190 return super().materialize(target, name)
192 def restore_columns(
193 self,
194 relation: Relation,
195 columns_required: Set[ColumnTag],
196 ) -> tuple[Relation, set[ColumnTag]]:
197 # Docstring inherited.
198 if relation.is_locked:
199 return relation, set(relation.columns & columns_required)
200 match relation:
201 case UnaryOperationRelation(operation=operation, target=target):
202 new_target, columns_found = self.restore_columns(target, columns_required)
203 if not columns_found:
204 return relation, columns_found
205 match operation:
206 case Projection(columns=columns):
207 new_columns = columns | columns_found
208 if new_columns != columns:
209 return (
210 Projection(new_columns).apply(new_target),
211 columns_found,
212 )
213 case Calculation(tag=tag):
214 if tag in columns_required:
215 columns_found.add(tag)
216 case Deduplication():
217 # Pulling key columns through the deduplication would
218 # fundamentally change what it does; have to limit
219 # ourselves to non-key columns here, and put a
220 # Projection back in place.
221 columns_found = {c for c in columns_found if not c.is_key}
222 new_target = new_target.with_only_columns(relation.columns | columns_found)
223 return relation.reapply(new_target), columns_found
224 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs):
225 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required)
226 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required)
227 match operation:
228 case Join():
229 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs
230 case Chain():
231 if columns_found_lhs != columns_found_rhs:
232 # Got different answers for different join
233 # branches; let's try again with just columns found
234 # in both branches.
235 new_columns_required = columns_found_lhs & columns_found_rhs
236 if not new_columns_required:
237 return relation, set()
238 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required)
239 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required)
240 assert columns_found_lhs == columns_found_rhs
241 return relation.reapply(new_lhs, new_rhs), columns_found_lhs
242 case MarkerRelation(target=target):
243 new_target, columns_found = self.restore_columns(target, columns_required)
244 return relation.reapply(new_target), columns_found
245 raise AssertionError("Match should be exhaustive and all branches should return.")
247 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]:
248 # Docstring inherited.
249 if relation.engine != self.iteration_engine or relation.is_locked:
250 return relation, []
251 match relation:
252 case UnaryOperationRelation(operation=operation, target=target):
253 new_target, stripped = self.strip_postprocessing(target)
254 stripped.append(operation)
255 return new_target, stripped
256 case Transfer(destination=self.iteration_engine, target=target):
257 return target, []
258 return relation, []
260 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation:
261 # Docstring inherited.
262 if relation.engine != self.iteration_engine or relation.is_locked:
263 return relation
264 match relation:
265 case UnaryOperationRelation(operation=operation, target=target):
266 columns_added_here = relation.columns - target.columns
267 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here)
268 if operation.columns_required <= new_columns:
269 # This operation can stay.
270 return relation.reapply(new_target)
271 else:
272 # This operation needs to be dropped, as we're about to
273 # project away one more more of the columns it needs.
274 return new_target
275 return relation
277 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable:
278 """Execute a relation transfer from the SQL engine to the iteration
279 engine.
281 Parameters
282 ----------
283 source : `~lsst.daf.relation.Relation`
284 Input relation, in what is assumed to be `sql_engine`.
285 materialize_as : `str` or `None`
286 If not `None`, the name of a persistent materialization to apply
287 in the iteration engine. This fetches all rows up front.
289 Returns
290 -------
291 destination : `~lsst.daf.relation.iteration.RowIterable`
292 Iteration engine payload iterable with the same content as the
293 given SQL relation.
294 """
295 sql_executable = self.sql_engine.to_executable(source)
296 rows: iteration.RowIterable = _SqlRowIterable(
297 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine)
298 )
299 if materialize_as is not None:
300 rows = rows.materialized()
301 return rows
303 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload:
304 """Execute a relation transfer from the iteration engine to the SQL
305 engine.
307 Parameters
308 ----------
309 source : `~lsst.daf.relation.Relation`
310 Input relation, in what is assumed to be `iteration_engine`.
311 materialize_as : `str` or `None`
312 If not `None`, the name of a persistent materialization to apply
313 in the SQL engine. This sets the name of the temporary table
314 and ensures that one is used (instead of `Database.constant_rows`,
315 which might otherwise be used for small row sets).
317 Returns
318 -------
319 destination : `~lsst.daf.relation.sql.Payload`
320 SQL engine payload struct with the same content as the given
321 iteration relation.
322 """
323 iterable = self.iteration_engine.execute(source)
324 table_spec = self.column_types.make_relation_table_spec(source.columns)
325 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine)
326 sql_rows = []
327 iterator = iter(iterable)
328 # Iterate over the first chunk of rows and transform them to hold only
329 # types recognized by SQLAlchemy (i.e. expand out
330 # TimespanDatabaseRepresentations). Note that this advances
331 # `iterator`.
332 for relation_row in itertools.islice(iterator, self._row_chunk_size):
333 sql_rows.append(row_transformer.relation_to_sql(relation_row))
334 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload")
335 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max():
336 if self._exit_stack is None:
337 raise RuntimeError("This operation requires the QueryContext to have been entered.")
338 # Either we're being asked to insert into a temporary table with a
339 # "persistent" lifetime (the duration of the QueryContext), or we
340 # have enough rows that the database says we can't use its
341 # "constant_rows" construct (e.g. VALUES). Note that
342 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is
343 # greater than or equal to `Database.get_constant_rows_max`.
344 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name))
345 while sql_rows:
346 self._db.insert(table, *sql_rows)
347 sql_rows.clear()
348 for relation_row in itertools.islice(iterator, self._row_chunk_size):
349 sql_rows.append(row_transformer.relation_to_sql(relation_row))
350 payload = sql.Payload[LogicalColumn](table)
351 else:
352 # Small number of rows; use Database.constant_rows.
353 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows))
354 payload.columns_available = self.sql_engine.extract_mapping(
355 source.columns, payload.from_clause.columns
356 )
357 return payload
359 def _strip_empty_invariant_operations(
360 self,
361 relation: Relation,
362 exact: bool,
363 ) -> Relation:
364 """Return a modified relation tree that strips out relation operations
365 that do not affect whether there are any rows from the end of the tree.
367 Parameters
368 ----------
369 relation : `Relation`
370 Original relation tree.
371 exact : `bool`
372 If `False` strip all iteration-engine operations, even those that
373 can remove all rows.
375 Returns
376 -------
377 modified : `Relation`
378 Modified relation tree.
379 """
380 if relation.payload is not None or relation.is_locked:
381 return relation
382 match relation:
383 case UnaryOperationRelation(operation=operation, target=target):
384 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine):
385 return self._strip_empty_invariant_operations(target, exact)
386 case Transfer(target=target):
387 return self._strip_empty_invariant_operations(target, exact)
388 case MarkerRelation(target=target):
389 return relation.reapply(self._strip_empty_invariant_operations(target, exact))
390 return relation
392 def _strip_count_invariant_operations(
393 self,
394 relation: Relation,
395 exact: bool,
396 ) -> Relation:
397 """Return a modified relation tree that strips out relation operations
398 that do not affect the number of rows from the end of the tree.
400 Parameters
401 ----------
402 relation : `Relation`
403 Original relation tree.
404 exact : `bool`
405 If `False` strip all iteration-engine operations, even those that
406 can affect the number of rows.
408 Returns
409 -------
410 modified : `Relation`
411 Modified relation tree.
412 """
413 if relation.payload is not None or relation.is_locked:
414 return relation
415 match relation:
416 case UnaryOperationRelation(operation=operation, target=target):
417 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine):
418 return self._strip_count_invariant_operations(target, exact)
419 case Transfer(target=target):
420 return self._strip_count_invariant_operations(target, exact)
421 case MarkerRelation(target=target):
422 return relation.reapply(self._strip_count_invariant_operations(target, exact))
423 return relation
426class _SqlRowIterable(iteration.RowIterable):
427 """An implementation of `lsst.daf.relation.iteration.RowIterable` that
428 executes a SQL query.
430 Parameters
431 ----------
432 context : `SqlQueryContext`
433 Context to execute the query with. If this context has already been
434 entered, a lazy iterable will be returned that may or may not use
435 server side cursors, since the context's lifetime can be used to manage
436 that cursor's lifetime. If the context has not been entered, all
437 results will be fetched up front in raw form, but will still be
438 processed into mappings keyed by `ColumnTag` lazily.
439 sql_executable : `sqlalchemy.sql.expression.SelectBase`
440 SQL query to execute; assumed to have been built by
441 `lsst.daf.relation.sql.Engine.to_executable` or similar.
442 row_transformer : `_SqlRowTransformer`
443 Object that converts SQLAlchemy result-row mappings (with `str`
444 keys and possibly-unpacked timespan values) to relation row
445 mappings (with `ColumnTag` keys and `LogicalColumn` values).
446 """
448 def __init__(
449 self,
450 context: SqlQueryContext,
451 sql_executable: sqlalchemy.sql.expression.SelectBase,
452 row_transformer: _SqlRowTransformer,
453 ):
454 self._context = context
455 self._sql_executable = sql_executable
456 self._row_transformer = row_transformer
458 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
459 with self._context._db.query(self._sql_executable) as sql_result:
460 raw_rows = sql_result.mappings()
461 if self._context._exit_stack is None:
462 for sql_row in raw_rows.fetchall():
463 yield self._row_transformer.sql_to_relation(sql_row)
464 else:
465 for sql_row in raw_rows:
466 yield self._row_transformer.sql_to_relation(sql_row)
469class _SqlRowTransformer:
470 """Object that converts SQLAlchemy result-row mappings to relation row
471 mappings.
473 Parameters
474 ----------
475 columns : `Iterable` [ `ColumnTag` ]
476 Set of columns to handle. Rows must have at least these columns, but
477 may have more.
478 engine : `ButlerSqlEngine`
479 Relation engine; used to transform column tags into SQL identifiers and
480 obtain column type information.
481 """
483 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine):
484 self._scalar_columns: list[tuple[str, ColumnTag]] = []
485 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = []
486 for tag in columns:
487 if is_timespan_column(tag):
488 self._timespan_columns.append(
489 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls)
490 )
491 else:
492 self._scalar_columns.append((engine.get_identifier(tag), tag))
493 self._has_no_columns = not (self._scalar_columns or self._timespan_columns)
495 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns")
497 def sql_to_relation(self, sql_row: Mapping[str, Any]) -> dict[ColumnTag, Any]:
498 """Convert a result row from a SQLAlchemy result into the form expected
499 by `lsst.daf.relation.iteration`.
501 Parameters
502 ----------
503 sql_row : `Mapping`
504 Mapping with `str` keys and possibly-unpacked values for timespan
505 columns.
507 Returns
508 -------
509 relation_row : `Mapping`
510 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
511 columns.
512 """
513 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns}
514 for identifier, tag, db_repr_cls in self._timespan_columns:
515 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier)
516 return relation_row
518 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]:
519 """Convert a `lsst.daf.relation.iteration` row into the mapping form
520 used by SQLAlchemy.
522 Parameters
523 ----------
524 relation_row : `Mapping`
525 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
526 columns.
528 Returns
529 -------
530 sql_row : `Mapping`
531 Mapping with `str` keys and possibly-unpacked values for timespan
532 columns.
533 """
534 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns}
535 for identifier, tag, db_repr_cls in self._timespan_columns:
536 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier)
537 if self._has_no_columns:
538 sql_row["IGNORED"] = None
539 return sql_row