Coverage for python/lsst/daf/butler/registry/queries/butler_sql_engine.py: 29%
85 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-04 02:55 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-04 02:55 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ("ButlerSqlEngine",)
33import dataclasses
34from collections.abc import Iterable, Set
35from typing import Any, cast
37import astropy.time
38import sqlalchemy
39from lsst.daf.relation import ColumnTag, Relation, Sort, UnaryOperation, UnaryOperationRelation, sql
41from ..._column_tags import is_timespan_column
42from ..._column_type_info import ColumnTypeInfo, LogicalColumn
43from ..._timespan import Timespan
44from ...timespan_database_representation import TimespanDatabaseRepresentation
45from ..nameShrinker import NameShrinker
46from .find_first_dataset import FindFirstDataset
49@dataclasses.dataclass(repr=False, eq=False, kw_only=True)
50class ButlerSqlEngine(sql.Engine[LogicalColumn]):
51 """An extension of the `lsst.daf.relation.sql.Engine` class to add timespan
52 and `FindFirstDataset` operation support.
53 """
55 column_types: ColumnTypeInfo
56 """Struct containing information about column types that depend on registry
57 configuration.
58 """
60 name_shrinker: NameShrinker
61 """Object used to shrink the SQL identifiers that represent a ColumnTag to
62 fit within the database's limit (`NameShrinker`).
63 """
65 def __str__(self) -> str:
66 return self.name
68 def __repr__(self) -> str:
69 return f"ButlerSqlEngine({self.name!r})@{id(self):0x}"
71 def _append_unary_to_select(self, operation: UnaryOperation, target: sql.Select) -> sql.Select:
72 # Docstring inherited.
73 # This override exists to add support for the custom FindFirstDataset
74 # operation.
75 match operation:
76 case FindFirstDataset():
77 if target.has_sort and not target.has_slice:
78 # Existing target is sorted, but not sliced. We want to
79 # move that sort outside (i.e. after) the FindFirstDataset,
80 # since otherwise the FindFirstDataset would put the Sort
81 # into a CTE where it will do nothing.
82 inner = target.reapply_skip(sort=Sort())
83 return sql.Select.apply_skip(operation._finish_apply(inner), sort=target.sort)
84 else:
85 # Apply the FindFirstDataset directly to the existing
86 # target, which we've already asserted starts with a
87 # Select. That existing Select will be used for the CTE
88 # that starts the FindFirstDataset implementation (see
89 # to_payload override).
90 return sql.Select.apply_skip(operation._finish_apply(target))
91 case _:
92 return super()._append_unary_to_select(operation, target)
94 def get_identifier(self, tag: ColumnTag) -> str:
95 # Docstring inherited.
96 return self.name_shrinker.shrink(super().get_identifier(tag))
98 def extract_mapping(
99 self, tags: Iterable[ColumnTag], sql_columns: sqlalchemy.sql.ColumnCollection
100 ) -> dict[ColumnTag, LogicalColumn]:
101 # Docstring inherited.
102 # This override exists to add support for Timespan columns.
103 result: dict[ColumnTag, LogicalColumn] = {}
104 for tag in tags:
105 if is_timespan_column(tag):
106 result[tag] = self.column_types.timespan_cls.from_columns(
107 sql_columns, name=self.get_identifier(tag)
108 )
109 else:
110 result[tag] = sql_columns[self.get_identifier(tag)]
111 return result
113 def select_items(
114 self,
115 items: Iterable[tuple[ColumnTag, LogicalColumn]],
116 sql_from: sqlalchemy.sql.FromClause,
117 *extra: sqlalchemy.sql.ColumnElement,
118 ) -> sqlalchemy.sql.Select:
119 # Docstring inherited.
120 # This override exists to add support for Timespan columns.
121 select_columns: list[sqlalchemy.sql.ColumnElement] = []
122 for tag, logical_column in items:
123 if is_timespan_column(tag):
124 select_columns.extend(
125 cast(TimespanDatabaseRepresentation, logical_column).flatten(
126 name=self.get_identifier(tag)
127 )
128 )
129 else:
130 select_columns.append(
131 cast(sqlalchemy.sql.ColumnElement, logical_column).label(self.get_identifier(tag))
132 )
133 select_columns.extend(extra)
134 self.handle_empty_columns(select_columns)
135 return sqlalchemy.sql.select(*select_columns).select_from(sql_from)
137 def make_zero_select(self, tags: Set[ColumnTag]) -> sqlalchemy.sql.Select:
138 # Docstring inherited.
139 # This override exists to add support for Timespan columns.
140 select_columns: list[sqlalchemy.sql.ColumnElement] = []
141 for tag in tags:
142 if is_timespan_column(tag):
143 select_columns.extend(
144 self.column_types.timespan_cls.fromLiteral(None).flatten(name=self.get_identifier(tag))
145 )
146 else:
147 select_columns.append(sqlalchemy.sql.literal(None).label(self.get_identifier(tag)))
148 self.handle_empty_columns(select_columns)
149 return sqlalchemy.sql.select(*select_columns).where(sqlalchemy.sql.literal(False))
151 def convert_column_literal(self, value: Any) -> LogicalColumn:
152 # Docstring inherited.
153 # This override exists to add support for Timespan columns.
154 if isinstance(value, Timespan):
155 return self.column_types.timespan_cls.fromLiteral(value)
156 elif isinstance(value, astropy.time.Time):
157 return sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai)
158 else:
159 return super().convert_column_literal(value)
161 def to_payload(self, relation: Relation) -> sql.Payload[LogicalColumn]:
162 # Docstring inherited.
163 # This override exists to add support for the custom FindFirstDataset
164 # operation.
165 match relation:
166 case UnaryOperationRelation(operation=FindFirstDataset() as operation, target=target):
167 # We build a subquery of the form below to search the
168 # collections in order.
169 #
170 # WITH {dst}_search AS (
171 # {target}
172 # ...
173 # )
174 # SELECT
175 # {dst}_window.*,
176 # FROM (
177 # SELECT
178 # {dst}_search.*,
179 # ROW_NUMBER() OVER (
180 # PARTITION BY {dst_search}.{operation.dimensions}
181 # ORDER BY {operation.rank}
182 # ) AS rownum
183 # ) {dst}_window
184 # WHERE
185 # {dst}_window.rownum = 1;
186 #
187 # We'll start with the Common Table Expression (CTE) at the
188 # top, which we mostly get from the target relation.
189 search = self.to_executable(target).cte(f"{operation.rank.dataset_type}_search")
190 # Now we fill out the SELECT from the CTE, and the subquery it
191 # contains (at the same time, since they have the same columns,
192 # aside from the special 'rownum' window-function column).
193 search_columns = self.extract_mapping(target.columns, search.columns)
194 partition_by = [
195 _assert_column_is_directly_usable_by_sqlalchemy(search_columns[tag])
196 for tag in operation.dimensions
197 ]
198 row_number = sqlalchemy.sql.func.row_number()
199 rank_column = _assert_column_is_directly_usable_by_sqlalchemy(search_columns[operation.rank])
200 if partition_by:
201 rownum_column = row_number.over(partition_by=partition_by, order_by=rank_column)
202 else:
203 rownum_column = row_number.over(order_by=rank_column)
204 window = self.select_items(
205 search_columns.items(), search, rownum_column.label("rownum")
206 ).subquery(f"{operation.rank.dataset_type}_window")
207 return sql.Payload(
208 from_clause=window,
209 columns_available=self.extract_mapping(target.columns, window.columns),
210 where=[window.columns["rownum"] == 1],
211 )
212 case _:
213 return super().to_payload(relation)
216def _assert_column_is_directly_usable_by_sqlalchemy(column: LogicalColumn) -> sqlalchemy.sql.ColumnElement:
217 """Narrow a `LogicalColumn` to a SqlAlchemy ColumnElement, to satisfy the
218 typechecker in cases where no timespans are expected.
219 """
220 if isinstance(column, TimespanDatabaseRepresentation):
221 raise TypeError("Timespans not expected here.")
223 return column