Coverage for python/lsst/daf/butler/registry/queries/butler_sql_engine.py: 27%
76 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("ButlerSqlEngine",)
31import dataclasses
32from collections.abc import Iterable, Set
33from typing import Any, cast
35import astropy.time
36import sqlalchemy
37from lsst.daf.relation import ColumnTag, Relation, Sort, UnaryOperation, UnaryOperationRelation, sql
39from ...core import (
40 ColumnTypeInfo,
41 LogicalColumn,
42 Timespan,
43 TimespanDatabaseRepresentation,
44 ddl,
45 is_timespan_column,
46)
47from ..nameShrinker import NameShrinker
48from .find_first_dataset import FindFirstDataset
51@dataclasses.dataclass(repr=False, eq=False, kw_only=True)
52class ButlerSqlEngine(sql.Engine[LogicalColumn]):
53 """An extension of the `lsst.daf.relation.sql.Engine` class to add timespan
54 and `FindFirstDataset` operation support.
55 """
57 column_types: ColumnTypeInfo
58 """Struct containing information about column types that depend on registry
59 configuration.
60 """
62 name_shrinker: NameShrinker
63 """Object used to shrink the SQL identifiers that represent a ColumnTag to
64 fit within the database's limit (`NameShrinker`).
65 """
67 def __str__(self) -> str:
68 return self.name
70 def __repr__(self) -> str:
71 return f"ButlerSqlEngine({self.name!r})@{id(self):0x}"
73 def _append_unary_to_select(self, operation: UnaryOperation, target: sql.Select) -> sql.Select:
74 # Docstring inherited.
75 # This override exists to add support for the custom FindFirstDataset
76 # operation.
77 match operation:
78 case FindFirstDataset():
79 if target.has_sort and not target.has_slice:
80 # Existing target is sorted, but not sliced. We want to
81 # move that sort outside (i.e. after) the FindFirstDataset,
82 # since otherwise the FindFirstDataset would put the Sort
83 # into a CTE where it will do nothing.
84 inner = target.reapply_skip(sort=Sort())
85 return sql.Select.apply_skip(operation._finish_apply(inner), sort=target.sort)
86 else:
87 # Apply the FindFirstDataset directly to the existing
88 # target, which we've already asserted starts with a
89 # Select. That existing Select will be used for the CTE
90 # that starts the FindFirstDataset implementation (see
91 # to_payload override).
92 return sql.Select.apply_skip(operation._finish_apply(target))
93 case _:
94 return super()._append_unary_to_select(operation, target)
96 def get_identifier(self, tag: ColumnTag) -> str:
97 # Docstring inherited.
98 return self.name_shrinker.shrink(super().get_identifier(tag))
100 def extract_mapping(
101 self, tags: Iterable[ColumnTag], sql_columns: sqlalchemy.sql.ColumnCollection
102 ) -> dict[ColumnTag, LogicalColumn]:
103 # Docstring inherited.
104 # This override exists to add support for Timespan columns.
105 result: dict[ColumnTag, LogicalColumn] = {}
106 for tag in tags:
107 if is_timespan_column(tag):
108 result[tag] = self.column_types.timespan_cls.from_columns(
109 sql_columns, name=self.get_identifier(tag)
110 )
111 else:
112 result[tag] = sql_columns[self.get_identifier(tag)]
113 return result
115 def select_items(
116 self,
117 items: Iterable[tuple[ColumnTag, LogicalColumn]],
118 sql_from: sqlalchemy.sql.FromClause,
119 *extra: sqlalchemy.sql.ColumnElement,
120 ) -> sqlalchemy.sql.Select:
121 # Docstring inherited.
122 # This override exists to add support for Timespan columns.
123 select_columns: list[sqlalchemy.sql.ColumnElement] = []
124 for tag, logical_column in items:
125 if is_timespan_column(tag):
126 select_columns.extend(
127 cast(TimespanDatabaseRepresentation, logical_column).flatten(
128 name=self.get_identifier(tag)
129 )
130 )
131 else:
132 select_columns.append(
133 cast(sqlalchemy.sql.ColumnElement, logical_column).label(self.get_identifier(tag))
134 )
135 select_columns.extend(extra)
136 self.handle_empty_columns(select_columns)
137 return sqlalchemy.sql.select(*select_columns).select_from(sql_from)
139 def make_zero_select(self, tags: Set[ColumnTag]) -> sqlalchemy.sql.Select:
140 # Docstring inherited.
141 # This override exists to add support for Timespan columns.
142 select_columns: list[sqlalchemy.sql.ColumnElement] = []
143 for tag in tags:
144 if is_timespan_column(tag):
145 select_columns.extend(
146 self.column_types.timespan_cls.fromLiteral(None).flatten(name=self.get_identifier(tag))
147 )
148 else:
149 select_columns.append(sqlalchemy.sql.literal(None).label(self.get_identifier(tag)))
150 self.handle_empty_columns(select_columns)
151 return sqlalchemy.sql.select(*select_columns).where(sqlalchemy.sql.literal(False))
153 def convert_column_literal(self, value: Any) -> LogicalColumn:
154 # Docstring inherited.
155 # This override exists to add support for Timespan columns.
156 if isinstance(value, Timespan):
157 return self.column_types.timespan_cls.fromLiteral(value)
158 elif isinstance(value, astropy.time.Time):
159 return sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai)
160 else:
161 return super().convert_column_literal(value)
163 def to_payload(self, relation: Relation) -> sql.Payload[LogicalColumn]:
164 # Docstring inherited.
165 # This override exists to add support for the custom FindFirstDataset
166 # operation.
167 match relation:
168 case UnaryOperationRelation(operation=FindFirstDataset() as operation, target=target):
169 # We build a subquery of the form below to search the
170 # collections in order.
171 #
172 # WITH {dst}_search AS (
173 # {target}
174 # ...
175 # )
176 # SELECT
177 # {dst}_window.*,
178 # FROM (
179 # SELECT
180 # {dst}_search.*,
181 # ROW_NUMBER() OVER (
182 # PARTITION BY {dst_search}.{operation.dimensions}
183 # ORDER BY {operation.rank}
184 # ) AS rownum
185 # ) {dst}_window
186 # WHERE
187 # {dst}_window.rownum = 1;
188 #
189 # We'll start with the Common Table Expression (CTE) at the
190 # top, which we mostly get from the target relation.
191 search = self.to_executable(target).cte(f"{operation.rank.dataset_type}_search")
192 # Now we fill out the SELECT from the CTE, and the subquery it
193 # contains (at the same time, since they have the same columns,
194 # aside from the special 'rownum' window-function column).
195 search_columns = self.extract_mapping(target.columns, search.columns)
196 partition_by = [search_columns[tag] for tag in operation.dimensions]
197 rownum_column = sqlalchemy.sql.func.row_number()
198 if partition_by:
199 rownum_column = rownum_column.over(
200 partition_by=partition_by, order_by=search_columns[operation.rank]
201 )
202 else:
203 rownum_column = rownum_column.over(order_by=search_columns[operation.rank])
204 window = self.select_items(
205 search_columns.items(), search, rownum_column.label("rownum")
206 ).subquery(f"{operation.rank.dataset_type}_window")
207 return sql.Payload(
208 from_clause=window,
209 columns_available=self.extract_mapping(target.columns, window.columns),
210 where=[window.columns["rownum"] == 1],
211 )
212 case _:
213 return super().to_payload(relation)