Coverage for python/lsst/daf/butler/registry/queries/expressions/_predicate.py: 12%
200 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-30 02:32 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-30 02:32 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("make_string_expression_predicate", "ExpressionTypeError")
25import builtins
26import datetime
27import types
28import warnings
29from collections.abc import Mapping, Set
30from typing import Any, Union, cast
32import astropy.time
33import astropy.utils.exceptions
34from lsst.daf.relation import (
35 ColumnContainer,
36 ColumnExpression,
37 ColumnExpressionSequence,
38 ColumnLiteral,
39 ColumnTag,
40 Predicate,
41 sql,
42)
44# We import the timespan module rather than types within it because match
45# syntax uses qualified names with periods to distinguish literals from
46# captures.
47from ....core import (
48 DataCoordinate,
49 DatasetColumnTag,
50 Dimension,
51 DimensionGraph,
52 DimensionKeyColumnTag,
53 DimensionRecordColumnTag,
54 DimensionUniverse,
55 timespan,
56)
57from ..._exceptions import UserExpressionError, UserExpressionSyntaxError
58from .categorize import ExpressionConstant, categorizeConstant, categorizeElementId
59from .check import CheckVisitor
60from .normalForm import NormalForm, NormalFormExpression
61from .parser import Node, ParserYacc, TreeVisitor # type: ignore
63# As of astropy 4.2, the erfa interface is shipped independently and
64# ErfaWarning is no longer an AstropyWarning
65try:
66 import erfa
67except ImportError:
68 erfa = None
71class ExpressionTypeError(TypeError):
72 """Exception raised when the types in a query expression are not
73 compatible with the operators or other syntax.
74 """
77def make_string_expression_predicate(
78 string: str,
79 dimensions: DimensionGraph,
80 *,
81 bind: Mapping[str, Any] | None = None,
82 data_id: DataCoordinate | None = None,
83 defaults: DataCoordinate | None = None,
84 dataset_type_name: str | None = None,
85 allow_orphans: bool = False,
86) -> tuple[Predicate | None, Mapping[str, Set[str]]]:
87 """Create a predicate by parsing and analyzing a string expression.
89 Parameters
90 ----------
91 string : `str`
92 String to parse.
93 dimensions : `DimensionGraph`
94 The dimensions the query would include in the absence of this WHERE
95 expression.
96 bind : `Mapping` [ `str`, `Any` ], optional
97 Literal values referenced in the expression.
98 data_id : `DataCoordinate`, optional
99 A fully-expanded data ID identifying dimensions known in advance.
100 If not provided, will be set to an empty data ID.
101 ``dataId.hasRecords()`` must return `True`.
102 defaults : `DataCoordinate`, optional
103 A data ID containing default for governor dimensions. Ignored
104 unless ``check=True``.
105 dataset_type_name : `str` or `None`, optional
106 The name of the dataset type to assume for unqualified dataset
107 columns, or `None` if there are no such identifiers.
108 allow_orphans : `bool`, optional
109 If `True`, permit expressions to refer to dimensions without
110 providing a value for their governor dimensions (e.g. referring to
111 a visit without an instrument). Should be left to default to
112 `False` in essentially all new code.
114 Returns
115 -------
116 predicate : `lsst.daf.relation.colum_expressions.Predicate` or `None`
117 New predicate derived from the string expression, or `None` if the
118 string is empty.
119 governor_constraints : `Mapping` [ `str` , `~collections.abc.Set` ]
120 Constraints on dimension values derived from the expression and data
121 ID.
122 """
123 governor_constraints: dict[str, Set[str]] = {}
124 if data_id is None:
125 data_id = DataCoordinate.makeEmpty(dimensions.universe)
126 if not string:
127 for dimension in data_id.graph.governors:
128 governor_constraints[dimension.name] = {cast(str, data_id[dimension])}
129 return None, governor_constraints
130 try:
131 parser = ParserYacc()
132 tree = parser.parse(string)
133 except Exception as exc:
134 raise UserExpressionSyntaxError(f"Failed to parse user expression {string!r}.") from exc
135 if bind is None:
136 bind = {}
137 if bind:
138 for identifier in bind:
139 if identifier in dimensions.universe.getStaticElements().names:
140 raise RuntimeError(f"Bind parameter key {identifier!r} conflicts with a dimension element.")
141 table, _, column = identifier.partition(".")
142 if column and table in dimensions.universe.getStaticElements().names:
143 raise RuntimeError(f"Bind parameter key {identifier!r} looks like a dimension column.")
144 if defaults is None:
145 defaults = DataCoordinate.makeEmpty(dimensions.universe)
146 # Convert the expression to disjunctive normal form (ORs of ANDs).
147 # That's potentially super expensive in the general case (where there's
148 # a ton of nesting of ANDs and ORs). That won't be the case for the
149 # expressions we expect, and we actually use disjunctive normal instead
150 # of conjunctive (i.e. ANDs of ORs) because I think the worst-case is
151 # a long list of OR'd-together data IDs, which is already in or very
152 # close to disjunctive normal form.
153 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE)
154 # Check the expression for consistency and completeness.
155 visitor = CheckVisitor(data_id, dimensions, bind, defaults, allow_orphans=allow_orphans)
156 try:
157 summary = expr.visit(visitor)
158 except UserExpressionError as err:
159 exprOriginal = str(tree)
160 exprNormal = str(expr.toTree())
161 if exprNormal == exprOriginal:
162 msg = f'Error in query expression "{exprOriginal}": {err}'
163 else:
164 msg = f'Error in query expression "{exprOriginal}" (normalized to "{exprNormal}"): {err}'
165 raise UserExpressionError(msg) from None
166 for dimension_name, values in summary.dimension_constraints.items():
167 if dimension_name in dimensions.universe.getGovernorDimensions().names:
168 governor_constraints[dimension_name] = cast(Set[str], values)
169 converter = PredicateConversionVisitor(bind, dataset_type_name, dimensions.universe)
170 predicate = tree.visit(converter)
171 return predicate, governor_constraints
174VisitorResult = Union[Predicate, ColumnExpression, ColumnContainer]
177class PredicateConversionVisitor(TreeVisitor[VisitorResult]):
178 def __init__(
179 self,
180 bind: Mapping[str, Any],
181 dataset_type_name: str | None,
182 universe: DimensionUniverse,
183 ):
184 self.bind = bind
185 self.dataset_type_name = dataset_type_name
186 self.universe = universe
188 OPERATOR_MAP = {
189 "=": "__eq__",
190 "!=": "__ne__",
191 "<": "__lt__",
192 ">": "__gt__",
193 "<=": "__le__",
194 ">=": "__ge__",
195 "+": "__add__",
196 "-": "__sub__",
197 "/": "__mul__",
198 }
200 def to_datetime(self, time: astropy.time.Time) -> datetime.datetime:
201 with warnings.catch_warnings():
202 warnings.simplefilter("ignore", category=astropy.utils.exceptions.AstropyWarning)
203 if erfa is not None:
204 warnings.simplefilter("ignore", category=erfa.ErfaWarning)
205 return time.to_datetime()
207 def visitBinaryOp(
208 self, operator: str, lhs: VisitorResult, rhs: VisitorResult, node: Node
209 ) -> VisitorResult:
210 # Docstring inherited.
211 b = builtins
212 match (operator, lhs, rhs):
213 case ["OR", Predicate() as lhs, Predicate() as rhs]:
214 return lhs.logical_or(rhs)
215 case ["AND", Predicate() as lhs, Predicate() as rhs]:
216 return lhs.logical_and(rhs)
217 # Allow all comparisons between expressions of the same type for
218 # sortable types.
219 case [
220 "=" | "!=" | "<" | ">" | "<=" | ">=",
221 ColumnExpression(
222 dtype=b.int | b.float | b.str | astropy.time.Time | datetime.datetime
223 ) as lhs,
224 ColumnExpression() as rhs,
225 ] if lhs.dtype is rhs.dtype:
226 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
227 # Allow comparisons between datetime expressions and
228 # astropy.time.Time literals/binds (only), by coercing the
229 # astropy.time.Time version to datetime.
230 case [
231 "=" | "!=" | "<" | ">" | "<=" | ">=",
232 ColumnLiteral(dtype=astropy.time.Time) as lhs,
233 ColumnExpression(dtype=datetime.datetime) as rhs,
234 ]:
235 lhs = ColumnLiteral(self.to_datetime(lhs.value), datetime.datetime)
236 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
237 case [
238 "=" | "!=" | "<" | ">" | "<=" | ">=",
239 ColumnExpression(dtype=datetime.datetime) as lhs,
240 ColumnLiteral(dtype=astropy.time.Time) as rhs,
241 ]:
242 rhs = ColumnLiteral(self.to_datetime(rhs.value), datetime.datetime)
243 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
244 # Allow equality comparisons with None/NULL. We don't have an 'IS'
245 # operator.
246 case ["=" | "!=", ColumnExpression(dtype=types.NoneType) as lhs, ColumnExpression() as rhs]:
247 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
248 case ["=" | "!=", ColumnExpression() as lhs, ColumnExpression(dtype=types.NoneType) as rhs]:
249 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
250 # Comparisions between Time and Timespan need have the Timespan on
251 # the lhs, since that (actually TimespanDatabaseRepresentation) is
252 # what actually has the methods.
253 case [
254 "<",
255 ColumnExpression(dtype=astropy.time.Time) as lhs,
256 ColumnExpression(dtype=timespan.Timespan) as rhs,
257 ]:
258 return rhs.predicate_method(self.OPERATOR_MAP[">"], lhs)
259 case [
260 ">",
261 ColumnExpression(dtype=astropy.time.Time) as lhs,
262 ColumnExpression(dtype=timespan.Timespan) as rhs,
263 ]:
264 return rhs.predicate_method(self.OPERATOR_MAP["<"], lhs)
265 # Enable other comparisons between times and Timespans (many of the
266 # combinations matched by this branch will have already been
267 # covered by a preceding branch).
268 case [
269 "<" | ">",
270 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as lhs,
271 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as rhs,
272 ]:
273 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
274 # Enable "overlaps" operations between timespans, and between times
275 # and timespans. The latter resolve to the `Timespan.contains` or
276 # `TimespanDatabaseRepresentation.contains` methods, but we use
277 # OVERLAPS in the string expression language to keep that simple.
278 case [
279 "OVERLAPS",
280 ColumnExpression(dtype=timespan.Timespan) as lhs,
281 ColumnExpression(dtype=timespan.Timespan) as rhs,
282 ]:
283 return lhs.predicate_method("overlaps", rhs)
284 case [
285 "OVERLAPS",
286 ColumnExpression(dtype=timespan.Timespan) as lhs,
287 ColumnExpression(dtype=astropy.time.Time) as rhs,
288 ]:
289 return lhs.predicate_method("overlaps", rhs)
290 case [
291 "OVERLAPS",
292 ColumnExpression(dtype=astropy.time.Time) as lhs,
293 ColumnExpression(dtype=timespan.Timespan) as rhs,
294 ]:
295 return rhs.predicate_method("overlaps", lhs)
296 # Enable arithmetic operators on numeric types, without any type
297 # coercion or broadening.
298 case [
299 "+" | "-" | "*",
300 ColumnExpression(dtype=b.int | b.float) as lhs,
301 ColumnExpression() as rhs,
302 ] if lhs.dtype is rhs.dtype:
303 return lhs.method(self.OPERATOR_MAP[operator], rhs, dtype=lhs.dtype)
304 case ["/", ColumnExpression(dtype=b.float) as lhs, ColumnExpression(dtype=b.float) as rhs]:
305 return lhs.method("__truediv__", rhs, dtype=b.float)
306 case ["/", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
307 # SQLAlchemy maps Python's '/' (__truediv__) operator directly
308 # to SQL's '/', despite those being defined differently for
309 # integers. Our expression language uses the SQL definition,
310 # and we only care about these expressions being evaluated in
311 # SQL right now, but we still want to guard against it being
312 # evaluated in Python and producing a surprising answer, so we
313 # mark it as being supported only by a SQL engine.
314 return lhs.method(
315 "__truediv__",
316 rhs,
317 dtype=b.int,
318 supporting_engine_types={sql.Engine},
319 )
320 case ["%", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
321 return lhs.method("__mod__", rhs, dtype=b.int)
322 assert (
323 lhs.dtype is not None and rhs.dtype is not None
324 ), "Expression converter should not yield untyped nodes."
325 raise ExpressionTypeError(
326 f"Invalid types {lhs.dtype.__name__}, {rhs.dtype.__name__} for binary operator {operator!r} "
327 f"in expression {node!s}."
328 )
330 def visitIdentifier(self, name: str, node: Node) -> VisitorResult:
331 # Docstring inherited.
332 if name in self.bind:
333 value = self.bind[name]
334 if isinstance(value, (list, tuple, Set)):
335 elements = []
336 all_dtypes = set()
337 for item in value:
338 dtype = type(item)
339 all_dtypes.add(dtype)
340 elements.append(ColumnExpression.literal(item, dtype=dtype))
341 if len(all_dtypes) > 1:
342 raise ExpressionTypeError(
343 f"Mismatched types in bind iterable: {value} has a mix of {all_dtypes}."
344 )
345 elif not elements:
346 # Empty container
347 return ColumnContainer.sequence([])
348 else:
349 (dtype,) = all_dtypes
350 return ColumnContainer.sequence(elements, dtype=dtype)
351 return ColumnExpression.literal(value, dtype=type(value))
352 tag: ColumnTag
353 match categorizeConstant(name):
354 case ExpressionConstant.INGEST_DATE:
355 assert self.dataset_type_name is not None
356 tag = DatasetColumnTag(self.dataset_type_name, "ingest_date")
357 return ColumnExpression.reference(tag, datetime.datetime)
358 case ExpressionConstant.NULL:
359 return ColumnExpression.literal(None, type(None))
360 case None:
361 pass
362 case _:
363 raise AssertionError("Check for enum values should be exhaustive.")
364 element, column = categorizeElementId(self.universe, name)
365 if column is not None:
366 tag = DimensionRecordColumnTag(element.name, column)
367 dtype = (
368 timespan.Timespan
369 if column == timespan.TimespanDatabaseRepresentation.NAME
370 else element.RecordClass.fields.standard[column].getPythonType()
371 )
372 return ColumnExpression.reference(tag, dtype)
373 else:
374 tag = DimensionKeyColumnTag(element.name)
375 assert isinstance(element, Dimension)
376 return ColumnExpression.reference(tag, element.primaryKey.getPythonType())
378 def visitIsIn(
379 self, lhs: VisitorResult, values: list[VisitorResult], not_in: bool, node: Node
380 ) -> VisitorResult:
381 # Docstring inherited.
382 clauses: list[Predicate] = []
383 items: list[ColumnExpression] = []
384 assert isinstance(lhs, ColumnExpression), "LHS of IN guaranteed to be scalar by parser."
385 for rhs_item in values:
386 match rhs_item:
387 case ColumnExpressionSequence(
388 items=rhs_items, dtype=rhs_dtype
389 ) if rhs_dtype is None or rhs_dtype == lhs.dtype:
390 items.extend(rhs_items)
391 case ColumnContainer(dtype=lhs.dtype):
392 clauses.append(rhs_item.contains(lhs))
393 case ColumnExpression(dtype=lhs.dtype):
394 items.append(rhs_item)
395 case _:
396 raise ExpressionTypeError(
397 f"Invalid type {rhs_item.dtype} for element in {lhs.dtype} IN expression '{node}'."
398 )
399 if items:
400 clauses.append(ColumnContainer.sequence(items, dtype=lhs.dtype).contains(lhs))
401 result = Predicate.logical_or(*clauses)
402 if not_in:
403 result = result.logical_not()
404 return result
406 def visitNumericLiteral(self, value: str, node: Node) -> VisitorResult:
407 # Docstring inherited.
408 try:
409 return ColumnExpression.literal(int(value), dtype=int)
410 except ValueError:
411 return ColumnExpression.literal(float(value), dtype=float)
413 def visitParens(self, expression: VisitorResult, node: Node) -> VisitorResult:
414 # Docstring inherited.
415 return expression
417 def visitPointNode(self, ra: VisitorResult, dec: VisitorResult, node: Node) -> VisitorResult:
418 # Docstring inherited.
420 # this is a placeholder for future extension, we enabled syntax but
421 # do not support actual use just yet.
422 raise NotImplementedError("POINT() function is not supported yet")
424 def visitRangeLiteral(self, start: int, stop: int, stride: int | None, node: Node) -> VisitorResult:
425 # Docstring inherited.
426 return ColumnContainer.range_literal(range(start, stop + 1, stride or 1))
428 def visitStringLiteral(self, value: str, node: Node) -> VisitorResult:
429 # Docstring inherited.
430 return ColumnExpression.literal(value, dtype=str)
432 def visitTimeLiteral(self, value: astropy.time.Time, node: Node) -> VisitorResult:
433 # Docstring inherited.
434 return ColumnExpression.literal(value, dtype=astropy.time.Time)
436 def visitTupleNode(self, items: tuple[VisitorResult, ...], node: Node) -> VisitorResult:
437 # Docstring inherited.
438 match items:
439 case [
440 ColumnLiteral(value=begin, dtype=astropy.time.Time | types.NoneType),
441 ColumnLiteral(value=end, dtype=astropy.time.Time | types.NoneType),
442 ]:
443 return ColumnExpression.literal(timespan.Timespan(begin, end), dtype=timespan.Timespan)
444 raise ExpressionTypeError(
445 f'Invalid type(s) ({items[0].dtype}, {items[1].dtype}) in timespan tuple "{node}" '
446 '(Note that date/time strings must be preceded by "T" to be recognized).'
447 )
449 def visitUnaryOp(self, operator: str, operand: VisitorResult, node: Node) -> VisitorResult:
450 # Docstring inherited.
451 match (operator, operand):
452 case ["NOT", Predicate() as operand]:
453 return operand.logical_not()
454 case ["+", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
455 return operand.method("__pos__")
456 case ["-", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
457 return operand.method("__neg__")
458 raise ExpressionTypeError(
459 f"Unary operator {operator!r} is not valid for operand of type {operand.dtype!s} in {node!s}."
460 )