Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 13%
236 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("SqlQueryContext",)
31import itertools
32from collections.abc import Iterable, Iterator, Mapping, Set
33from contextlib import ExitStack
34from typing import TYPE_CHECKING, Any, cast
36import sqlalchemy
37from lsst.daf.relation import (
38 BinaryOperationRelation,
39 Calculation,
40 Chain,
41 ColumnTag,
42 Deduplication,
43 Engine,
44 EngineError,
45 Join,
46 MarkerRelation,
47 Projection,
48 Relation,
49 Transfer,
50 UnaryOperation,
51 UnaryOperationRelation,
52 iteration,
53 sql,
54)
56from ...core import ColumnTypeInfo, LogicalColumn, TimespanDatabaseRepresentation, is_timespan_column
57from ..nameShrinker import NameShrinker
58from ._query_context import QueryContext
59from .butler_sql_engine import ButlerSqlEngine
61if TYPE_CHECKING:
62 from ..interfaces import Database
65class SqlQueryContext(QueryContext):
66 """An implementation of `sql.QueryContext` for `SqlRegistry`.
68 Parameters
69 ----------
70 db : `Database`
71 Object that abstracts the database engine.
72 sql_engine : `ButlerSqlEngine`
73 Information about column types that can vary with registry
74 configuration.
75 row_chunk_size : `int`, optional
76 Number of rows to insert into temporary tables at once. If this is
77 lower than ``db.get_constant_rows_max()`` it will be set to that value.
78 """
80 def __init__(
81 self,
82 db: Database,
83 column_types: ColumnTypeInfo,
84 row_chunk_size: int = 1000,
85 ):
86 super().__init__()
87 self.sql_engine = ButlerSqlEngine(
88 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length)
89 )
90 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max())
91 self._db = db
92 self._exit_stack: ExitStack | None = None
94 def __enter__(self) -> SqlQueryContext:
95 assert self._exit_stack is None, "Context manager already entered."
96 self._exit_stack = ExitStack().__enter__()
97 self._exit_stack.enter_context(self._db.session())
98 return self
100 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
101 assert self._exit_stack is not None, "Context manager not yet entered."
102 result = self._exit_stack.__exit__(exc_type, exc_value, traceback)
103 self._exit_stack = None
104 return result
106 @property
107 def is_open(self) -> bool:
108 return self._exit_stack is not None
110 @property
111 def column_types(self) -> ColumnTypeInfo:
112 """Information about column types that depend on registry configuration
113 (`ColumnTypeInfo`).
114 """
115 return self.sql_engine.column_types
117 @property
118 def preferred_engine(self) -> Engine:
119 # Docstring inherited.
120 return self.sql_engine
122 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int:
123 # Docstring inherited.
124 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())
125 if relation.engine == self.sql_engine:
126 sql_executable = self.sql_engine.to_executable(
127 relation, extra_columns=[sqlalchemy.sql.func.count()]
128 )
129 with self._db.query(sql_executable) as sql_result:
130 return cast(int, sql_result.scalar_one())
131 elif (rows := relation.payload) is not None:
132 assert isinstance(
133 rows, iteration.MaterializedRowIterable
134 ), "Query guarantees that only materialized payloads are attached to its relations."
135 return len(rows)
136 elif discard:
137 n = 0
138 for _ in self.fetch_iterable(relation):
139 n += 1
140 return n
141 else:
142 raise RuntimeError(
143 f"Query with relation {relation} has deferred operations that "
144 "must be executed in Python in order to obtain an exact "
145 "count. Pass discard=True to run the query and discard the "
146 "result rows while counting them, run the query first, or "
147 "pass exact=False to obtain an upper bound."
148 )
150 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool:
151 # Docstring inherited.
152 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1]
153 if relation.engine == self.sql_engine:
154 sql_executable = self.sql_engine.to_executable(relation)
155 with self._db.query(sql_executable) as sql_result:
156 return sql_result.one_or_none() is not None
157 elif (rows := relation.payload) is not None:
158 assert isinstance(
159 rows, iteration.MaterializedRowIterable
160 ), "Query guarantees that only materialized payloads are attached to its relations."
161 return bool(rows)
162 elif execute:
163 for _ in self.fetch_iterable(relation):
164 return True
165 return False
166 else:
167 raise RuntimeError(
168 f"Query with relation {relation} has deferred operations that "
169 "must be executed in Python in order to obtain an exact "
170 "check for whether any rows would be returned. Pass "
171 "execute=True to run the query until at least "
172 "one row is found, run the query first, or pass exact=False "
173 "test only whether the query _might_ have result rows."
174 )
176 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any:
177 # Docstring inherited from lsst.daf.relation.Processor.
178 if source.engine == self.sql_engine and destination == self.iteration_engine:
179 return self._sql_to_iteration(source, materialize_as)
180 if source.engine == self.iteration_engine and destination == self.sql_engine:
181 return self._iteration_to_sql(source, materialize_as)
182 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.")
184 def materialize(self, target: Relation, name: str) -> Any:
185 # Docstring inherited from lsst.daf.relation.Processor.
186 if target.engine == self.sql_engine:
187 sql_executable = self.sql_engine.to_executable(target)
188 table_spec = self.column_types.make_relation_table_spec(target.columns)
189 if self._exit_stack is None:
190 raise RuntimeError("This operation requires the QueryContext to have been entered.")
191 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name))
192 self._db.insert(table, select=sql_executable)
193 payload = sql.Payload[LogicalColumn](table)
194 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns)
195 return payload
196 return super().materialize(target, name)
198 def restore_columns(
199 self,
200 relation: Relation,
201 columns_required: Set[ColumnTag],
202 ) -> tuple[Relation, set[ColumnTag]]:
203 # Docstring inherited.
204 if relation.is_locked:
205 return relation, set(relation.columns & columns_required)
206 match relation:
207 case UnaryOperationRelation(operation=operation, target=target):
208 new_target, columns_found = self.restore_columns(target, columns_required)
209 if not columns_found:
210 return relation, columns_found
211 match operation:
212 case Projection(columns=columns):
213 new_columns = columns | columns_found
214 if new_columns != columns:
215 return (
216 Projection(new_columns).apply(new_target),
217 columns_found,
218 )
219 case Calculation(tag=tag):
220 if tag in columns_required:
221 columns_found.add(tag)
222 case Deduplication():
223 # Pulling key columns through the deduplication would
224 # fundamentally change what it does; have to limit
225 # ourselves to non-key columns here, and put a
226 # Projection back in place.
227 columns_found = {c for c in columns_found if not c.is_key}
228 new_target = new_target.with_only_columns(relation.columns | columns_found)
229 return relation.reapply(new_target), columns_found
230 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs):
231 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required)
232 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required)
233 match operation:
234 case Join():
235 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs
236 case Chain():
237 if columns_found_lhs != columns_found_rhs:
238 # Got different answers for different join
239 # branches; let's try again with just columns found
240 # in both branches.
241 new_columns_required = columns_found_lhs & columns_found_rhs
242 if not new_columns_required:
243 return relation, set()
244 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required)
245 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required)
246 assert columns_found_lhs == columns_found_rhs
247 return relation.reapply(new_lhs, new_rhs), columns_found_lhs
248 case MarkerRelation(target=target):
249 new_target, columns_found = self.restore_columns(target, columns_required)
250 return relation.reapply(new_target), columns_found
251 raise AssertionError("Match should be exhaustive and all branches should return.")
253 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]:
254 # Docstring inherited.
255 if relation.engine != self.iteration_engine or relation.is_locked:
256 return relation, []
257 match relation:
258 case UnaryOperationRelation(operation=operation, target=target):
259 new_target, stripped = self.strip_postprocessing(target)
260 stripped.append(operation)
261 return new_target, stripped
262 case Transfer(destination=self.iteration_engine, target=target):
263 return target, []
264 return relation, []
266 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation:
267 # Docstring inherited.
268 if relation.engine != self.iteration_engine or relation.is_locked:
269 return relation
270 match relation:
271 case UnaryOperationRelation(operation=operation, target=target):
272 columns_added_here = relation.columns - target.columns
273 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here)
274 if operation.columns_required <= new_columns:
275 # This operation can stay...
276 if new_target is target:
277 # ...and nothing has actually changed upstream of it.
278 return relation
279 else:
280 # ...but we should see if it can now be performed in
281 # the preferred engine, since something it didn't
282 # commute with may have been removed.
283 return operation.apply(new_target, preferred_engine=self.preferred_engine)
284 else:
285 # This operation needs to be dropped, as we're about to
286 # project away one more more of the columns it needs.
287 return new_target
288 return relation
290 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable:
291 """Execute a relation transfer from the SQL engine to the iteration
292 engine.
294 Parameters
295 ----------
296 source : `~lsst.daf.relation.Relation`
297 Input relation, in what is assumed to be `sql_engine`.
298 materialize_as : `str` or `None`
299 If not `None`, the name of a persistent materialization to apply
300 in the iteration engine. This fetches all rows up front.
302 Returns
303 -------
304 destination : `~lsst.daf.relation.iteration.RowIterable`
305 Iteration engine payload iterable with the same content as the
306 given SQL relation.
307 """
308 sql_executable = self.sql_engine.to_executable(source)
309 rows: iteration.RowIterable = _SqlRowIterable(
310 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine)
311 )
312 if materialize_as is not None:
313 rows = rows.materialized()
314 return rows
316 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload:
317 """Execute a relation transfer from the iteration engine to the SQL
318 engine.
320 Parameters
321 ----------
322 source : `~lsst.daf.relation.Relation`
323 Input relation, in what is assumed to be `iteration_engine`.
324 materialize_as : `str` or `None`
325 If not `None`, the name of a persistent materialization to apply
326 in the SQL engine. This sets the name of the temporary table
327 and ensures that one is used (instead of `Database.constant_rows`,
328 which might otherwise be used for small row sets).
330 Returns
331 -------
332 destination : `~lsst.daf.relation.sql.Payload`
333 SQL engine payload struct with the same content as the given
334 iteration relation.
335 """
336 iterable = self.iteration_engine.execute(source)
337 table_spec = self.column_types.make_relation_table_spec(source.columns)
338 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine)
339 sql_rows = []
340 iterator = iter(iterable)
341 # Iterate over the first chunk of rows and transform them to hold only
342 # types recognized by SQLAlchemy (i.e. expand out
343 # TimespanDatabaseRepresentations). Note that this advances
344 # `iterator`.
345 for relation_row in itertools.islice(iterator, self._row_chunk_size):
346 sql_rows.append(row_transformer.relation_to_sql(relation_row))
347 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload")
348 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max():
349 if self._exit_stack is None:
350 raise RuntimeError("This operation requires the QueryContext to have been entered.")
351 # Either we're being asked to insert into a temporary table with a
352 # "persistent" lifetime (the duration of the QueryContext), or we
353 # have enough rows that the database says we can't use its
354 # "constant_rows" construct (e.g. VALUES). Note that
355 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is
356 # greater than or equal to `Database.get_constant_rows_max`.
357 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name))
358 while sql_rows:
359 self._db.insert(table, *sql_rows)
360 sql_rows.clear()
361 for relation_row in itertools.islice(iterator, self._row_chunk_size):
362 sql_rows.append(row_transformer.relation_to_sql(relation_row))
363 payload = sql.Payload[LogicalColumn](table)
364 else:
365 # Small number of rows; use Database.constant_rows.
366 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows))
367 payload.columns_available = self.sql_engine.extract_mapping(
368 source.columns, payload.from_clause.columns
369 )
370 return payload
372 def _strip_empty_invariant_operations(
373 self,
374 relation: Relation,
375 exact: bool,
376 ) -> Relation:
377 """Return a modified relation tree that strips out relation operations
378 that do not affect whether there are any rows from the end of the tree.
380 Parameters
381 ----------
382 relation : `Relation`
383 Original relation tree.
384 exact : `bool`
385 If `False` strip all iteration-engine operations, even those that
386 can remove all rows.
388 Returns
389 -------
390 modified : `Relation`
391 Modified relation tree.
392 """
393 if relation.payload is not None or relation.is_locked:
394 return relation
395 match relation:
396 case UnaryOperationRelation(operation=operation, target=target):
397 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine):
398 return self._strip_empty_invariant_operations(target, exact)
399 case Transfer(target=target):
400 return self._strip_empty_invariant_operations(target, exact)
401 case MarkerRelation(target=target):
402 return relation.reapply(self._strip_empty_invariant_operations(target, exact))
403 return relation
405 def _strip_count_invariant_operations(
406 self,
407 relation: Relation,
408 exact: bool,
409 ) -> Relation:
410 """Return a modified relation tree that strips out relation operations
411 that do not affect the number of rows from the end of the tree.
413 Parameters
414 ----------
415 relation : `Relation`
416 Original relation tree.
417 exact : `bool`
418 If `False` strip all iteration-engine operations, even those that
419 can affect the number of rows.
421 Returns
422 -------
423 modified : `Relation`
424 Modified relation tree.
425 """
426 if relation.payload is not None or relation.is_locked:
427 return relation
428 match relation:
429 case UnaryOperationRelation(operation=operation, target=target):
430 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine):
431 return self._strip_count_invariant_operations(target, exact)
432 case Transfer(target=target):
433 return self._strip_count_invariant_operations(target, exact)
434 case MarkerRelation(target=target):
435 return relation.reapply(self._strip_count_invariant_operations(target, exact))
436 return relation
439class _SqlRowIterable(iteration.RowIterable):
440 """An implementation of `lsst.daf.relation.iteration.RowIterable` that
441 executes a SQL query.
443 Parameters
444 ----------
445 context : `SqlQueryContext`
446 Context to execute the query with. If this context has already been
447 entered, a lazy iterable will be returned that may or may not use
448 server side cursors, since the context's lifetime can be used to manage
449 that cursor's lifetime. If the context has not been entered, all
450 results will be fetched up front in raw form, but will still be
451 processed into mappings keyed by `ColumnTag` lazily.
452 sql_executable : `sqlalchemy.sql.expression.SelectBase`
453 SQL query to execute; assumed to have been built by
454 `lsst.daf.relation.sql.Engine.to_executable` or similar.
455 row_transformer : `_SqlRowTransformer`
456 Object that converts SQLAlchemy result-row mappings (with `str`
457 keys and possibly-unpacked timespan values) to relation row
458 mappings (with `ColumnTag` keys and `LogicalColumn` values).
459 """
461 def __init__(
462 self,
463 context: SqlQueryContext,
464 sql_executable: sqlalchemy.sql.expression.SelectBase,
465 row_transformer: _SqlRowTransformer,
466 ):
467 self._context = context
468 self._sql_executable = sql_executable
469 self._row_transformer = row_transformer
471 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
472 if self._context._exit_stack is None:
473 # Have to read results into memory and close database connection.
474 with self._context._db.query(self._sql_executable) as sql_result:
475 rows = sql_result.mappings().fetchall()
476 for sql_row in rows:
477 yield self._row_transformer.sql_to_relation(sql_row)
478 else:
479 with self._context._db.query(self._sql_executable) as sql_result:
480 raw_rows = sql_result.mappings()
481 for sql_row in raw_rows:
482 yield self._row_transformer.sql_to_relation(sql_row)
485class _SqlRowTransformer:
486 """Object that converts SQLAlchemy result-row mappings to relation row
487 mappings.
489 Parameters
490 ----------
491 columns : `~collections.abc.Iterable` [ `ColumnTag` ]
492 Set of columns to handle. Rows must have at least these columns, but
493 may have more.
494 engine : `ButlerSqlEngine`
495 Relation engine; used to transform column tags into SQL identifiers and
496 obtain column type information.
497 """
499 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine):
500 self._scalar_columns: list[tuple[str, ColumnTag]] = []
501 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = []
502 for tag in columns:
503 if is_timespan_column(tag):
504 self._timespan_columns.append(
505 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls)
506 )
507 else:
508 self._scalar_columns.append((engine.get_identifier(tag), tag))
509 self._has_no_columns = not (self._scalar_columns or self._timespan_columns)
511 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns")
513 def sql_to_relation(self, sql_row: sqlalchemy.RowMapping) -> dict[ColumnTag, Any]:
514 """Convert a result row from a SQLAlchemy result into the form expected
515 by `lsst.daf.relation.iteration`.
517 Parameters
518 ----------
519 sql_row : `~collections.abc.Mapping`
520 Mapping with `str` keys and possibly-unpacked values for timespan
521 columns.
523 Returns
524 -------
525 relation_row : `~collections.abc.Mapping`
526 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
527 columns.
528 """
529 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns}
530 for identifier, tag, db_repr_cls in self._timespan_columns:
531 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier)
532 return relation_row
534 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]:
535 """Convert a `lsst.daf.relation.iteration` row into the mapping form
536 used by SQLAlchemy.
538 Parameters
539 ----------
540 relation_row : `~collections.abc.Mapping`
541 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
542 columns.
544 Returns
545 -------
546 sql_row : `~collections.abc.Mapping`
547 Mapping with `str` keys and possibly-unpacked values for timespan
548 columns.
549 """
550 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns}
551 for identifier, tag, db_repr_cls in self._timespan_columns:
552 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier)
553 if self._has_no_columns:
554 sql_row["IGNORED"] = None
555 return sql_row