Coverage for python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py: 16%
123 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-03 02:48 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-03 02:48 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("SqlColumnVisitor",)
32from typing import TYPE_CHECKING, Any
34import sqlalchemy
36from .. import ddl
37from ..queries import tree as qt
38from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor
39from ..timespan_database_representation import TimespanDatabaseRepresentation
41if TYPE_CHECKING:
42 from ._driver import DirectQueryDriver
43 from ._query_builder import QueryJoiner
46class SqlColumnVisitor(
47 ColumnExpressionVisitor[sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation],
48 PredicateVisitor[
49 sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool]
50 ],
51):
52 """A column expression visitor that constructs a `sqlalchemy.ColumnElement`
53 expression tree.
55 Parameters
56 ----------
57 joiner : `QueryJoiner`
58 `QueryJoiner` that provides SQL columns for column-reference
59 expressions.
60 driver : `QueryDriver`
61 Driver used to construct nested queries for "in query" predicates.
62 """
64 def __init__(self, joiner: QueryJoiner, driver: DirectQueryDriver):
65 self._driver = driver
66 self._joiner = joiner
68 def visit_literal(
69 self, expression: qt.ColumnLiteral
70 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
71 # Docstring inherited.
72 if expression.column_type == "timespan":
73 return self._driver.db.getTimespanRepresentation().fromLiteral(expression.get_literal_value())
74 return sqlalchemy.literal(
75 expression.get_literal_value(), type_=ddl.VALID_CONFIG_COLUMN_TYPES[expression.column_type]
76 )
78 def visit_dimension_key_reference(
79 self, expression: qt.DimensionKeyReference
80 ) -> sqlalchemy.ColumnElement[int | str]:
81 # Docstring inherited.
82 return self._joiner.dimension_keys[expression.dimension.name][0]
84 def visit_dimension_field_reference(
85 self, expression: qt.DimensionFieldReference
86 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
87 # Docstring inherited.
88 if expression.column_type == "timespan":
89 return self._joiner.timespans[expression.element.name]
90 return self._joiner.fields[expression.element.name][expression.field]
92 def visit_dataset_field_reference(
93 self, expression: qt.DatasetFieldReference
94 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
95 # Docstring inherited.
96 if expression.column_type == "timespan":
97 return self._joiner.timespans[expression.dataset_type]
98 return self._joiner.fields[expression.dataset_type][expression.field]
100 def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]:
101 # Docstring inherited.
102 match expression.operator:
103 case "-":
104 return -self.expect_scalar(expression.operand)
105 case "begin_of":
106 return self.expect_timespan(expression.operand).lower()
107 case "end_of":
108 return self.expect_timespan(expression.operand).upper()
109 raise AssertionError(f"Invalid unary expression operator {expression.operator!r}.")
111 def visit_binary_expression(self, expression: qt.BinaryExpression) -> sqlalchemy.ColumnElement[Any]:
112 # Docstring inherited.
113 a = self.expect_scalar(expression.a)
114 b = self.expect_scalar(expression.b)
115 match expression.operator:
116 case "+":
117 return a + b
118 case "-":
119 return a - b
120 case "*":
121 return a * b
122 case "/":
123 return a / b
124 case "%":
125 return a % b
126 raise AssertionError(f"Invalid binary expression operator {expression.operator!r}.")
128 def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[Any]:
129 # Docstring inherited.
130 return self.expect_scalar(expression.operand).desc()
132 def visit_comparison(
133 self,
134 a: qt.ColumnExpression,
135 operator: qt.ComparisonOperator,
136 b: qt.ColumnExpression,
137 flags: PredicateVisitFlags,
138 ) -> sqlalchemy.ColumnElement[bool]:
139 # Docstring inherited.
140 if operator == "overlaps":
141 assert a.column_type == "timespan", "Spatial overlaps should be transformed away by now."
142 return self.expect_timespan(a).overlaps(self.expect_timespan(b))
143 lhs = self.expect_scalar(a)
144 rhs = self.expect_scalar(b)
145 match operator:
146 case "==":
147 return lhs == rhs
148 case "!=":
149 return lhs != rhs
150 case "<":
151 return lhs < rhs
152 case ">":
153 return lhs > rhs
154 case "<=":
155 return lhs <= rhs
156 case ">=":
157 return lhs >= rhs
158 raise AssertionError(f"Invalid comparison operator {operator!r}.")
160 def visit_is_null(
161 self, operand: qt.ColumnExpression, flags: PredicateVisitFlags
162 ) -> sqlalchemy.ColumnElement[bool]:
163 # Docstring inherited.
164 if operand.column_type == "timespan":
165 return self.expect_timespan(operand).isNull()
166 return self.expect_scalar(operand) == sqlalchemy.null()
168 def visit_in_container(
169 self,
170 member: qt.ColumnExpression,
171 container: tuple[qt.ColumnExpression, ...],
172 flags: PredicateVisitFlags,
173 ) -> sqlalchemy.ColumnElement[bool]:
174 # Docstring inherited.
175 return self.expect_scalar(member).in_([self.expect_scalar(item) for item in container])
177 def visit_in_range(
178 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags
179 ) -> sqlalchemy.ColumnElement[bool]:
180 # Docstring inherited.
181 sql_member = self.expect_scalar(member)
182 if stop is None:
183 target = sql_member >= sqlalchemy.literal(start)
184 else:
185 stop_inclusive = stop - 1
186 if start == stop_inclusive:
187 return sql_member == sqlalchemy.literal(start)
188 else:
189 target = sqlalchemy.sql.between(
190 sql_member,
191 sqlalchemy.literal(start),
192 sqlalchemy.literal(stop_inclusive),
193 )
194 if step != 1:
195 return sqlalchemy.sql.and_(
196 *[
197 target,
198 sql_member % sqlalchemy.literal(step) == sqlalchemy.literal(start % step),
199 ]
200 )
201 else:
202 return target
204 def visit_in_query_tree(
205 self,
206 member: qt.ColumnExpression,
207 column: qt.ColumnExpression,
208 query_tree: qt.QueryTree,
209 flags: PredicateVisitFlags,
210 ) -> sqlalchemy.ColumnElement[bool]:
211 # Docstring inherited.
212 columns = qt.ColumnSet(self._driver.universe.empty.as_group())
213 column.gather_required_columns(columns)
214 _, builder = self._driver.build_query(query_tree, columns)
215 if builder.postprocessing:
216 raise NotImplementedError(
217 "Right-hand side subquery in IN expression would require postprocessing."
218 )
219 subquery_visitor = SqlColumnVisitor(builder.joiner, self._driver)
220 builder.joiner.special["_MEMBER"] = subquery_visitor.expect_scalar(column)
221 builder.columns = qt.ColumnSet(self._driver.universe.empty.as_group())
222 subquery_select = builder.select()
223 sql_member = self.expect_scalar(member)
224 return sql_member.in_(subquery_select)
226 def apply_logical_and(
227 self, originals: qt.PredicateOperands, results: tuple[sqlalchemy.ColumnElement[bool], ...]
228 ) -> sqlalchemy.ColumnElement[bool]:
229 # Docstring inherited.
230 match len(results):
231 case 0:
232 return sqlalchemy.true()
233 case 1:
234 return results[0]
235 case _:
236 return sqlalchemy.and_(*results)
238 def apply_logical_or(
239 self,
240 originals: tuple[qt.PredicateLeaf, ...],
241 results: tuple[sqlalchemy.ColumnElement[bool], ...],
242 flags: PredicateVisitFlags,
243 ) -> sqlalchemy.ColumnElement[bool]:
244 # Docstring inherited.
245 match len(results):
246 case 0:
247 return sqlalchemy.false()
248 case 1:
249 return results[0]
250 case _:
251 return sqlalchemy.or_(*results)
253 def apply_logical_not(
254 self, original: qt.PredicateLeaf, result: sqlalchemy.ColumnElement[bool], flags: PredicateVisitFlags
255 ) -> sqlalchemy.ColumnElement[bool]:
256 # Docstring inherited.
257 return sqlalchemy.not_(result)
259 def expect_scalar(self, expression: qt.OrderExpression) -> sqlalchemy.ColumnElement[Any]:
260 result = expression.visit(self)
261 assert isinstance(result, sqlalchemy.ColumnElement)
262 return result
264 def expect_timespan(self, expression: qt.ColumnExpression) -> TimespanDatabaseRepresentation:
265 result = expression.visit(self)
266 assert isinstance(result, TimespanDatabaseRepresentation)
267 return result