Coverage for python/lsst/daf/butler/registry/queries/expressions/_predicate.py: 12%
207 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 10:56 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 10:56 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("make_string_expression_predicate", "ExpressionTypeError")
25import builtins
26import datetime
27import types
28import warnings
29from collections.abc import Mapping, Set
30from typing import Any, cast
32import astropy.time
33import astropy.utils.exceptions
34from lsst.daf.relation import (
35 ColumnContainer,
36 ColumnExpression,
37 ColumnExpressionSequence,
38 ColumnLiteral,
39 ColumnTag,
40 Predicate,
41 sql,
42)
44# We import the timespan module rather than types within it because match
45# syntax uses qualified names with periods to distinguish literals from
46# captures.
47from ....core import (
48 ColumnTypeInfo,
49 DataCoordinate,
50 DatasetColumnTag,
51 Dimension,
52 DimensionGraph,
53 DimensionKeyColumnTag,
54 DimensionRecordColumnTag,
55 DimensionUniverse,
56 timespan,
57)
58from ..._exceptions import UserExpressionError, UserExpressionSyntaxError
59from .categorize import ExpressionConstant, categorizeConstant, categorizeElementId
60from .check import CheckVisitor
61from .normalForm import NormalForm, NormalFormExpression
62from .parser import Node, ParserYacc, TreeVisitor # type: ignore
64# As of astropy 4.2, the erfa interface is shipped independently and
65# ErfaWarning is no longer an AstropyWarning
66try:
67 import erfa
68except ImportError:
69 erfa = None
72class ExpressionTypeError(TypeError):
73 """Exception raised when the types in a query expression are not
74 compatible with the operators or other syntax.
75 """
78def make_string_expression_predicate(
79 string: str,
80 dimensions: DimensionGraph,
81 *,
82 column_types: ColumnTypeInfo,
83 bind: Mapping[str, Any] | None = None,
84 data_id: DataCoordinate | None = None,
85 defaults: DataCoordinate | None = None,
86 dataset_type_name: str | None = None,
87 allow_orphans: bool = False,
88) -> tuple[Predicate | None, Mapping[str, Set[str]]]:
89 """Create a predicate by parsing and analyzing a string expression.
91 Parameters
92 ----------
93 string : `str`
94 String to parse.
95 dimensions : `DimensionGraph`
96 The dimensions the query would include in the absence of this WHERE
97 expression.
98 column_types : `ColumnTypeInfo`
99 Information about column types.
100 bind : `~collections.abc.Mapping` [ `str`, `Any` ], optional
101 Literal values referenced in the expression.
102 data_id : `DataCoordinate`, optional
103 A fully-expanded data ID identifying dimensions known in advance.
104 If not provided, will be set to an empty data ID.
105 ``dataId.hasRecords()`` must return `True`.
106 defaults : `DataCoordinate`, optional
107 A data ID containing default for governor dimensions. Ignored
108 unless ``check=True``.
109 dataset_type_name : `str` or `None`, optional
110 The name of the dataset type to assume for unqualified dataset
111 columns, or `None` if there are no such identifiers.
112 allow_orphans : `bool`, optional
113 If `True`, permit expressions to refer to dimensions without
114 providing a value for their governor dimensions (e.g. referring to
115 a visit without an instrument). Should be left to default to
116 `False` in essentially all new code.
118 Returns
119 -------
120 predicate : `lsst.daf.relation.colum_expressions.Predicate` or `None`
121 New predicate derived from the string expression, or `None` if the
122 string is empty.
123 governor_constraints : `~collections.abc.Mapping` [ `str` , \
124 `~collections.abc.Set` ]
125 Constraints on dimension values derived from the expression and data
126 ID.
127 """
128 governor_constraints: dict[str, Set[str]] = {}
129 if data_id is None:
130 data_id = DataCoordinate.makeEmpty(dimensions.universe)
131 if not string:
132 for dimension in data_id.graph.governors:
133 governor_constraints[dimension.name] = {cast(str, data_id[dimension])}
134 return None, governor_constraints
135 try:
136 parser = ParserYacc()
137 tree = parser.parse(string)
138 except Exception as exc:
139 raise UserExpressionSyntaxError(f"Failed to parse user expression {string!r}.") from exc
140 if bind is None:
141 bind = {}
142 if bind:
143 for identifier in bind:
144 if identifier in dimensions.universe.getStaticElements().names:
145 raise RuntimeError(f"Bind parameter key {identifier!r} conflicts with a dimension element.")
146 table, _, column = identifier.partition(".")
147 if column and table in dimensions.universe.getStaticElements().names:
148 raise RuntimeError(f"Bind parameter key {identifier!r} looks like a dimension column.")
149 if defaults is None:
150 defaults = DataCoordinate.makeEmpty(dimensions.universe)
151 # Convert the expression to disjunctive normal form (ORs of ANDs).
152 # That's potentially super expensive in the general case (where there's
153 # a ton of nesting of ANDs and ORs). That won't be the case for the
154 # expressions we expect, and we actually use disjunctive normal instead
155 # of conjunctive (i.e. ANDs of ORs) because I think the worst-case is
156 # a long list of OR'd-together data IDs, which is already in or very
157 # close to disjunctive normal form.
158 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE)
159 # Check the expression for consistency and completeness.
160 visitor = CheckVisitor(data_id, dimensions, bind, defaults, allow_orphans=allow_orphans)
161 try:
162 summary = expr.visit(visitor)
163 except UserExpressionError as err:
164 exprOriginal = str(tree)
165 exprNormal = str(expr.toTree())
166 if exprNormal == exprOriginal:
167 msg = f'Error in query expression "{exprOriginal}": {err}'
168 else:
169 msg = f'Error in query expression "{exprOriginal}" (normalized to "{exprNormal}"): {err}'
170 raise UserExpressionError(msg) from None
171 for dimension_name, values in summary.dimension_constraints.items():
172 if dimension_name in dimensions.universe.getGovernorDimensions().names:
173 governor_constraints[dimension_name] = cast(Set[str], values)
174 converter = PredicateConversionVisitor(bind, dataset_type_name, dimensions.universe, column_types)
175 predicate = tree.visit(converter)
176 return predicate, governor_constraints
179VisitorResult = Predicate | ColumnExpression | ColumnContainer
182class PredicateConversionVisitor(TreeVisitor[VisitorResult]):
183 def __init__(
184 self,
185 bind: Mapping[str, Any],
186 dataset_type_name: str | None,
187 universe: DimensionUniverse,
188 column_types: ColumnTypeInfo,
189 ):
190 self.bind = bind
191 self.dataset_type_name = dataset_type_name
192 self.universe = universe
193 self.column_types = column_types
195 OPERATOR_MAP = {
196 "=": "__eq__",
197 "!=": "__ne__",
198 "<": "__lt__",
199 ">": "__gt__",
200 "<=": "__le__",
201 ">=": "__ge__",
202 "+": "__add__",
203 "-": "__sub__",
204 "/": "__mul__",
205 }
207 def to_datetime(self, time: astropy.time.Time) -> datetime.datetime:
208 with warnings.catch_warnings():
209 warnings.simplefilter("ignore", category=astropy.utils.exceptions.AstropyWarning)
210 if erfa is not None:
211 warnings.simplefilter("ignore", category=erfa.ErfaWarning)
212 return time.to_datetime()
214 def visitBinaryOp(
215 self, operator: str, lhs: VisitorResult, rhs: VisitorResult, node: Node
216 ) -> VisitorResult:
217 # Docstring inherited.
218 b = builtins
219 match (operator, lhs, rhs):
220 case ["OR", Predicate() as lhs, Predicate() as rhs]:
221 return lhs.logical_or(rhs)
222 case ["AND", Predicate() as lhs, Predicate() as rhs]:
223 return lhs.logical_and(rhs)
224 # Allow all comparisons between expressions of the same type for
225 # sortable types.
226 case [
227 "=" | "!=" | "<" | ">" | "<=" | ">=",
228 ColumnExpression(
229 dtype=b.int | b.float | b.str | astropy.time.Time | datetime.datetime
230 ) as lhs,
231 ColumnExpression() as rhs,
232 ] if lhs.dtype is rhs.dtype:
233 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
234 # Allow comparisons between datetime expressions and
235 # astropy.time.Time literals/binds (only), by coercing the
236 # astropy.time.Time version to datetime.
237 case [
238 "=" | "!=" | "<" | ">" | "<=" | ">=",
239 ColumnLiteral(dtype=astropy.time.Time) as lhs,
240 ColumnExpression(dtype=datetime.datetime) as rhs,
241 ]:
242 lhs = ColumnLiteral(self.to_datetime(lhs.value), datetime.datetime)
243 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
244 case [
245 "=" | "!=" | "<" | ">" | "<=" | ">=",
246 ColumnExpression(dtype=datetime.datetime) as lhs,
247 ColumnLiteral(dtype=astropy.time.Time) as rhs,
248 ]:
249 rhs = ColumnLiteral(self.to_datetime(rhs.value), datetime.datetime)
250 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
251 # Allow comparisons between astropy.time.Time expressions and
252 # datetime literals/binds, by coercing the
253 # datetime literals to astropy.time.Time (in UTC scale).
254 case [
255 "=" | "!=" | "<" | ">" | "<=" | ">=",
256 ColumnLiteral(dtype=datetime.datetime) as lhs,
257 ColumnExpression(dtype=astropy.time.Time) as rhs,
258 ]:
259 lhs = ColumnLiteral(astropy.time.Time(lhs.value, scale="utc"), astropy.time.Time)
260 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
261 case [
262 "=" | "!=" | "<" | ">" | "<=" | ">=",
263 ColumnExpression(dtype=astropy.time.Time) as lhs,
264 ColumnLiteral(dtype=datetime.datetime) as rhs,
265 ]:
266 rhs = ColumnLiteral(astropy.time.Time(rhs.value, scale="utc"), astropy.time.Time)
267 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
268 # Allow equality comparisons with None/NULL. We don't have an 'IS'
269 # operator.
270 case ["=" | "!=", ColumnExpression(dtype=types.NoneType) as lhs, ColumnExpression() as rhs]:
271 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
272 case ["=" | "!=", ColumnExpression() as lhs, ColumnExpression(dtype=types.NoneType) as rhs]:
273 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
274 # Comparisions between Time and Timespan need have the Timespan on
275 # the lhs, since that (actually TimespanDatabaseRepresentation) is
276 # what actually has the methods.
277 case [
278 "<",
279 ColumnExpression(dtype=astropy.time.Time) as lhs,
280 ColumnExpression(dtype=timespan.Timespan) as rhs,
281 ]:
282 return rhs.predicate_method(self.OPERATOR_MAP[">"], lhs)
283 case [
284 ">",
285 ColumnExpression(dtype=astropy.time.Time) as lhs,
286 ColumnExpression(dtype=timespan.Timespan) as rhs,
287 ]:
288 return rhs.predicate_method(self.OPERATOR_MAP["<"], lhs)
289 # Enable other comparisons between times and Timespans (many of the
290 # combinations matched by this branch will have already been
291 # covered by a preceding branch).
292 case [
293 "<" | ">",
294 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as lhs,
295 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as rhs,
296 ]:
297 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
298 # Enable "overlaps" operations between timespans, and between times
299 # and timespans. The latter resolve to the `Timespan.contains` or
300 # `TimespanDatabaseRepresentation.contains` methods, but we use
301 # OVERLAPS in the string expression language to keep that simple.
302 case [
303 "OVERLAPS",
304 ColumnExpression(dtype=timespan.Timespan) as lhs,
305 ColumnExpression(dtype=timespan.Timespan) as rhs,
306 ]:
307 return lhs.predicate_method("overlaps", rhs)
308 case [
309 "OVERLAPS",
310 ColumnExpression(dtype=timespan.Timespan) as lhs,
311 ColumnExpression(dtype=astropy.time.Time) as rhs,
312 ]:
313 return lhs.predicate_method("overlaps", rhs)
314 case [
315 "OVERLAPS",
316 ColumnExpression(dtype=astropy.time.Time) as lhs,
317 ColumnExpression(dtype=timespan.Timespan) as rhs,
318 ]:
319 return rhs.predicate_method("overlaps", lhs)
320 # Enable arithmetic operators on numeric types, without any type
321 # coercion or broadening.
322 case [
323 "+" | "-" | "*",
324 ColumnExpression(dtype=b.int | b.float) as lhs,
325 ColumnExpression() as rhs,
326 ] if lhs.dtype is rhs.dtype:
327 return lhs.method(self.OPERATOR_MAP[operator], rhs, dtype=lhs.dtype)
328 case ["/", ColumnExpression(dtype=b.float) as lhs, ColumnExpression(dtype=b.float) as rhs]:
329 return lhs.method("__truediv__", rhs, dtype=b.float)
330 case ["/", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
331 # SQLAlchemy maps Python's '/' (__truediv__) operator directly
332 # to SQL's '/', despite those being defined differently for
333 # integers. Our expression language uses the SQL definition,
334 # and we only care about these expressions being evaluated in
335 # SQL right now, but we still want to guard against it being
336 # evaluated in Python and producing a surprising answer, so we
337 # mark it as being supported only by a SQL engine.
338 return lhs.method(
339 "__truediv__",
340 rhs,
341 dtype=b.int,
342 supporting_engine_types={sql.Engine},
343 )
344 case ["%", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
345 return lhs.method("__mod__", rhs, dtype=b.int)
346 assert (
347 lhs.dtype is not None and rhs.dtype is not None
348 ), "Expression converter should not yield untyped nodes."
349 raise ExpressionTypeError(
350 f"Invalid types {lhs.dtype.__name__}, {rhs.dtype.__name__} for binary operator {operator!r} "
351 f"in expression {node!s}."
352 )
354 def visitIdentifier(self, name: str, node: Node) -> VisitorResult:
355 # Docstring inherited.
356 if name in self.bind:
357 value = self.bind[name]
358 if isinstance(value, (list, tuple, Set)):
359 elements = []
360 all_dtypes = set()
361 for item in value:
362 dtype = type(item)
363 all_dtypes.add(dtype)
364 elements.append(ColumnExpression.literal(item, dtype=dtype))
365 if len(all_dtypes) > 1:
366 raise ExpressionTypeError(
367 f"Mismatched types in bind iterable: {value} has a mix of {all_dtypes}."
368 )
369 elif not elements:
370 # Empty container
371 return ColumnContainer.sequence([])
372 else:
373 (dtype,) = all_dtypes
374 return ColumnContainer.sequence(elements, dtype=dtype)
375 return ColumnExpression.literal(value, dtype=type(value))
376 tag: ColumnTag
377 match categorizeConstant(name):
378 case ExpressionConstant.INGEST_DATE:
379 assert self.dataset_type_name is not None
380 tag = DatasetColumnTag(self.dataset_type_name, "ingest_date")
381 return ColumnExpression.reference(tag, self.column_types.ingest_date_pytype)
382 case ExpressionConstant.NULL:
383 return ColumnExpression.literal(None, type(None))
384 case None:
385 pass
386 case _:
387 raise AssertionError("Check for enum values should be exhaustive.")
388 element, column = categorizeElementId(self.universe, name)
389 if column is not None:
390 tag = DimensionRecordColumnTag(element.name, column)
391 dtype = (
392 timespan.Timespan
393 if column == timespan.TimespanDatabaseRepresentation.NAME
394 else element.RecordClass.fields.standard[column].getPythonType()
395 )
396 return ColumnExpression.reference(tag, dtype)
397 else:
398 tag = DimensionKeyColumnTag(element.name)
399 assert isinstance(element, Dimension)
400 return ColumnExpression.reference(tag, element.primaryKey.getPythonType())
402 def visitIsIn(
403 self, lhs: VisitorResult, values: list[VisitorResult], not_in: bool, node: Node
404 ) -> VisitorResult:
405 # Docstring inherited.
406 clauses: list[Predicate] = []
407 items: list[ColumnExpression] = []
408 assert isinstance(lhs, ColumnExpression), "LHS of IN guaranteed to be scalar by parser."
409 for rhs_item in values:
410 match rhs_item:
411 case ColumnExpressionSequence(
412 items=rhs_items, dtype=rhs_dtype
413 ) if rhs_dtype is None or rhs_dtype == lhs.dtype:
414 items.extend(rhs_items)
415 case ColumnContainer(dtype=lhs.dtype):
416 clauses.append(rhs_item.contains(lhs))
417 case ColumnExpression(dtype=lhs.dtype):
418 items.append(rhs_item)
419 case _:
420 raise ExpressionTypeError(
421 f"Invalid type {rhs_item.dtype} for element in {lhs.dtype} IN expression '{node}'."
422 )
423 if items:
424 clauses.append(ColumnContainer.sequence(items, dtype=lhs.dtype).contains(lhs))
425 result = Predicate.logical_or(*clauses)
426 if not_in:
427 result = result.logical_not()
428 return result
430 def visitNumericLiteral(self, value: str, node: Node) -> VisitorResult:
431 # Docstring inherited.
432 try:
433 return ColumnExpression.literal(int(value), dtype=int)
434 except ValueError:
435 return ColumnExpression.literal(float(value), dtype=float)
437 def visitParens(self, expression: VisitorResult, node: Node) -> VisitorResult:
438 # Docstring inherited.
439 return expression
441 def visitPointNode(self, ra: VisitorResult, dec: VisitorResult, node: Node) -> VisitorResult:
442 # Docstring inherited.
444 # this is a placeholder for future extension, we enabled syntax but
445 # do not support actual use just yet.
446 raise NotImplementedError("POINT() function is not supported yet")
448 def visitRangeLiteral(self, start: int, stop: int, stride: int | None, node: Node) -> VisitorResult:
449 # Docstring inherited.
450 return ColumnContainer.range_literal(range(start, stop + 1, stride or 1))
452 def visitStringLiteral(self, value: str, node: Node) -> VisitorResult:
453 # Docstring inherited.
454 return ColumnExpression.literal(value, dtype=str)
456 def visitTimeLiteral(self, value: astropy.time.Time, node: Node) -> VisitorResult:
457 # Docstring inherited.
458 return ColumnExpression.literal(value, dtype=astropy.time.Time)
460 def visitTupleNode(self, items: tuple[VisitorResult, ...], node: Node) -> VisitorResult:
461 # Docstring inherited.
462 match items:
463 case [
464 ColumnLiteral(value=begin, dtype=astropy.time.Time | types.NoneType),
465 ColumnLiteral(value=end, dtype=astropy.time.Time | types.NoneType),
466 ]:
467 return ColumnExpression.literal(timespan.Timespan(begin, end), dtype=timespan.Timespan)
468 raise ExpressionTypeError(
469 f'Invalid type(s) ({items[0].dtype}, {items[1].dtype}) in timespan tuple "{node}" '
470 '(Note that date/time strings must be preceded by "T" to be recognized).'
471 )
473 def visitUnaryOp(self, operator: str, operand: VisitorResult, node: Node) -> VisitorResult:
474 # Docstring inherited.
475 match (operator, operand):
476 case ["NOT", Predicate() as operand]:
477 return operand.logical_not()
478 case ["+", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
479 return operand.method("__pos__")
480 case ["-", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
481 return operand.method("__neg__")
482 raise ExpressionTypeError(
483 f"Unary operator {operator!r} is not valid for operand of type {operand.dtype!s} in {node!s}."
484 )