Coverage for python/lsst/daf/butler/registry/queries/butler_sql_engine.py: 29%
79 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:44 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:44 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl
31__all__ = ("ButlerSqlEngine",)
33import dataclasses
34from collections.abc import Iterable, Set
35from typing import Any, cast
37import astropy.time
38import sqlalchemy
39from lsst.daf.relation import ColumnTag, Relation, Sort, UnaryOperation, UnaryOperationRelation, sql
41from ..._column_tags import is_timespan_column
42from ..._column_type_info import ColumnTypeInfo, LogicalColumn
43from ..._timespan import Timespan, TimespanDatabaseRepresentation
44from ..nameShrinker import NameShrinker
45from .find_first_dataset import FindFirstDataset
48@dataclasses.dataclass(repr=False, eq=False, kw_only=True)
49class ButlerSqlEngine(sql.Engine[LogicalColumn]):
50 """An extension of the `lsst.daf.relation.sql.Engine` class to add timespan
51 and `FindFirstDataset` operation support.
52 """
54 column_types: ColumnTypeInfo
55 """Struct containing information about column types that depend on registry
56 configuration.
57 """
59 name_shrinker: NameShrinker
60 """Object used to shrink the SQL identifiers that represent a ColumnTag to
61 fit within the database's limit (`NameShrinker`).
62 """
64 def __str__(self) -> str:
65 return self.name
67 def __repr__(self) -> str:
68 return f"ButlerSqlEngine({self.name!r})@{id(self):0x}"
70 def _append_unary_to_select(self, operation: UnaryOperation, target: sql.Select) -> sql.Select:
71 # Docstring inherited.
72 # This override exists to add support for the custom FindFirstDataset
73 # operation.
74 match operation:
75 case FindFirstDataset():
76 if target.has_sort and not target.has_slice:
77 # Existing target is sorted, but not sliced. We want to
78 # move that sort outside (i.e. after) the FindFirstDataset,
79 # since otherwise the FindFirstDataset would put the Sort
80 # into a CTE where it will do nothing.
81 inner = target.reapply_skip(sort=Sort())
82 return sql.Select.apply_skip(operation._finish_apply(inner), sort=target.sort)
83 else:
84 # Apply the FindFirstDataset directly to the existing
85 # target, which we've already asserted starts with a
86 # Select. That existing Select will be used for the CTE
87 # that starts the FindFirstDataset implementation (see
88 # to_payload override).
89 return sql.Select.apply_skip(operation._finish_apply(target))
90 case _:
91 return super()._append_unary_to_select(operation, target)
93 def get_identifier(self, tag: ColumnTag) -> str:
94 # Docstring inherited.
95 return self.name_shrinker.shrink(super().get_identifier(tag))
97 def extract_mapping(
98 self, tags: Iterable[ColumnTag], sql_columns: sqlalchemy.sql.ColumnCollection
99 ) -> dict[ColumnTag, LogicalColumn]:
100 # Docstring inherited.
101 # This override exists to add support for Timespan columns.
102 result: dict[ColumnTag, LogicalColumn] = {}
103 for tag in tags:
104 if is_timespan_column(tag):
105 result[tag] = self.column_types.timespan_cls.from_columns(
106 sql_columns, name=self.get_identifier(tag)
107 )
108 else:
109 result[tag] = sql_columns[self.get_identifier(tag)]
110 return result
112 def select_items(
113 self,
114 items: Iterable[tuple[ColumnTag, LogicalColumn]],
115 sql_from: sqlalchemy.sql.FromClause,
116 *extra: sqlalchemy.sql.ColumnElement,
117 ) -> sqlalchemy.sql.Select:
118 # Docstring inherited.
119 # This override exists to add support for Timespan columns.
120 select_columns: list[sqlalchemy.sql.ColumnElement] = []
121 for tag, logical_column in items:
122 if is_timespan_column(tag):
123 select_columns.extend(
124 cast(TimespanDatabaseRepresentation, logical_column).flatten(
125 name=self.get_identifier(tag)
126 )
127 )
128 else:
129 select_columns.append(
130 cast(sqlalchemy.sql.ColumnElement, logical_column).label(self.get_identifier(tag))
131 )
132 select_columns.extend(extra)
133 self.handle_empty_columns(select_columns)
134 return sqlalchemy.sql.select(*select_columns).select_from(sql_from)
136 def make_zero_select(self, tags: Set[ColumnTag]) -> sqlalchemy.sql.Select:
137 # Docstring inherited.
138 # This override exists to add support for Timespan columns.
139 select_columns: list[sqlalchemy.sql.ColumnElement] = []
140 for tag in tags:
141 if is_timespan_column(tag):
142 select_columns.extend(
143 self.column_types.timespan_cls.fromLiteral(None).flatten(name=self.get_identifier(tag))
144 )
145 else:
146 select_columns.append(sqlalchemy.sql.literal(None).label(self.get_identifier(tag)))
147 self.handle_empty_columns(select_columns)
148 return sqlalchemy.sql.select(*select_columns).where(sqlalchemy.sql.literal(False))
150 def convert_column_literal(self, value: Any) -> LogicalColumn:
151 # Docstring inherited.
152 # This override exists to add support for Timespan columns.
153 if isinstance(value, Timespan):
154 return self.column_types.timespan_cls.fromLiteral(value)
155 elif isinstance(value, astropy.time.Time):
156 return sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai)
157 else:
158 return super().convert_column_literal(value)
160 def to_payload(self, relation: Relation) -> sql.Payload[LogicalColumn]:
161 # Docstring inherited.
162 # This override exists to add support for the custom FindFirstDataset
163 # operation.
164 match relation:
165 case UnaryOperationRelation(operation=FindFirstDataset() as operation, target=target):
166 # We build a subquery of the form below to search the
167 # collections in order.
168 #
169 # WITH {dst}_search AS (
170 # {target}
171 # ...
172 # )
173 # SELECT
174 # {dst}_window.*,
175 # FROM (
176 # SELECT
177 # {dst}_search.*,
178 # ROW_NUMBER() OVER (
179 # PARTITION BY {dst_search}.{operation.dimensions}
180 # ORDER BY {operation.rank}
181 # ) AS rownum
182 # ) {dst}_window
183 # WHERE
184 # {dst}_window.rownum = 1;
185 #
186 # We'll start with the Common Table Expression (CTE) at the
187 # top, which we mostly get from the target relation.
188 search = self.to_executable(target).cte(f"{operation.rank.dataset_type}_search")
189 # Now we fill out the SELECT from the CTE, and the subquery it
190 # contains (at the same time, since they have the same columns,
191 # aside from the special 'rownum' window-function column).
192 search_columns = self.extract_mapping(target.columns, search.columns)
193 partition_by = [search_columns[tag] for tag in operation.dimensions]
194 rownum_column = sqlalchemy.sql.func.row_number()
195 if partition_by:
196 rownum_column = rownum_column.over(
197 partition_by=partition_by, order_by=search_columns[operation.rank]
198 )
199 else:
200 rownum_column = rownum_column.over(order_by=search_columns[operation.rank])
201 window = self.select_items(
202 search_columns.items(), search, rownum_column.label("rownum")
203 ).subquery(f"{operation.rank.dataset_type}_window")
204 return sql.Payload(
205 from_clause=window,
206 columns_available=self.extract_mapping(target.columns, window.columns),
207 where=[window.columns["rownum"] == 1],
208 )
209 case _:
210 return super().to_payload(relation)