Coverage for python / lsst / daf / butler / direct_query_driver / _sql_column_visitor.py: 15%
158 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:41 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:41 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("SqlColumnVisitor",)
32import warnings
33from typing import TYPE_CHECKING, Any
35import erfa
36import sqlalchemy
38from .. import ddl
39from .._exceptions import InvalidQueryError
40from ..queries import tree as qt
41from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor
42from ..timespan_database_representation import TimespanDatabaseRepresentation
44if TYPE_CHECKING:
45 from ._driver import DirectQueryDriver
46 from ._sql_builders import SqlColumns
49class SqlColumnVisitor(
50 ColumnExpressionVisitor[sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation],
51 PredicateVisitor[
52 sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool]
53 ],
54):
55 """A column expression visitor that constructs a `sqlalchemy.ColumnElement`
56 expression tree.
58 Parameters
59 ----------
60 columns : `QueryColumns`
61 `QueryColumns` that provides SQL columns for column-reference
62 expressions.
63 driver : `QueryDriver`
64 Driver used to construct nested queries for "in query" predicates.
65 """
67 def __init__(self, columns: SqlColumns, driver: DirectQueryDriver):
68 self._driver = driver
69 self._columns = columns
71 def visit_literal(
72 self, expression: qt.ColumnLiteral
73 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
74 # Docstring inherited.
75 if expression.column_type == "timespan":
76 return self._driver.db.getTimespanRepresentation().fromLiteral(expression.get_literal_value())
77 return sqlalchemy.literal(
78 expression.get_literal_value(), type_=ddl.VALID_CONFIG_COLUMN_TYPES[expression.column_type]
79 )
81 def visit_dimension_key_reference(
82 self, expression: qt.DimensionKeyReference
83 ) -> sqlalchemy.ColumnElement[int | str]:
84 # Docstring inherited.
85 return self._columns.dimension_keys[expression.dimension.name][0]
87 def visit_dimension_field_reference(
88 self, expression: qt.DimensionFieldReference
89 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
90 # Docstring inherited.
91 if expression.column_type == "timespan":
92 return self._columns.timespans[expression.element.name]
93 return self._columns.fields[expression.element.name][expression.field]
95 def visit_dataset_field_reference(
96 self, expression: qt.DatasetFieldReference
97 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation:
98 # Docstring inherited.
99 if expression.column_type == "timespan":
100 return self._columns.timespans[expression.dataset_type]
101 return self._columns.fields[expression.dataset_type][expression.field]
103 def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]:
104 # Docstring inherited.
105 match expression.operator:
106 case "-":
107 return -self.expect_scalar(expression.operand)
108 case "begin_of":
109 return self.expect_timespan(expression.operand).lower()
110 case "end_of":
111 return self.expect_timespan(expression.operand).upper()
112 raise AssertionError(f"Invalid unary expression operator {expression.operator!r}.")
114 def visit_binary_expression(self, expression: qt.BinaryExpression) -> sqlalchemy.ColumnElement[Any]:
115 # Docstring inherited.
116 a = self.expect_scalar(expression.a)
117 b = self.expect_scalar(expression.b)
118 match expression.operator:
119 case "+":
120 return a + b
121 case "-":
122 return a - b
123 case "*":
124 return a * b
125 case "/":
126 return a / b
127 case "%":
128 return a % b
129 raise AssertionError(f"Invalid binary expression operator {expression.operator!r}.")
131 def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[Any]:
132 # Docstring inherited.
133 return self.expect_scalar(expression.operand).desc()
135 def visit_boolean_wrapper(
136 self, value: qt.ColumnExpression, flags: PredicateVisitFlags
137 ) -> sqlalchemy.ColumnElement[bool]:
138 return self.expect_scalar(value)
140 def visit_comparison(
141 self,
142 a: qt.ColumnExpression,
143 operator: qt.ComparisonOperator,
144 b: qt.ColumnExpression,
145 flags: PredicateVisitFlags,
146 ) -> sqlalchemy.ColumnElement[bool]:
147 # Docstring inherited.
148 if operator == "overlaps":
149 if values := (qt.is_one_timespan_and_one_datetime(a, b)):
150 return self.expect_timespan(values.timespan).contains(self.expect_scalar(values.datetime))
151 elif values := (qt.is_one_timespan_and_one_ingest_date(a, b)):
152 return self.expect_timespan(values.timespan).contains(
153 self._convert_ingest_date_to_datetime(values.datetime)
154 )
155 elif a.column_type == "timespan" and b.column_type == "timespan":
156 return self.expect_timespan(a).overlaps(self.expect_timespan(b))
157 else:
158 # Spatial overlaps should be transformed away by now.
159 raise AssertionError(
160 f"Unexpected types {a.column_type},{b.column_type} in overlaps operator."
161 )
163 if operator == "glob":
164 # Second operand must be a string literal.
165 if isinstance(b, qt._column_literal.StringColumnLiteral):
166 pattern = b.value
167 expression = self.expect_scalar(a)
168 return self._driver.db.glob_expression(expression, pattern)
169 else:
170 raise AssertionError(f"Unexpected pattern type ({type(b)}) in glob operator.")
172 lhs = self.expect_scalar(a)
173 rhs = self.expect_scalar(b)
174 # Special case to handle awkward situation where ingest_date is not
175 # always the same type as other datetime columns.
176 if qt.is_one_datetime_and_one_ingest_date(a, b):
177 if a.column_type == "datetime":
178 lhs = self._convert_datetime_to_ingest_date(a)
179 elif b.column_type == "datetime":
180 rhs = self._convert_datetime_to_ingest_date(b)
182 match operator:
183 case "==":
184 return lhs == rhs
185 case "!=":
186 return lhs != rhs
187 case "<":
188 return lhs < rhs
189 case ">":
190 return lhs > rhs
191 case "<=":
192 return lhs <= rhs
193 case ">=":
194 return lhs >= rhs
195 raise AssertionError(f"Invalid comparison operator {operator!r}.")
197 def visit_is_null(
198 self, operand: qt.ColumnExpression, flags: PredicateVisitFlags
199 ) -> sqlalchemy.ColumnElement[bool]:
200 # Docstring inherited.
201 if operand.column_type == "timespan":
202 return self.expect_timespan(operand).isNull()
203 return self.expect_scalar(operand) == sqlalchemy.null()
205 def visit_in_container(
206 self,
207 member: qt.ColumnExpression,
208 container: tuple[qt.ColumnExpression, ...],
209 flags: PredicateVisitFlags,
210 ) -> sqlalchemy.ColumnElement[bool]:
211 # Docstring inherited.
212 return self.expect_scalar(member).in_([self.expect_scalar(item) for item in container])
214 def visit_in_range(
215 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags
216 ) -> sqlalchemy.ColumnElement[bool]:
217 # Docstring inherited.
218 sql_member = self.expect_scalar(member)
219 if stop is None:
220 target = sql_member >= sqlalchemy.literal(start)
221 else:
222 stop_inclusive = stop - 1
223 if start == stop_inclusive:
224 return sql_member == sqlalchemy.literal(start)
225 else:
226 target = sqlalchemy.sql.between(
227 sql_member,
228 sqlalchemy.literal(start),
229 sqlalchemy.literal(stop_inclusive),
230 )
231 if step != 1:
232 return sqlalchemy.sql.and_(
233 *[
234 target,
235 sql_member % sqlalchemy.literal(step) == sqlalchemy.literal(start % step),
236 ]
237 )
238 else:
239 return target
241 def visit_in_query_tree(
242 self,
243 member: qt.ColumnExpression,
244 column: qt.ColumnExpression,
245 query_tree: qt.QueryTree,
246 flags: PredicateVisitFlags,
247 ) -> sqlalchemy.ColumnElement[bool]:
248 # Docstring inherited.
249 columns = qt.ColumnSet(self._driver.universe.empty)
250 column.gather_required_columns(columns)
251 builder = self._driver.build_query(query_tree, columns)
252 if builder.postprocessing:
253 raise NotImplementedError(
254 "Right-hand side subquery in IN expression would require postprocessing."
255 )
256 select_builder = builder.finish_nested()
257 subquery_visitor = SqlColumnVisitor(select_builder.joins, self._driver)
258 select_builder.joins.special["_MEMBER"] = subquery_visitor.expect_scalar(column)
259 select_builder.columns = qt.ColumnSet(self._driver.universe.empty)
260 subquery_select = select_builder.select(postprocessing=None)
261 sql_member = self.expect_scalar(member)
262 return sql_member.in_(subquery_select)
264 def apply_logical_and(
265 self, originals: qt.PredicateOperands, results: tuple[sqlalchemy.ColumnElement[bool], ...]
266 ) -> sqlalchemy.ColumnElement[bool]:
267 # Docstring inherited.
268 match len(results):
269 case 0:
270 return sqlalchemy.true()
271 case 1:
272 return results[0]
273 case _:
274 return sqlalchemy.and_(*results)
276 def apply_logical_or(
277 self,
278 originals: tuple[qt.PredicateLeaf, ...],
279 results: tuple[sqlalchemy.ColumnElement[bool], ...],
280 flags: PredicateVisitFlags,
281 ) -> sqlalchemy.ColumnElement[bool]:
282 # Docstring inherited.
283 match len(results):
284 case 0:
285 return sqlalchemy.false()
286 case 1:
287 return results[0]
288 case _:
289 return sqlalchemy.or_(*results)
291 def apply_logical_not(
292 self, original: qt.PredicateLeaf, result: sqlalchemy.ColumnElement[bool], flags: PredicateVisitFlags
293 ) -> sqlalchemy.ColumnElement[bool]:
294 # Docstring inherited.
295 return sqlalchemy.not_(result)
297 def expect_scalar(self, expression: qt.OrderExpression) -> sqlalchemy.ColumnElement[Any]:
298 result = expression.visit(self)
299 assert isinstance(result, sqlalchemy.ColumnElement)
300 return result
302 def _convert_datetime_to_ingest_date(self, expression: qt.ColumnExpression) -> sqlalchemy.ColumnElement:
303 assert expression.column_type == "datetime"
304 if self._driver.managers.datasets.ingest_date_dtype() == sqlalchemy.TIMESTAMP:
305 # Datasets manager v1 schema has "datasets" table's "ingest_date"
306 # column as TIMESTAMP, but the rest of the database schema and
307 # query system uses integer TAI nanoseconds. So we have to convert
308 # the nanoseconds value to a timestamp.
309 # Note that this loses precision.
310 if expression.expression_type != "datetime":
311 # Conversion between TAI and UTC can't be done in the database,
312 # so we are only able to handle literal values here.
313 raise InvalidQueryError("Only literal date-time values can be compared with ingest date.")
315 # The conversion from TAI to UTC can trigger a warning for dates
316 # in the future so catch those warnings.
317 with warnings.catch_warnings():
318 warnings.simplefilter("ignore", category=erfa.ErfaWarning)
319 dt = expression.value.utc.to_datetime()
320 return sqlalchemy.literal(dt)
321 else:
322 # For v2 schema, ingest_date uses TAI nanoseconds like everything
323 # else, so no conversion is required.
324 return self.expect_scalar(expression)
326 def _convert_ingest_date_to_datetime(self, expression: qt.ColumnExpression) -> sqlalchemy.ColumnElement:
327 assert expression.column_type == "ingest_date"
328 if self._driver.managers.datasets.ingest_date_dtype() == sqlalchemy.TIMESTAMP:
329 # Datasets manager v1 schema has "datasets" table's "ingest_date"
330 # column as TIMESTAMP, but the rest of the database schema and
331 # query system uses integer TAI nanoseconds.
332 raise InvalidQueryError(
333 "Expressions involving ingest date are not supported by this database schema."
334 )
335 else:
336 # For v2 schema, ingest_date uses TAI nanoseconds like everything
337 # else, so no conversion is required.
338 return self.expect_scalar(expression)
340 def expect_timespan(self, expression: qt.ColumnExpression) -> TimespanDatabaseRepresentation:
341 result = expression.visit(self)
342 assert isinstance(result, TimespanDatabaseRepresentation)
343 return result