Coverage for python/lsst/daf/butler/registry/queries/_sql_query_context.py: 13%
238 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 10:24 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 10:24 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("SqlQueryContext",)
31import itertools
32from collections.abc import Iterable, Iterator, Mapping, Set
33from contextlib import ExitStack
34from typing import TYPE_CHECKING, Any, cast
36import sqlalchemy
37from lsst.daf.relation import (
38 BinaryOperationRelation,
39 Calculation,
40 Chain,
41 ColumnTag,
42 Deduplication,
43 Engine,
44 EngineError,
45 Join,
46 MarkerRelation,
47 Projection,
48 Relation,
49 Transfer,
50 UnaryOperation,
51 UnaryOperationRelation,
52 iteration,
53 sql,
54)
56from ..._column_tags import is_timespan_column
57from ..._column_type_info import ColumnTypeInfo, LogicalColumn
58from ...name_shrinker import NameShrinker
59from ...timespan_database_representation import TimespanDatabaseRepresentation
60from ._query_context import QueryContext
61from .butler_sql_engine import ButlerSqlEngine
63if TYPE_CHECKING:
64 from ..interfaces import Database
67class SqlQueryContext(QueryContext):
68 """An implementation of `sql.QueryContext` for `SqlRegistry`.
70 Parameters
71 ----------
72 db : `Database`
73 Object that abstracts the database engine.
74 column_types : `ColumnTypeInfo`
75 Information about column types that can vary with registry
76 configuration.
77 row_chunk_size : `int`, optional
78 Number of rows to insert into temporary tables at once. If this is
79 lower than ``db.get_constant_rows_max()`` it will be set to that value.
80 """
82 def __init__(
83 self,
84 db: Database,
85 column_types: ColumnTypeInfo,
86 row_chunk_size: int = 1000,
87 ):
88 super().__init__()
89 self.sql_engine = ButlerSqlEngine(
90 column_types=column_types, name_shrinker=NameShrinker(db.dialect.max_identifier_length)
91 )
92 self._row_chunk_size = max(row_chunk_size, db.get_constant_rows_max())
93 self._db = db
94 self._exit_stack: ExitStack | None = None
96 def __enter__(self) -> SqlQueryContext:
97 assert self._exit_stack is None, "Context manager already entered."
98 self._exit_stack = ExitStack().__enter__()
99 self._exit_stack.enter_context(self._db.session())
100 return self
102 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
103 assert self._exit_stack is not None, "Context manager not yet entered."
104 result = self._exit_stack.__exit__(exc_type, exc_value, traceback)
105 self._exit_stack = None
106 return result
108 @property
109 def is_open(self) -> bool:
110 return self._exit_stack is not None
112 @property
113 def column_types(self) -> ColumnTypeInfo:
114 """Information about column types that depend on registry configuration
115 (`ColumnTypeInfo`).
116 """
117 return self.sql_engine.column_types
119 @property
120 def preferred_engine(self) -> Engine:
121 # Docstring inherited.
122 return self.sql_engine
124 def count(self, relation: Relation, *, exact: bool = True, discard: bool = False) -> int:
125 # Docstring inherited.
126 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())
127 if relation.engine == self.sql_engine:
128 sql_executable = self.sql_engine.to_executable(
129 relation, extra_columns=[sqlalchemy.sql.func.count()]
130 )
131 with self._db.query(sql_executable) as sql_result:
132 return cast(int, sql_result.scalar_one())
133 elif (rows := relation.payload) is not None:
134 assert isinstance(
135 rows, iteration.MaterializedRowIterable
136 ), "Query guarantees that only materialized payloads are attached to its relations."
137 return len(rows)
138 elif discard:
139 n = 0
140 for _ in self.fetch_iterable(relation):
141 n += 1
142 return n
143 else:
144 raise RuntimeError(
145 f"Query with relation {relation} has deferred operations that "
146 "must be executed in Python in order to obtain an exact "
147 "count. Pass discard=True to run the query and discard the "
148 "result rows while counting them, run the query first, or "
149 "pass exact=False to obtain an upper bound."
150 )
152 def any(self, relation: Relation, *, execute: bool = True, exact: bool = True) -> bool:
153 # Docstring inherited.
154 relation = self._strip_count_invariant_operations(relation, exact).with_only_columns(frozenset())[:1]
155 if relation.engine == self.sql_engine:
156 sql_executable = self.sql_engine.to_executable(relation)
157 with self._db.query(sql_executable) as sql_result:
158 return sql_result.one_or_none() is not None
159 elif (rows := relation.payload) is not None:
160 assert isinstance(
161 rows, iteration.MaterializedRowIterable
162 ), "Query guarantees that only materialized payloads are attached to its relations."
163 return bool(rows)
164 elif execute:
165 for _ in self.fetch_iterable(relation):
166 return True
167 return False
168 else:
169 raise RuntimeError(
170 f"Query with relation {relation} has deferred operations that "
171 "must be executed in Python in order to obtain an exact "
172 "check for whether any rows would be returned. Pass "
173 "execute=True to run the query until at least "
174 "one row is found, run the query first, or pass exact=False "
175 "test only whether the query _might_ have result rows."
176 )
178 def transfer(self, source: Relation, destination: Engine, materialize_as: str | None) -> Any:
179 # Docstring inherited from lsst.daf.relation.Processor.
180 if source.engine == self.sql_engine and destination == self.iteration_engine:
181 return self._sql_to_iteration(source, materialize_as)
182 if source.engine == self.iteration_engine and destination == self.sql_engine:
183 return self._iteration_to_sql(source, materialize_as)
184 raise EngineError(f"Unexpected transfer for SqlQueryContext: {source.engine} -> {destination}.")
186 def materialize(self, target: Relation, name: str) -> Any:
187 # Docstring inherited from lsst.daf.relation.Processor.
188 if target.engine == self.sql_engine:
189 sql_executable = self.sql_engine.to_executable(target)
190 table_spec = self.column_types.make_relation_table_spec(target.columns)
191 if self._exit_stack is None:
192 raise RuntimeError("This operation requires the QueryContext to have been entered.")
193 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name=name))
194 self._db.insert(table, select=sql_executable)
195 payload = sql.Payload[LogicalColumn](table)
196 payload.columns_available = self.sql_engine.extract_mapping(target.columns, table.columns)
197 return payload
198 return super().materialize(target, name)
200 def restore_columns(
201 self,
202 relation: Relation,
203 columns_required: Set[ColumnTag],
204 ) -> tuple[Relation, set[ColumnTag]]:
205 # Docstring inherited.
206 if relation.is_locked:
207 return relation, set(relation.columns & columns_required)
208 match relation:
209 case UnaryOperationRelation(operation=operation, target=target):
210 new_target, columns_found = self.restore_columns(target, columns_required)
211 if not columns_found:
212 return relation, columns_found
213 match operation:
214 case Projection(columns=columns):
215 new_columns = columns | columns_found
216 if new_columns != columns:
217 return (
218 Projection(new_columns).apply(new_target),
219 columns_found,
220 )
221 case Calculation(tag=tag):
222 if tag in columns_required:
223 columns_found.add(tag)
224 case Deduplication():
225 # Pulling key columns through the deduplication would
226 # fundamentally change what it does; have to limit
227 # ourselves to non-key columns here, and put a
228 # Projection back in place.
229 columns_found = {c for c in columns_found if not c.is_key}
230 new_target = new_target.with_only_columns(relation.columns | columns_found)
231 return relation.reapply(new_target), columns_found
232 case BinaryOperationRelation(operation=operation, lhs=lhs, rhs=rhs):
233 new_lhs, columns_found_lhs = self.restore_columns(lhs, columns_required)
234 new_rhs, columns_found_rhs = self.restore_columns(rhs, columns_required)
235 match operation:
236 case Join():
237 return relation.reapply(new_lhs, new_rhs), columns_found_lhs | columns_found_rhs
238 case Chain():
239 if columns_found_lhs != columns_found_rhs:
240 # Got different answers for different join
241 # branches; let's try again with just columns found
242 # in both branches.
243 new_columns_required = columns_found_lhs & columns_found_rhs
244 if not new_columns_required:
245 return relation, set()
246 new_lhs, columns_found_lhs = self.restore_columns(lhs, new_columns_required)
247 new_rhs, columns_found_rhs = self.restore_columns(rhs, new_columns_required)
248 assert columns_found_lhs == columns_found_rhs
249 return relation.reapply(new_lhs, new_rhs), columns_found_lhs
250 case MarkerRelation(target=target):
251 new_target, columns_found = self.restore_columns(target, columns_required)
252 return relation.reapply(new_target), columns_found
253 raise AssertionError("Match should be exhaustive and all branches should return.")
255 def strip_postprocessing(self, relation: Relation) -> tuple[Relation, list[UnaryOperation]]:
256 # Docstring inherited.
257 if relation.engine != self.iteration_engine or relation.is_locked:
258 return relation, []
259 match relation:
260 case UnaryOperationRelation(operation=operation, target=target):
261 new_target, stripped = self.strip_postprocessing(target)
262 stripped.append(operation)
263 return new_target, stripped
264 case Transfer(destination=self.iteration_engine, target=target):
265 return target, []
266 return relation, []
268 def drop_invalidated_postprocessing(self, relation: Relation, new_columns: Set[ColumnTag]) -> Relation:
269 # Docstring inherited.
270 if relation.engine != self.iteration_engine or relation.is_locked:
271 return relation
272 match relation:
273 case UnaryOperationRelation(operation=operation, target=target):
274 columns_added_here = relation.columns - target.columns
275 new_target = self.drop_invalidated_postprocessing(target, new_columns - columns_added_here)
276 if operation.columns_required <= new_columns:
277 # This operation can stay...
278 if new_target is target:
279 # ...and nothing has actually changed upstream of it.
280 return relation
281 else:
282 # ...but we should see if it can now be performed in
283 # the preferred engine, since something it didn't
284 # commute with may have been removed.
285 return operation.apply(new_target, preferred_engine=self.preferred_engine)
286 else:
287 # This operation needs to be dropped, as we're about to
288 # project away one more more of the columns it needs.
289 return new_target
290 return relation
292 def _sql_to_iteration(self, source: Relation, materialize_as: str | None) -> iteration.RowIterable:
293 """Execute a relation transfer from the SQL engine to the iteration
294 engine.
296 Parameters
297 ----------
298 source : `~lsst.daf.relation.Relation`
299 Input relation, in what is assumed to be `sql_engine`.
300 materialize_as : `str` or `None`
301 If not `None`, the name of a persistent materialization to apply
302 in the iteration engine. This fetches all rows up front.
304 Returns
305 -------
306 destination : `~lsst.daf.relation.iteration.RowIterable`
307 Iteration engine payload iterable with the same content as the
308 given SQL relation.
309 """
310 sql_executable = self.sql_engine.to_executable(source)
311 rows: iteration.RowIterable = _SqlRowIterable(
312 self, sql_executable, _SqlRowTransformer(source.columns, self.sql_engine)
313 )
314 if materialize_as is not None:
315 rows = rows.materialized()
316 return rows
318 def _iteration_to_sql(self, source: Relation, materialize_as: str | None) -> sql.Payload:
319 """Execute a relation transfer from the iteration engine to the SQL
320 engine.
322 Parameters
323 ----------
324 source : `~lsst.daf.relation.Relation`
325 Input relation, in what is assumed to be `iteration_engine`.
326 materialize_as : `str` or `None`
327 If not `None`, the name of a persistent materialization to apply
328 in the SQL engine. This sets the name of the temporary table
329 and ensures that one is used (instead of `Database.constant_rows`,
330 which might otherwise be used for small row sets).
332 Returns
333 -------
334 destination : `~lsst.daf.relation.sql.Payload`
335 SQL engine payload struct with the same content as the given
336 iteration relation.
337 """
338 iterable = self.iteration_engine.execute(source)
339 table_spec = self.column_types.make_relation_table_spec(source.columns)
340 row_transformer = _SqlRowTransformer(source.columns, self.sql_engine)
341 sql_rows = []
342 iterator = iter(iterable)
343 # Iterate over the first chunk of rows and transform them to hold only
344 # types recognized by SQLAlchemy (i.e. expand out
345 # TimespanDatabaseRepresentations). Note that this advances
346 # `iterator`.
347 for relation_row in itertools.islice(iterator, self._row_chunk_size):
348 sql_rows.append(row_transformer.relation_to_sql(relation_row))
349 name = materialize_as if materialize_as is not None else self.sql_engine.get_relation_name("upload")
350 if materialize_as is not None or len(sql_rows) > self._db.get_constant_rows_max():
351 if self._exit_stack is None:
352 raise RuntimeError("This operation requires the QueryContext to have been entered.")
353 # Either we're being asked to insert into a temporary table with a
354 # "persistent" lifetime (the duration of the QueryContext), or we
355 # have enough rows that the database says we can't use its
356 # "constant_rows" construct (e.g. VALUES). Note that
357 # `QueryContext.__init__` guarantees that `self._row_chunk_size` is
358 # greater than or equal to `Database.get_constant_rows_max`.
359 table = self._exit_stack.enter_context(self._db.temporary_table(table_spec, name))
360 while sql_rows:
361 self._db.insert(table, *sql_rows)
362 sql_rows.clear()
363 for relation_row in itertools.islice(iterator, self._row_chunk_size):
364 sql_rows.append(row_transformer.relation_to_sql(relation_row))
365 payload = sql.Payload[LogicalColumn](table)
366 else:
367 # Small number of rows; use Database.constant_rows.
368 payload = sql.Payload[LogicalColumn](self._db.constant_rows(table_spec.fields, *sql_rows))
369 payload.columns_available = self.sql_engine.extract_mapping(
370 source.columns, payload.from_clause.columns
371 )
372 return payload
374 def _strip_empty_invariant_operations(
375 self,
376 relation: Relation,
377 exact: bool,
378 ) -> Relation:
379 """Return a modified relation tree that strips out relation operations
380 that do not affect whether there are any rows from the end of the tree.
382 Parameters
383 ----------
384 relation : `Relation`
385 Original relation tree.
386 exact : `bool`
387 If `False` strip all iteration-engine operations, even those that
388 can remove all rows.
390 Returns
391 -------
392 modified : `Relation`
393 Modified relation tree.
394 """
395 if relation.payload is not None or relation.is_locked:
396 return relation
397 match relation:
398 case UnaryOperationRelation(operation=operation, target=target):
399 if operation.is_empty_invariant or (not exact and target.engine == self.iteration_engine):
400 return self._strip_empty_invariant_operations(target, exact)
401 case Transfer(target=target):
402 return self._strip_empty_invariant_operations(target, exact)
403 case MarkerRelation(target=target):
404 return relation.reapply(self._strip_empty_invariant_operations(target, exact))
405 return relation
407 def _strip_count_invariant_operations(
408 self,
409 relation: Relation,
410 exact: bool,
411 ) -> Relation:
412 """Return a modified relation tree that strips out relation operations
413 that do not affect the number of rows from the end of the tree.
415 Parameters
416 ----------
417 relation : `Relation`
418 Original relation tree.
419 exact : `bool`
420 If `False` strip all iteration-engine operations, even those that
421 can affect the number of rows.
423 Returns
424 -------
425 modified : `Relation`
426 Modified relation tree.
427 """
428 if relation.payload is not None or relation.is_locked:
429 return relation
430 match relation:
431 case UnaryOperationRelation(operation=operation, target=target):
432 if operation.is_count_invariant or (not exact and target.engine == self.iteration_engine):
433 return self._strip_count_invariant_operations(target, exact)
434 case Transfer(target=target):
435 return self._strip_count_invariant_operations(target, exact)
436 case MarkerRelation(target=target):
437 return relation.reapply(self._strip_count_invariant_operations(target, exact))
438 return relation
441class _SqlRowIterable(iteration.RowIterable):
442 """An implementation of `lsst.daf.relation.iteration.RowIterable` that
443 executes a SQL query.
445 Parameters
446 ----------
447 context : `SqlQueryContext`
448 Context to execute the query with. If this context has already been
449 entered, a lazy iterable will be returned that may or may not use
450 server side cursors, since the context's lifetime can be used to manage
451 that cursor's lifetime. If the context has not been entered, all
452 results will be fetched up front in raw form, but will still be
453 processed into mappings keyed by `ColumnTag` lazily.
454 sql_executable : `sqlalchemy.sql.expression.SelectBase`
455 SQL query to execute; assumed to have been built by
456 `lsst.daf.relation.sql.Engine.to_executable` or similar.
457 row_transformer : `_SqlRowTransformer`
458 Object that converts SQLAlchemy result-row mappings (with `str`
459 keys and possibly-unpacked timespan values) to relation row
460 mappings (with `ColumnTag` keys and `LogicalColumn` values).
461 """
463 def __init__(
464 self,
465 context: SqlQueryContext,
466 sql_executable: sqlalchemy.sql.expression.SelectBase,
467 row_transformer: _SqlRowTransformer,
468 ):
469 self._context = context
470 self._sql_executable = sql_executable
471 self._row_transformer = row_transformer
473 def __iter__(self) -> Iterator[Mapping[ColumnTag, Any]]:
474 if self._context._exit_stack is None:
475 # Have to read results into memory and close database connection.
476 with self._context._db.query(self._sql_executable) as sql_result:
477 rows = sql_result.mappings().fetchall()
478 for sql_row in rows:
479 yield self._row_transformer.sql_to_relation(sql_row)
480 else:
481 with self._context._db.query(self._sql_executable) as sql_result:
482 raw_rows = sql_result.mappings()
483 for sql_row in raw_rows:
484 yield self._row_transformer.sql_to_relation(sql_row)
487class _SqlRowTransformer:
488 """Object that converts SQLAlchemy result-row mappings to relation row
489 mappings.
491 Parameters
492 ----------
493 columns : `~collections.abc.Iterable` [ `ColumnTag` ]
494 Set of columns to handle. Rows must have at least these columns, but
495 may have more.
496 engine : `ButlerSqlEngine`
497 Relation engine; used to transform column tags into SQL identifiers and
498 obtain column type information.
499 """
501 def __init__(self, columns: Iterable[ColumnTag], engine: ButlerSqlEngine):
502 self._scalar_columns: list[tuple[str, ColumnTag]] = []
503 self._timespan_columns: list[tuple[str, ColumnTag, type[TimespanDatabaseRepresentation]]] = []
504 for tag in columns:
505 if is_timespan_column(tag):
506 self._timespan_columns.append(
507 (engine.get_identifier(tag), tag, engine.column_types.timespan_cls)
508 )
509 else:
510 self._scalar_columns.append((engine.get_identifier(tag), tag))
511 self._has_no_columns = not (self._scalar_columns or self._timespan_columns)
513 __slots__ = ("_scalar_columns", "_timespan_columns", "_has_no_columns")
515 def sql_to_relation(self, sql_row: sqlalchemy.RowMapping) -> dict[ColumnTag, Any]:
516 """Convert a result row from a SQLAlchemy result into the form expected
517 by `lsst.daf.relation.iteration`.
519 Parameters
520 ----------
521 sql_row : `~collections.abc.Mapping`
522 Mapping with `str` keys and possibly-unpacked values for timespan
523 columns.
525 Returns
526 -------
527 relation_row : `~collections.abc.Mapping`
528 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
529 columns.
530 """
531 relation_row = {tag: sql_row[identifier] for identifier, tag in self._scalar_columns}
532 for identifier, tag, db_repr_cls in self._timespan_columns:
533 relation_row[tag] = db_repr_cls.extract(sql_row, name=identifier)
534 return relation_row
536 def relation_to_sql(self, relation_row: Mapping[ColumnTag, Any]) -> dict[str, Any]:
537 """Convert a `lsst.daf.relation.iteration` row into the mapping form
538 used by SQLAlchemy.
540 Parameters
541 ----------
542 relation_row : `~collections.abc.Mapping`
543 Mapping with `ColumnTag` keys and `Timespan` objects for timespan
544 columns.
546 Returns
547 -------
548 sql_row : `~collections.abc.Mapping`
549 Mapping with `str` keys and possibly-unpacked values for timespan
550 columns.
551 """
552 sql_row = {identifier: relation_row[tag] for identifier, tag in self._scalar_columns}
553 for identifier, tag, db_repr_cls in self._timespan_columns:
554 db_repr_cls.update(relation_row[tag], result=sql_row, name=identifier)
555 if self._has_no_columns:
556 sql_row["IGNORED"] = None
557 return sql_row