Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 13%
236 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-21 09:55 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-21 09:55 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("SqlQueryContext",)
25import itertools
26from collections.abc import Iterable, Iterator, Mapping, Set
27from contextlib import ExitStack
28from typing import TYPE_CHECKING, Any, cast
30import sqlalchemy
31from lsst.daf.relation import (
32 BinaryOperationRelation,
33 Calculation,
34 Chain,
35 ColumnTag,
36 Deduplication,
37 Engine,
38 EngineError,
39 Join,
40 MarkerRelation,
41 Projection,
42 Relation,
43 Transfer,
44 UnaryOperation,
45 UnaryOperationRelation,
46 iteration,
47 sql,
48)
50from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column
51from ..nameShrinker import NameShrinker
52from ._query_context import QueryContext
53from .butler_sql_engine import ButlerSqlEngine
55if TYPE_CHECKING:
56 from ..interfaces import Database
59class SqlQueryContext(QueryContext):
60 """An implementation of `sql.QueryContext` for `SqlRegistry`.
62 Parameters
63 ----------
64 db : `Database`
65 Object that abstracts the database engine.
66 sql_engine : `ButlerSqlEngine`
67 Information about column types that can vary with registry
68 configuration.
69 row_chunk_size : `int`, optional
70 Number of rows to insert into temporary tables at once. If this is
71 lower than ``db.get_constant_rows_max()`` it will be set to that value.
72 """
74 def __init__(
75 self,
76 db: Database,
77 column_types: ColumnTypeInfo,
78 row_chunk_size: int = 1000,
79 ):
80 super().__init__()
81 self.sql_engine = ButlerSqlEngine(
82 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length)
83 )
84 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max())
85 self._db = db
86 self._exit_stack: ExitStack | None = None
88 def __enter__(self) -> SqlQueryContext:
89 assert self._exit_stack is None, "Context manager already entered."
90 self._exit_stack = ExitStack().__enter__()
91 self._exit_stack.enter_context(self._db.session())
92 return self
94 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
95 assert self._exit_stack is not None, "Context manager not yet entered."
96 result = self._exit_stack.__exit__(exc_type, exc_value, traceback)
97 self._exit_stack = None
98 return result
100 @property
101 def is_open(self) -> bool:
102 return self._exit_stack is not None
104 @property
105 def column_types(self) -> ColumnTypeInfo:
106 """Information about column types that depend on registry configuration
107 (`ColumnTypeInfo`).
108 """
109 return self.sql_engine.column_types
111 @property
112 def preferred_engine(self) -> Engine:
113 # Docstring inherited.
114 return self.sql_engine
116 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int:
117 # Docstring inherited.
118 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())
119 if relation.engine == self.sql_engine:
120 sql_executable = self.sql_engine.to_executable(
121 relation, extra_columns=[sqlalchemy.sql.func.count()]
122 )
123 with self._db.query(sql_executable) as sql_result:
124 return cast(int, sql_result.scalar_one())
125 elif (rows := relation.payload) is not None:
126 assert isinstance(
127 rows, iteration.MaterializedRowIterable
128 ), "Query guarantees that only materialized payloads are attached to its relations."
129 return len(rows)
130 elif discard:
131 n = 0
132 for _ in self.fetch_iterable(relation):
133 n += 1
134 return n
135 else:
136 raise RuntimeError(
137 f"Query with relation {relation} has deferred operations that "
138 "must be executed in Python in order to obtain an exact "
139 "count. Pass discard=True to run the query and discard the "
140 "result rows while counting them, run the query first, or "
141 "pass exact=False to obtain an upper bound."
142 )
144 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool:
145 # Docstring inherited.
146 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1]
147 if relation.engine == self.sql_engine:
148 sql_executable = self.sql_engine.to_executable(relation)
149 with self._db.query(sql_executable) as sql_result:
150 return sql_result.one_or_none() is not None
151 elif (rows := relation.payload) is not None:
152 assert isinstance(
153 rows, iteration.MaterializedRowIterable
154 ), "Query guarantees that only materialized payloads are attached to its relations."
155 return bool(rows)
156 elif execute:
157 for _ in self.fetch_iterable(relation):
158 return True
159 return False
160 else:
161 raise RuntimeError(
162 f"Query with relation {relation} has deferred operations that "
163 "must be executed in Python in order to obtain an exact "
164 "check for whether any rows would be returned. Pass "
165 "execute=True to run the query until at least "
166 "one row is found, run the query first, or pass exact=False "
167 "test only whether the query _might_ have result rows."
168 )
170 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any:
171 # Docstring inherited from lsst.daf.relation.Processor.
172 if source.engine == self.sql_engine and destination == self.iteration_engine:
173 return self._sql_to_iteration(source, materialize_as)
174 if source.engine == self.iteration_engine and destination == self.sql_engine:
175 return self._iteration_to_sql(source, materialize_as)
176 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.")
178 def materialize(self, target: Relation, name: str) -> Any:
179 # Docstring inherited from lsst.daf.relation.Processor.
180 if target.engine == self.sql_engine:
181 sql_executable = self.sql_engine.to_executable(target)
182 table_spec = self.column_types.make_relation_table_spec(target.columns)
183 if self._exit_stack is None:
184 raise RuntimeError("This operation requires the QueryContext to have been entered.")
185 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name))
186 self._db.insert(table, select=sql_executable)
187 payload = sql.Payload[LogicalColumn](table)
188 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns)
189 return payload
190 return super().materialize(target, name)
192 def restore_columns(
193 self,
194 relation: Relation,
195 columns_required: Set[ColumnTag],
196 ) -> tuple[Relation, set[ColumnTag]]:
197 # Docstring inherited.
198 if relation.is_locked:
199 return relation, set(relation.columns & columns_required)
200 match relation:
201 case UnaryOperationRelation(operation=operation, target=target):
202 new_target, columns_found = self.restore_columns(target, columns_required)
203 if not columns_found:
204 return relation, columns_found
205 match operation:
206 case Projection(columns=columns):
207 new_columns = columns | columns_found
208 if new_columns != columns:
209 return (
210 Projection(new_columns).apply(new_target),
211 columns_found,
212 )
213 case Calculation(tag=tag):
214 if tag in columns_required:
215 columns_found.add(tag)
216 case Deduplication():
217 # Pulling key columns through the deduplication would
218 # fundamentally change what it does; have to limit
219 # ourselves to non-key columns here, and put a
220 # Projection back in place.
221 columns_found = {c for c in columns_found if not c.is_key}
222 new_target = new_target.with_only_columns(relation.columns | columns_found)
223 return relation.reapply(new_target), columns_found
224 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs):
225 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required)
226 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required)
227 match operation:
228 case Join():
229 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs
230 case Chain():
231 if columns_found_lhs != columns_found_rhs:
232 # Got different answers for different join
233 # branches; let's try again with just columns found
234 # in both branches.
235 new_columns_required = columns_found_lhs & columns_found_rhs
236 if not new_columns_required:
237 return relation, set()
238 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required)
239 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required)
240 assert columns_found_lhs == columns_found_rhs
241 return relation.reapply(new_lhs, new_rhs), columns_found_lhs
242 case MarkerRelation(target=target):
243 new_target, columns_found = self.restore_columns(target, columns_required)
244 return relation.reapply(new_target), columns_found
245 raise AssertionError("Match should be exhaustive and all branches should return.")
247 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]:
248 # Docstring inherited.
249 if relation.engine != self.iteration_engine or relation.is_locked:
250 return relation, []
251 match relation:
252 case UnaryOperationRelation(operation=operation, target=target):
253 new_target, stripped = self.strip_postprocessing(target)
254 stripped.append(operation)
255 return new_target, stripped
256 case Transfer(destination=self.iteration_engine, target=target):
257 return target, []
258 return relation, []
260 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation:
261 # Docstring inherited.
262 if relation.engine != self.iteration_engine or relation.is_locked:
263 return relation
264 match relation:
265 case UnaryOperationRelation(operation=operation, target=target):
266 columns_added_here = relation.columns - target.columns
267 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here)
268 if operation.columns_required <= new_columns:
269 # This operation can stay...
270 if new_target is target:
271 # ...and nothing has actually changed upstream of it.
272 return relation
273 else:
274 # ...but we should see if it can now be performed in
275 # the preferred engine, since something it didn't
276 # commute with may have been removed.
277 return operation.apply(new_target, preferred_engine=self.preferred_engine)
278 else:
279 # This operation needs to be dropped, as we're about to
280 # project away one more more of the columns it needs.
281 return new_target
282 return relation
284 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable:
285 """Execute a relation transfer from the SQL engine to the iteration
286 engine.
288 Parameters
289 ----------
290 source : `~lsst.daf.relation.Relation`
291 Input relation, in what is assumed to be `sql_engine`.
292 materialize_as : `str` or `None`
293 If not `None`, the name of a persistent materialization to apply
294 in the iteration engine. This fetches all rows up front.
296 Returns
297 -------
298 destination : `~lsst.daf.relation.iteration.RowIterable`
299 Iteration engine payload iterable with the same content as the
300 given SQL relation.
301 """
302 sql_executable = self.sql_engine.to_executable(source)
303 rows: iteration.RowIterable = _SqlRowIterable(
304 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine)
305 )
306 if materialize_as is not None:
307 rows = rows.materialized()
308 return rows
310 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload:
311 """Execute a relation transfer from the iteration engine to the SQL
312 engine.
314 Parameters
315 ----------
316 source : `~lsst.daf.relation.Relation`
317 Input relation, in what is assumed to be `iteration_engine`.
318 materialize_as : `str` or `None`
319 If not `None`, the name of a persistent materialization to apply
320 in the SQL engine. This sets the name of the temporary table
321 and ensures that one is used (instead of `Database.constant_rows`,
322 which might otherwise be used for small row sets).
324 Returns
325 -------
326 destination : `~lsst.daf.relation.sql.Payload`
327 SQL engine payload struct with the same content as the given
328 iteration relation.
329 """
330 iterable = self.iteration_engine.execute(source)
331 table_spec = self.column_types.make_relation_table_spec(source.columns)
332 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine)
333 sql_rows = []
334 iterator = iter(iterable)
335 # Iterate over the first chunk of rows and transform them to hold only
336 # types recognized by SQLAlchemy (i.e. expand out
337 # TimespanDatabaseRepresentations). Note that this advances
338 # `iterator`.
339 for relation_row in itertools.islice(iterator, self._row_chunk_size):
340 sql_rows.append(row_transformer.relation_to_sql(relation_row))
341 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload")
342 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max():
343 if self._exit_stack is None:
344 raise RuntimeError("This operation requires the QueryContext to have been entered.")
345 # Either we're being asked to insert into a temporary table with a
346 # "persistent" lifetime (the duration of the QueryContext), or we
347 # have enough rows that the database says we can't use its
348 # "constant_rows" construct (e.g. VALUES). Note that
349 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is
350 # greater than or equal to `Database.get_constant_rows_max`.
351 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name))
352 while sql_rows:
353 self._db.insert(table, *sql_rows)
354 sql_rows.clear()
355 for relation_row in itertools.islice(iterator, self._row_chunk_size):
356 sql_rows.append(row_transformer.relation_to_sql(relation_row))
357 payload = sql.Payload[LogicalColumn](table)
358 else:
359 # Small number of rows; use Database.constant_rows.
360 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows))
361 payload.columns_available = self.sql_engine.extract_mapping(
362 source.columns, payload.from_clause.columns
363 )
364 return payload
366 def _strip_empty_invariant_operations(
367 self,
368 relation: Relation,
369 exact: bool,
370 ) -> Relation:
371 """Return a modified relation tree that strips out relation operations
372 that do not affect whether there are any rows from the end of the tree.
374 Parameters
375 ----------
376 relation : `Relation`
377 Original relation tree.
378 exact : `bool`
379 If `False` strip all iteration-engine operations, even those that
380 can remove all rows.
382 Returns
383 -------
384 modified : `Relation`
385 Modified relation tree.
386 """
387 if relation.payload is not None or relation.is_locked:
388 return relation
389 match relation:
390 case UnaryOperationRelation(operation=operation, target=target):
391 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine):
392 return self._strip_empty_invariant_operations(target, exact)
393 case Transfer(target=target):
394 return self._strip_empty_invariant_operations(target, exact)
395 case MarkerRelation(target=target):
396 return relation.reapply(self._strip_empty_invariant_operations(target, exact))
397 return relation
399 def _strip_count_invariant_operations(
400 self,
401 relation: Relation,
402 exact: bool,
403 ) -> Relation:
404 """Return a modified relation tree that strips out relation operations
405 that do not affect the number of rows from the end of the tree.
407 Parameters
408 ----------
409 relation : `Relation`
410 Original relation tree.
411 exact : `bool`
412 If `False` strip all iteration-engine operations, even those that
413 can affect the number of rows.
415 Returns
416 -------
417 modified : `Relation`
418 Modified relation tree.
419 """
420 if relation.payload is not None or relation.is_locked:
421 return relation
422 match relation:
423 case UnaryOperationRelation(operation=operation, target=target):
424 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine):
425 return self._strip_count_invariant_operations(target, exact)
426 case Transfer(target=target):
427 return self._strip_count_invariant_operations(target, exact)
428 case MarkerRelation(target=target):
429 return relation.reapply(self._strip_count_invariant_operations(target, exact))
430 return relation
433class _SqlRowIterable(iteration.RowIterable):
434 """An implementation of `lsst.daf.relation.iteration.RowIterable` that
435 executes a SQL query.
437 Parameters
438 ----------
439 context : `SqlQueryContext`
440 Context to execute the query with. If this context has already been
441 entered, a lazy iterable will be returned that may or may not use
442 server side cursors, since the context's lifetime can be used to manage
443 that cursor's lifetime. If the context has not been entered, all
444 results will be fetched up front in raw form, but will still be
445 processed into mappings keyed by `ColumnTag` lazily.
446 sql_executable : `sqlalchemy.sql.expression.SelectBase`
447 SQL query to execute; assumed to have been built by
448 `lsst.daf.relation.sql.Engine.to_executable` or similar.
449 row_transformer : `_SqlRowTransformer`
450 Object that converts SQLAlchemy result-row mappings (with `str`
451 keys and possibly-unpacked timespan values) to relation row
452 mappings (with `ColumnTag` keys and `LogicalColumn` values).
453 """
455 def __init__(
456 self,
457 context: SqlQueryContext,
458 sql_executable: sqlalchemy.sql.expression.SelectBase,
459 row_transformer: _SqlRowTransformer,
460 ):
461 self._context = context
462 self._sql_executable = sql_executable
463 self._row_transformer = row_transformer
465 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
466 if self._context._exit_stack is None:
467 # Have to read results into memory and close database connection.
468 with self._context._db.query(self._sql_executable) as sql_result:
469 rows = sql_result.mappings().fetchall()
470 for sql_row in rows:
471 yield self._row_transformer.sql_to_relation(sql_row)
472 else:
473 with self._context._db.query(self._sql_executable) as sql_result:
474 raw_rows = sql_result.mappings()
475 for sql_row in raw_rows:
476 yield self._row_transformer.sql_to_relation(sql_row)
479class _SqlRowTransformer:
480 """Object that converts SQLAlchemy result-row mappings to relation row
481 mappings.
483 Parameters
484 ----------
485 columns : `~collections.abc.Iterable` [ `ColumnTag` ]
486 Set of columns to handle. Rows must have at least these columns, but
487 may have more.
488 engine : `ButlerSqlEngine`
489 Relation engine; used to transform column tags into SQL identifiers and
490 obtain column type information.
491 """
493 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine):
494 self._scalar_columns: list[tuple[str, ColumnTag]] = []
495 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = []
496 for tag in columns:
497 if is_timespan_column(tag):
498 self._timespan_columns.append(
499 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls)
500 )
501 else:
502 self._scalar_columns.append((engine.get_identifier(tag), tag))
503 self._has_no_columns = not (self._scalar_columns or self._timespan_columns)
505 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns")
507 def sql_to_relation(self, sql_row: sqlalchemy.RowMapping) -> dict[ColumnTag, Any]:
508 """Convert a result row from a SQLAlchemy result into the form expected
509 by `lsst.daf.relation.iteration`.
511 Parameters
512 ----------
513 sql_row : `~collections.abc.Mapping`
514 Mapping with `str` keys and possibly-unpacked values for timespan
515 columns.
517 Returns
518 -------
519 relation_row : `~collections.abc.Mapping`
520 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
521 columns.
522 """
523 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns}
524 for identifier, tag, db_repr_cls in self._timespan_columns:
525 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier)
526 return relation_row
528 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]:
529 """Convert a `lsst.daf.relation.iteration` row into the mapping form
530 used by SQLAlchemy.
532 Parameters
533 ----------
534 relation_row : `~collections.abc.Mapping`
535 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
536 columns.
538 Returns
539 -------
540 sql_row : `~collections.abc.Mapping`
541 Mapping with `str` keys and possibly-unpacked values for timespan
542 columns.
543 """
544 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns}
545 for identifier, tag, db_repr_cls in self._timespan_columns:
546 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier)
547 if self._has_no_columns:
548 sql_row["IGNORED"] = None
549 return sql_row