Coverage for python/lsst/daf/butler/registry/queries/expressions/_predicate.py: 13%
210 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:44 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:44 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("make_string_expression_predicate", "ExpressionTypeError")
31import builtins
32import datetime
33import types
34import warnings
35from collections.abc import Mapping, Set
36from typing import Any, cast
38import astropy.time
39import astropy.utils.exceptions
40from lsst.daf.relation import (
41 ColumnContainer,
42 ColumnExpression,
43 ColumnExpressionSequence,
44 ColumnLiteral,
45 ColumnTag,
46 Predicate,
47 sql,
48)
50# We import the some modules rather than types within them because match syntax
51# uses qualified names with periods to distinguish literals from captures.
52from .... import _timespan, timespan_database_representation
53from ...._column_tags import DatasetColumnTag, DimensionKeyColumnTag, DimensionRecordColumnTag
54from ...._column_type_info import ColumnTypeInfo
55from ....dimensions import DataCoordinate, Dimension, DimensionGroup, DimensionUniverse
56from ..._exceptions import UserExpressionError, UserExpressionSyntaxError
57from .categorize import ExpressionConstant, categorizeConstant, categorizeElementId
58from .check import CheckVisitor
59from .normalForm import NormalForm, NormalFormExpression
60from .parser import Node, ParserYacc, TreeVisitor # type: ignore
62# As of astropy 4.2, the erfa interface is shipped independently and
63# ErfaWarning is no longer an AstropyWarning
64try:
65 import erfa
66except ImportError:
67 erfa = None
70class ExpressionTypeError(TypeError):
71 """Exception raised when the types in a query expression are not
72 compatible with the operators or other syntax.
73 """
76def make_string_expression_predicate(
77 string: str,
78 dimensions: DimensionGroup,
79 *,
80 column_types: ColumnTypeInfo,
81 bind: Mapping[str, Any] | None = None,
82 data_id: DataCoordinate | None = None,
83 defaults: DataCoordinate | None = None,
84 dataset_type_name: str | None = None,
85 allow_orphans: bool = False,
86) -> tuple[Predicate | None, Mapping[str, Set[str]]]:
87 """Create a predicate by parsing and analyzing a string expression.
89 Parameters
90 ----------
91 string : `str`
92 String to parse.
93 dimensions : `DimensionGroup`
94 The dimensions the query would include in the absence of this WHERE
95 expression.
96 column_types : `ColumnTypeInfo`
97 Information about column types.
98 bind : `~collections.abc.Mapping` [ `str`, `Any` ], optional
99 Literal values referenced in the expression.
100 data_id : `DataCoordinate`, optional
101 A fully-expanded data ID identifying dimensions known in advance.
102 If not provided, will be set to an empty data ID.
103 ``dataId.hasRecords()`` must return `True`.
104 defaults : `DataCoordinate`, optional
105 A data ID containing default for governor dimensions. Ignored
106 unless ``check=True``.
107 dataset_type_name : `str` or `None`, optional
108 The name of the dataset type to assume for unqualified dataset
109 columns, or `None` if there are no such identifiers.
110 allow_orphans : `bool`, optional
111 If `True`, permit expressions to refer to dimensions without
112 providing a value for their governor dimensions (e.g. referring to
113 a visit without an instrument). Should be left to default to
114 `False` in essentially all new code.
116 Returns
117 -------
118 predicate : `lsst.daf.relation.colum_expressions.Predicate` or `None`
119 New predicate derived from the string expression, or `None` if the
120 string is empty.
121 governor_constraints : `~collections.abc.Mapping` [ `str` , \
122 `~collections.abc.Set` ]
123 Constraints on dimension values derived from the expression and data
124 ID.
125 """
126 governor_constraints: dict[str, Set[str]] = {}
127 if data_id is None:
128 data_id = DataCoordinate.make_empty(dimensions.universe)
129 if not string:
130 for dimension in data_id.dimensions.governors:
131 governor_constraints[dimension] = {cast(str, data_id[dimension])}
132 return None, governor_constraints
133 try:
134 parser = ParserYacc()
135 tree = parser.parse(string)
136 except Exception as exc:
137 raise UserExpressionSyntaxError(f"Failed to parse user expression {string!r}.") from exc
138 if bind is None:
139 bind = {}
140 if bind:
141 for identifier in bind:
142 if identifier in dimensions.universe.elements.names:
143 raise RuntimeError(f"Bind parameter key {identifier!r} conflicts with a dimension element.")
144 table, _, column = identifier.partition(".")
145 if column and table in dimensions.universe.elements.names:
146 raise RuntimeError(f"Bind parameter key {identifier!r} looks like a dimension column.")
147 if defaults is None:
148 defaults = DataCoordinate.make_empty(dimensions.universe)
149 # Convert the expression to disjunctive normal form (ORs of ANDs).
150 # That's potentially super expensive in the general case (where there's
151 # a ton of nesting of ANDs and ORs). That won't be the case for the
152 # expressions we expect, and we actually use disjunctive normal instead
153 # of conjunctive (i.e. ANDs of ORs) because I think the worst-case is
154 # a long list of OR'd-together data IDs, which is already in or very
155 # close to disjunctive normal form.
156 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE)
157 # Check the expression for consistency and completeness.
158 visitor = CheckVisitor(data_id, dimensions, bind, defaults, allow_orphans=allow_orphans)
159 try:
160 summary = expr.visit(visitor)
161 except UserExpressionError as err:
162 exprOriginal = str(tree)
163 exprNormal = str(expr.toTree())
164 if exprNormal == exprOriginal:
165 msg = f'Error in query expression "{exprOriginal}": {err}'
166 else:
167 msg = f'Error in query expression "{exprOriginal}" (normalized to "{exprNormal}"): {err}'
168 raise UserExpressionError(msg) from None
169 for dimension_name, values in summary.dimension_constraints.items():
170 if dimension_name in dimensions.universe.governor_dimensions.names:
171 governor_constraints[dimension_name] = cast(Set[str], values)
172 converter = PredicateConversionVisitor(bind, dataset_type_name, dimensions.universe, column_types)
173 predicate = tree.visit(converter)
174 return predicate, governor_constraints
177VisitorResult = Predicate | ColumnExpression | ColumnContainer
180class PredicateConversionVisitor(TreeVisitor[VisitorResult]):
181 def __init__(
182 self,
183 bind: Mapping[str, Any],
184 dataset_type_name: str | None,
185 universe: DimensionUniverse,
186 column_types: ColumnTypeInfo,
187 ):
188 self.bind = bind
189 self.dataset_type_name = dataset_type_name
190 self.universe = universe
191 self.column_types = column_types
193 OPERATOR_MAP = {
194 "=": "__eq__",
195 "!=": "__ne__",
196 "<": "__lt__",
197 ">": "__gt__",
198 "<=": "__le__",
199 ">=": "__ge__",
200 "+": "__add__",
201 "-": "__sub__",
202 "/": "__mul__",
203 }
205 def to_datetime(self, time: astropy.time.Time) -> datetime.datetime:
206 with warnings.catch_warnings():
207 warnings.simplefilter("ignore", category=astropy.utils.exceptions.AstropyWarning)
208 if erfa is not None:
209 warnings.simplefilter("ignore", category=erfa.ErfaWarning)
210 return time.to_datetime()
212 def visitBinaryOp(
213 self, operator: str, lhs: VisitorResult, rhs: VisitorResult, node: Node
214 ) -> VisitorResult:
215 # Docstring inherited.
216 b = builtins
217 match (operator, lhs, rhs):
218 case ["OR", Predicate() as lhs, Predicate() as rhs]:
219 return lhs.logical_or(rhs)
220 case ["AND", Predicate() as lhs, Predicate() as rhs]:
221 return lhs.logical_and(rhs)
222 # Allow all comparisons between expressions of the same type for
223 # sortable types.
224 case [
225 "=" | "!=" | "<" | ">" | "<=" | ">=",
226 ColumnExpression(
227 dtype=b.int | b.float | b.str | astropy.time.Time | datetime.datetime
228 ) as lhs,
229 ColumnExpression() as rhs,
230 ] if lhs.dtype is rhs.dtype:
231 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
232 # Allow comparisons between datetime expressions and
233 # astropy.time.Time literals/binds (only), by coercing the
234 # astropy.time.Time version to datetime.
235 case [
236 "=" | "!=" | "<" | ">" | "<=" | ">=",
237 ColumnLiteral(dtype=astropy.time.Time) as lhs,
238 ColumnExpression(dtype=datetime.datetime) as rhs,
239 ]:
240 lhs = ColumnLiteral(self.to_datetime(lhs.value), datetime.datetime)
241 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
242 case [
243 "=" | "!=" | "<" | ">" | "<=" | ">=",
244 ColumnExpression(dtype=datetime.datetime) as lhs,
245 ColumnLiteral(dtype=astropy.time.Time) as rhs,
246 ]:
247 rhs = ColumnLiteral(self.to_datetime(rhs.value), datetime.datetime)
248 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
249 # Allow comparisons between astropy.time.Time expressions and
250 # datetime literals/binds, by coercing the
251 # datetime literals to astropy.time.Time (in UTC scale).
252 case [
253 "=" | "!=" | "<" | ">" | "<=" | ">=",
254 ColumnLiteral(dtype=datetime.datetime) as lhs,
255 ColumnExpression(dtype=astropy.time.Time) as rhs,
256 ]:
257 lhs = ColumnLiteral(astropy.time.Time(lhs.value, scale="utc"), astropy.time.Time)
258 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
259 case [
260 "=" | "!=" | "<" | ">" | "<=" | ">=",
261 ColumnExpression(dtype=astropy.time.Time) as lhs,
262 ColumnLiteral(dtype=datetime.datetime) as rhs,
263 ]:
264 rhs = ColumnLiteral(astropy.time.Time(rhs.value, scale="utc"), astropy.time.Time)
265 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
266 # Allow equality comparisons with None/NULL. We don't have an 'IS'
267 # operator.
268 case ["=" | "!=", ColumnExpression(dtype=types.NoneType) as lhs, ColumnExpression() as rhs]:
269 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
270 case ["=" | "!=", ColumnExpression() as lhs, ColumnExpression(dtype=types.NoneType) as rhs]:
271 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
272 # Comparisions between Time and Timespan need have the Timespan on
273 # the lhs, since that (actually TimespanDatabaseRepresentation) is
274 # what actually has the methods.
275 case [
276 "<",
277 ColumnExpression(dtype=astropy.time.Time) as lhs,
278 ColumnExpression(dtype=_timespan.Timespan) as rhs,
279 ]:
280 return rhs.predicate_method(self.OPERATOR_MAP[">"], lhs)
281 case [
282 ">",
283 ColumnExpression(dtype=astropy.time.Time) as lhs,
284 ColumnExpression(dtype=_timespan.Timespan) as rhs,
285 ]:
286 return rhs.predicate_method(self.OPERATOR_MAP["<"], lhs)
287 # Enable other comparisons between times and Timespans (many of the
288 # combinations matched by this branch will have already been
289 # covered by a preceding branch).
290 case [
291 "<" | ">",
292 ColumnExpression(dtype=_timespan.Timespan | astropy.time.Time) as lhs,
293 ColumnExpression(dtype=_timespan.Timespan | astropy.time.Time) as rhs,
294 ]:
295 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
296 # Enable "overlaps" operations between timespans, and between times
297 # and timespans. The latter resolve to the `Timespan.contains` or
298 # `TimespanDatabaseRepresentation.contains` methods, but we use
299 # OVERLAPS in the string expression language to keep that simple.
300 case [
301 "OVERLAPS",
302 ColumnExpression(dtype=_timespan.Timespan) as lhs,
303 ColumnExpression(dtype=_timespan.Timespan) as rhs,
304 ]:
305 return lhs.predicate_method("overlaps", rhs)
306 case [
307 "OVERLAPS",
308 ColumnExpression(dtype=_timespan.Timespan) as lhs,
309 ColumnExpression(dtype=astropy.time.Time) as rhs,
310 ]:
311 return lhs.predicate_method("overlaps", rhs)
312 case [
313 "OVERLAPS",
314 ColumnExpression(dtype=astropy.time.Time) as lhs,
315 ColumnExpression(dtype=_timespan.Timespan) as rhs,
316 ]:
317 return rhs.predicate_method("overlaps", lhs)
318 # Enable arithmetic operators on numeric types, without any type
319 # coercion or broadening.
320 case [
321 "+" | "-" | "*",
322 ColumnExpression(dtype=b.int | b.float) as lhs,
323 ColumnExpression() as rhs,
324 ] if lhs.dtype is rhs.dtype:
325 return lhs.method(self.OPERATOR_MAP[operator], rhs, dtype=lhs.dtype)
326 case ["/", ColumnExpression(dtype=b.float) as lhs, ColumnExpression(dtype=b.float) as rhs]:
327 return lhs.method("__truediv__", rhs, dtype=b.float)
328 case ["/", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
329 # SQLAlchemy maps Python's '/' (__truediv__) operator directly
330 # to SQL's '/', despite those being defined differently for
331 # integers. Our expression language uses the SQL definition,
332 # and we only care about these expressions being evaluated in
333 # SQL right now, but we still want to guard against it being
334 # evaluated in Python and producing a surprising answer, so we
335 # mark it as being supported only by a SQL engine.
336 return lhs.method(
337 "__truediv__",
338 rhs,
339 dtype=b.int,
340 supporting_engine_types={sql.Engine},
341 )
342 case ["%", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
343 return lhs.method("__mod__", rhs, dtype=b.int)
344 assert (
345 lhs.dtype is not None and rhs.dtype is not None
346 ), "Expression converter should not yield untyped nodes."
347 raise ExpressionTypeError(
348 f"Invalid types {lhs.dtype.__name__}, {rhs.dtype.__name__} for binary operator {operator!r} "
349 f"in expression {node!s}."
350 )
352 def visitIdentifier(self, name: str, node: Node) -> VisitorResult:
353 # Docstring inherited.
354 if name in self.bind:
355 value = self.bind[name]
356 if isinstance(value, list | tuple | Set):
357 elements = []
358 all_dtypes = set()
359 for item in value:
360 dtype = type(item)
361 all_dtypes.add(dtype)
362 elements.append(ColumnExpression.literal(item, dtype=dtype))
363 if len(all_dtypes) > 1:
364 raise ExpressionTypeError(
365 f"Mismatched types in bind iterable: {value} has a mix of {all_dtypes}."
366 )
367 elif not elements:
368 # Empty container
369 return ColumnContainer.sequence([])
370 else:
371 (dtype,) = all_dtypes
372 return ColumnContainer.sequence(elements, dtype=dtype)
373 return ColumnExpression.literal(value, dtype=type(value))
374 tag: ColumnTag
375 match categorizeConstant(name):
376 case ExpressionConstant.INGEST_DATE:
377 assert self.dataset_type_name is not None
378 tag = DatasetColumnTag(self.dataset_type_name, "ingest_date")
379 return ColumnExpression.reference(tag, self.column_types.ingest_date_pytype)
380 case ExpressionConstant.NULL:
381 return ColumnExpression.literal(None, type(None))
382 case None:
383 pass
384 case _:
385 raise AssertionError("Check for enum values should be exhaustive.")
386 element, column = categorizeElementId(self.universe, name)
387 if column is not None:
388 tag = DimensionRecordColumnTag(element.name, column)
389 dtype = (
390 _timespan.Timespan
391 if column == timespan_database_representation.TimespanDatabaseRepresentation.NAME
392 else element.RecordClass.fields.standard[column].getPythonType()
393 )
394 return ColumnExpression.reference(tag, dtype)
395 else:
396 tag = DimensionKeyColumnTag(element.name)
397 assert isinstance(element, Dimension)
398 return ColumnExpression.reference(tag, element.primaryKey.getPythonType())
400 def visitIsIn(
401 self, lhs: VisitorResult, values: list[VisitorResult], not_in: bool, node: Node
402 ) -> VisitorResult:
403 # Docstring inherited.
404 clauses: list[Predicate] = []
405 items: list[ColumnExpression] = []
406 assert isinstance(lhs, ColumnExpression), "LHS of IN guaranteed to be scalar by parser."
407 for rhs_item in values:
408 match rhs_item:
409 case ColumnExpressionSequence(
410 items=rhs_items, dtype=rhs_dtype
411 ) if rhs_dtype is None or rhs_dtype == lhs.dtype:
412 items.extend(rhs_items)
413 case ColumnContainer(dtype=lhs.dtype):
414 clauses.append(rhs_item.contains(lhs))
415 case ColumnExpression(dtype=lhs.dtype):
416 items.append(rhs_item)
417 case _:
418 raise ExpressionTypeError(
419 f"Invalid type {rhs_item.dtype} for element in {lhs.dtype} IN expression '{node}'."
420 )
421 if items:
422 clauses.append(ColumnContainer.sequence(items, dtype=lhs.dtype).contains(lhs))
423 result = Predicate.logical_or(*clauses)
424 if not_in:
425 result = result.logical_not()
426 return result
428 def visitNumericLiteral(self, value: str, node: Node) -> VisitorResult:
429 # Docstring inherited.
430 try:
431 return ColumnExpression.literal(int(value), dtype=int)
432 except ValueError:
433 return ColumnExpression.literal(float(value), dtype=float)
435 def visitParens(self, expression: VisitorResult, node: Node) -> VisitorResult:
436 # Docstring inherited.
437 return expression
439 def visitPointNode(self, ra: VisitorResult, dec: VisitorResult, node: Node) -> VisitorResult:
440 # Docstring inherited.
442 # this is a placeholder for future extension, we enabled syntax but
443 # do not support actual use just yet.
444 raise NotImplementedError("POINT() function is not supported yet")
446 def visitRangeLiteral(self, start: int, stop: int, stride: int | None, node: Node) -> VisitorResult:
447 # Docstring inherited.
448 return ColumnContainer.range_literal(range(start, stop + 1, stride or 1))
450 def visitStringLiteral(self, value: str, node: Node) -> VisitorResult:
451 # Docstring inherited.
452 return ColumnExpression.literal(value, dtype=str)
454 def visitTimeLiteral(self, value: astropy.time.Time, node: Node) -> VisitorResult:
455 # Docstring inherited.
456 return ColumnExpression.literal(value, dtype=astropy.time.Time)
458 def visitTupleNode(self, items: tuple[VisitorResult, ...], node: Node) -> VisitorResult:
459 # Docstring inherited.
460 match items:
461 case [
462 ColumnLiteral(value=begin, dtype=astropy.time.Time | types.NoneType),
463 ColumnLiteral(value=end, dtype=astropy.time.Time | types.NoneType),
464 ]:
465 return ColumnExpression.literal(_timespan.Timespan(begin, end), dtype=_timespan.Timespan)
466 raise ExpressionTypeError(
467 f'Invalid type(s) ({items[0].dtype}, {items[1].dtype}) in timespan tuple "{node}" '
468 '(Note that date/time strings must be preceded by "T" to be recognized).'
469 )
471 def visitUnaryOp(self, operator: str, operand: VisitorResult, node: Node) -> VisitorResult:
472 # Docstring inherited.
473 match (operator, operand):
474 case ["NOT", Predicate() as operand]:
475 return operand.logical_not()
476 case ["+", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
477 return operand.method("__pos__")
478 case ["-", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
479 return operand.method("__neg__")
480 raise ExpressionTypeError(
481 f"Unary operator {operator!r} is not valid for operand of type {operand.dtype!s} in {node!s}."
482 )