Coverage for python/lsst/daf/butler/registry/queries/expressions/_predicate.py: 12%
207 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 08:00 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("make_string_expression_predicate", "ExpressionTypeError")
31import builtins
32import datetime
33import types
34import warnings
35from collections.abc import Mapping, Set
36from typing import Any, cast
38import astropy.time
39import astropy.utils.exceptions
40from lsst.daf.relation import (
41 ColumnContainer,
42 ColumnExpression,
43 ColumnExpressionSequence,
44 ColumnLiteral,
45 ColumnTag,
46 Predicate,
47 sql,
48)
50# We import the timespan module rather than types within it because match
51# syntax uses qualified names with periods to distinguish literals from
52# captures.
53from ....core import (
54 ColumnTypeInfo,
55 DataCoordinate,
56 DatasetColumnTag,
57 Dimension,
58 DimensionGraph,
59 DimensionKeyColumnTag,
60 DimensionRecordColumnTag,
61 DimensionUniverse,
62 timespan,
63)
64from ..._exceptions import UserExpressionError, UserExpressionSyntaxError
65from .categorize import ExpressionConstant, categorizeConstant, categorizeElementId
66from .check import CheckVisitor
67from .normalForm import NormalForm, NormalFormExpression
68from .parser import Node, ParserYacc, TreeVisitor # type: ignore
70# As of astropy 4.2, the erfa interface is shipped independently and
71# ErfaWarning is no longer an AstropyWarning
72try:
73 import erfa
74except ImportError:
75 erfa = None
78class ExpressionTypeError(TypeError):
79 """Exception raised when the types in a query expression are not
80 compatible with the operators or other syntax.
81 """
84def make_string_expression_predicate(
85 string: str,
86 dimensions: DimensionGraph,
87 *,
88 column_types: ColumnTypeInfo,
89 bind: Mapping[str, Any] | None = None,
90 data_id: DataCoordinate | None = None,
91 defaults: DataCoordinate | None = None,
92 dataset_type_name: str | None = None,
93 allow_orphans: bool = False,
94) -> tuple[Predicate | None, Mapping[str, Set[str]]]:
95 """Create a predicate by parsing and analyzing a string expression.
97 Parameters
98 ----------
99 string : `str`
100 String to parse.
101 dimensions : `DimensionGraph`
102 The dimensions the query would include in the absence of this WHERE
103 expression.
104 column_types : `ColumnTypeInfo`
105 Information about column types.
106 bind : `~collections.abc.Mapping` [ `str`, `Any` ], optional
107 Literal values referenced in the expression.
108 data_id : `DataCoordinate`, optional
109 A fully-expanded data ID identifying dimensions known in advance.
110 If not provided, will be set to an empty data ID.
111 ``dataId.hasRecords()`` must return `True`.
112 defaults : `DataCoordinate`, optional
113 A data ID containing default for governor dimensions. Ignored
114 unless ``check=True``.
115 dataset_type_name : `str` or `None`, optional
116 The name of the dataset type to assume for unqualified dataset
117 columns, or `None` if there are no such identifiers.
118 allow_orphans : `bool`, optional
119 If `True`, permit expressions to refer to dimensions without
120 providing a value for their governor dimensions (e.g. referring to
121 a visit without an instrument). Should be left to default to
122 `False` in essentially all new code.
124 Returns
125 -------
126 predicate : `lsst.daf.relation.colum_expressions.Predicate` or `None`
127 New predicate derived from the string expression, or `None` if the
128 string is empty.
129 governor_constraints : `~collections.abc.Mapping` [ `str` , \
130 `~collections.abc.Set` ]
131 Constraints on dimension values derived from the expression and data
132 ID.
133 """
134 governor_constraints: dict[str, Set[str]] = {}
135 if data_id is None:
136 data_id = DataCoordinate.makeEmpty(dimensions.universe)
137 if not string:
138 for dimension in data_id.graph.governors:
139 governor_constraints[dimension.name] = {cast(str, data_id[dimension])}
140 return None, governor_constraints
141 try:
142 parser = ParserYacc()
143 tree = parser.parse(string)
144 except Exception as exc:
145 raise UserExpressionSyntaxError(f"Failed to parse user expression {string!r}.") from exc
146 if bind is None:
147 bind = {}
148 if bind:
149 for identifier in bind:
150 if identifier in dimensions.universe.getStaticElements().names:
151 raise RuntimeError(f"Bind parameter key {identifier!r} conflicts with a dimension element.")
152 table, _, column = identifier.partition(".")
153 if column and table in dimensions.universe.getStaticElements().names:
154 raise RuntimeError(f"Bind parameter key {identifier!r} looks like a dimension column.")
155 if defaults is None:
156 defaults = DataCoordinate.makeEmpty(dimensions.universe)
157 # Convert the expression to disjunctive normal form (ORs of ANDs).
158 # That's potentially super expensive in the general case (where there's
159 # a ton of nesting of ANDs and ORs). That won't be the case for the
160 # expressions we expect, and we actually use disjunctive normal instead
161 # of conjunctive (i.e. ANDs of ORs) because I think the worst-case is
162 # a long list of OR'd-together data IDs, which is already in or very
163 # close to disjunctive normal form.
164 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE)
165 # Check the expression for consistency and completeness.
166 visitor = CheckVisitor(data_id, dimensions, bind, defaults, allow_orphans=allow_orphans)
167 try:
168 summary = expr.visit(visitor)
169 except UserExpressionError as err:
170 exprOriginal = str(tree)
171 exprNormal = str(expr.toTree())
172 if exprNormal == exprOriginal:
173 msg = f'Error in query expression "{exprOriginal}": {err}'
174 else:
175 msg = f'Error in query expression "{exprOriginal}" (normalized to "{exprNormal}"): {err}'
176 raise UserExpressionError(msg) from None
177 for dimension_name, values in summary.dimension_constraints.items():
178 if dimension_name in dimensions.universe.getGovernorDimensions().names:
179 governor_constraints[dimension_name] = cast(Set[str], values)
180 converter = PredicateConversionVisitor(bind, dataset_type_name, dimensions.universe, column_types)
181 predicate = tree.visit(converter)
182 return predicate, governor_constraints
185VisitorResult = Predicate | ColumnExpression | ColumnContainer
188class PredicateConversionVisitor(TreeVisitor[VisitorResult]):
189 def __init__(
190 self,
191 bind: Mapping[str, Any],
192 dataset_type_name: str | None,
193 universe: DimensionUniverse,
194 column_types: ColumnTypeInfo,
195 ):
196 self.bind = bind
197 self.dataset_type_name = dataset_type_name
198 self.universe = universe
199 self.column_types = column_types
201 OPERATOR_MAP = {
202 "=": "__eq__",
203 "!=": "__ne__",
204 "<": "__lt__",
205 ">": "__gt__",
206 "<=": "__le__",
207 ">=": "__ge__",
208 "+": "__add__",
209 "-": "__sub__",
210 "/": "__mul__",
211 }
213 def to_datetime(self, time: astropy.time.Time) -> datetime.datetime:
214 with warnings.catch_warnings():
215 warnings.simplefilter("ignore", category=astropy.utils.exceptions.AstropyWarning)
216 if erfa is not None:
217 warnings.simplefilter("ignore", category=erfa.ErfaWarning)
218 return time.to_datetime()
220 def visitBinaryOp(
221 self, operator: str, lhs: VisitorResult, rhs: VisitorResult, node: Node
222 ) -> VisitorResult:
223 # Docstring inherited.
224 b = builtins
225 match (operator, lhs, rhs):
226 case ["OR", Predicate() as lhs, Predicate() as rhs]:
227 return lhs.logical_or(rhs)
228 case ["AND", Predicate() as lhs, Predicate() as rhs]:
229 return lhs.logical_and(rhs)
230 # Allow all comparisons between expressions of the same type for
231 # sortable types.
232 case [
233 "=" | "!=" | "<" | ">" | "<=" | ">=",
234 ColumnExpression(
235 dtype=b.int | b.float | b.str | astropy.time.Time | datetime.datetime
236 ) as lhs,
237 ColumnExpression() as rhs,
238 ] if lhs.dtype is rhs.dtype:
239 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
240 # Allow comparisons between datetime expressions and
241 # astropy.time.Time literals/binds (only), by coercing the
242 # astropy.time.Time version to datetime.
243 case [
244 "=" | "!=" | "<" | ">" | "<=" | ">=",
245 ColumnLiteral(dtype=astropy.time.Time) as lhs,
246 ColumnExpression(dtype=datetime.datetime) as rhs,
247 ]:
248 lhs = ColumnLiteral(self.to_datetime(lhs.value), datetime.datetime)
249 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
250 case [
251 "=" | "!=" | "<" | ">" | "<=" | ">=",
252 ColumnExpression(dtype=datetime.datetime) as lhs,
253 ColumnLiteral(dtype=astropy.time.Time) as rhs,
254 ]:
255 rhs = ColumnLiteral(self.to_datetime(rhs.value), datetime.datetime)
256 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
257 # Allow comparisons between astropy.time.Time expressions and
258 # datetime literals/binds, by coercing the
259 # datetime literals to astropy.time.Time (in UTC scale).
260 case [
261 "=" | "!=" | "<" | ">" | "<=" | ">=",
262 ColumnLiteral(dtype=datetime.datetime) as lhs,
263 ColumnExpression(dtype=astropy.time.Time) as rhs,
264 ]:
265 lhs = ColumnLiteral(astropy.time.Time(lhs.value, scale="utc"), astropy.time.Time)
266 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
267 case [
268 "=" | "!=" | "<" | ">" | "<=" | ">=",
269 ColumnExpression(dtype=astropy.time.Time) as lhs,
270 ColumnLiteral(dtype=datetime.datetime) as rhs,
271 ]:
272 rhs = ColumnLiteral(astropy.time.Time(rhs.value, scale="utc"), astropy.time.Time)
273 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
274 # Allow equality comparisons with None/NULL. We don't have an 'IS'
275 # operator.
276 case ["=" | "!=", ColumnExpression(dtype=types.NoneType) as lhs, ColumnExpression() as rhs]:
277 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
278 case ["=" | "!=", ColumnExpression() as lhs, ColumnExpression(dtype=types.NoneType) as rhs]:
279 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
280 # Comparisions between Time and Timespan need have the Timespan on
281 # the lhs, since that (actually TimespanDatabaseRepresentation) is
282 # what actually has the methods.
283 case [
284 "<",
285 ColumnExpression(dtype=astropy.time.Time) as lhs,
286 ColumnExpression(dtype=timespan.Timespan) as rhs,
287 ]:
288 return rhs.predicate_method(self.OPERATOR_MAP[">"], lhs)
289 case [
290 ">",
291 ColumnExpression(dtype=astropy.time.Time) as lhs,
292 ColumnExpression(dtype=timespan.Timespan) as rhs,
293 ]:
294 return rhs.predicate_method(self.OPERATOR_MAP["<"], lhs)
295 # Enable other comparisons between times and Timespans (many of the
296 # combinations matched by this branch will have already been
297 # covered by a preceding branch).
298 case [
299 "<" | ">",
300 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as lhs,
301 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as rhs,
302 ]:
303 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs)
304 # Enable "overlaps" operations between timespans, and between times
305 # and timespans. The latter resolve to the `Timespan.contains` or
306 # `TimespanDatabaseRepresentation.contains` methods, but we use
307 # OVERLAPS in the string expression language to keep that simple.
308 case [
309 "OVERLAPS",
310 ColumnExpression(dtype=timespan.Timespan) as lhs,
311 ColumnExpression(dtype=timespan.Timespan) as rhs,
312 ]:
313 return lhs.predicate_method("overlaps", rhs)
314 case [
315 "OVERLAPS",
316 ColumnExpression(dtype=timespan.Timespan) as lhs,
317 ColumnExpression(dtype=astropy.time.Time) as rhs,
318 ]:
319 return lhs.predicate_method("overlaps", rhs)
320 case [
321 "OVERLAPS",
322 ColumnExpression(dtype=astropy.time.Time) as lhs,
323 ColumnExpression(dtype=timespan.Timespan) as rhs,
324 ]:
325 return rhs.predicate_method("overlaps", lhs)
326 # Enable arithmetic operators on numeric types, without any type
327 # coercion or broadening.
328 case [
329 "+" | "-" | "*",
330 ColumnExpression(dtype=b.int | b.float) as lhs,
331 ColumnExpression() as rhs,
332 ] if lhs.dtype is rhs.dtype:
333 return lhs.method(self.OPERATOR_MAP[operator], rhs, dtype=lhs.dtype)
334 case ["/", ColumnExpression(dtype=b.float) as lhs, ColumnExpression(dtype=b.float) as rhs]:
335 return lhs.method("__truediv__", rhs, dtype=b.float)
336 case ["/", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
337 # SQLAlchemy maps Python's '/' (__truediv__) operator directly
338 # to SQL's '/', despite those being defined differently for
339 # integers. Our expression language uses the SQL definition,
340 # and we only care about these expressions being evaluated in
341 # SQL right now, but we still want to guard against it being
342 # evaluated in Python and producing a surprising answer, so we
343 # mark it as being supported only by a SQL engine.
344 return lhs.method(
345 "__truediv__",
346 rhs,
347 dtype=b.int,
348 supporting_engine_types={sql.Engine},
349 )
350 case ["%", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]:
351 return lhs.method("__mod__", rhs, dtype=b.int)
352 assert (
353 lhs.dtype is not None and rhs.dtype is not None
354 ), "Expression converter should not yield untyped nodes."
355 raise ExpressionTypeError(
356 f"Invalid types {lhs.dtype.__name__}, {rhs.dtype.__name__} for binary operator {operator!r} "
357 f"in expression {node!s}."
358 )
360 def visitIdentifier(self, name: str, node: Node) -> VisitorResult:
361 # Docstring inherited.
362 if name in self.bind:
363 value = self.bind[name]
364 if isinstance(value, list | tuple | Set):
365 elements = []
366 all_dtypes = set()
367 for item in value:
368 dtype = type(item)
369 all_dtypes.add(dtype)
370 elements.append(ColumnExpression.literal(item, dtype=dtype))
371 if len(all_dtypes) > 1:
372 raise ExpressionTypeError(
373 f"Mismatched types in bind iterable: {value} has a mix of {all_dtypes}."
374 )
375 elif not elements:
376 # Empty container
377 return ColumnContainer.sequence([])
378 else:
379 (dtype,) = all_dtypes
380 return ColumnContainer.sequence(elements, dtype=dtype)
381 return ColumnExpression.literal(value, dtype=type(value))
382 tag: ColumnTag
383 match categorizeConstant(name):
384 case ExpressionConstant.INGEST_DATE:
385 assert self.dataset_type_name is not None
386 tag = DatasetColumnTag(self.dataset_type_name, "ingest_date")
387 return ColumnExpression.reference(tag, self.column_types.ingest_date_pytype)
388 case ExpressionConstant.NULL:
389 return ColumnExpression.literal(None, type(None))
390 case None:
391 pass
392 case _:
393 raise AssertionError("Check for enum values should be exhaustive.")
394 element, column = categorizeElementId(self.universe, name)
395 if column is not None:
396 tag = DimensionRecordColumnTag(element.name, column)
397 dtype = (
398 timespan.Timespan
399 if column == timespan.TimespanDatabaseRepresentation.NAME
400 else element.RecordClass.fields.standard[column].getPythonType()
401 )
402 return ColumnExpression.reference(tag, dtype)
403 else:
404 tag = DimensionKeyColumnTag(element.name)
405 assert isinstance(element, Dimension)
406 return ColumnExpression.reference(tag, element.primaryKey.getPythonType())
408 def visitIsIn(
409 self, lhs: VisitorResult, values: list[VisitorResult], not_in: bool, node: Node
410 ) -> VisitorResult:
411 # Docstring inherited.
412 clauses: list[Predicate] = []
413 items: list[ColumnExpression] = []
414 assert isinstance(lhs, ColumnExpression), "LHS of IN guaranteed to be scalar by parser."
415 for rhs_item in values:
416 match rhs_item:
417 case ColumnExpressionSequence(
418 items=rhs_items, dtype=rhs_dtype
419 ) if rhs_dtype is None or rhs_dtype == lhs.dtype:
420 items.extend(rhs_items)
421 case ColumnContainer(dtype=lhs.dtype):
422 clauses.append(rhs_item.contains(lhs))
423 case ColumnExpression(dtype=lhs.dtype):
424 items.append(rhs_item)
425 case _:
426 raise ExpressionTypeError(
427 f"Invalid type {rhs_item.dtype} for element in {lhs.dtype} IN expression '{node}'."
428 )
429 if items:
430 clauses.append(ColumnContainer.sequence(items, dtype=lhs.dtype).contains(lhs))
431 result = Predicate.logical_or(*clauses)
432 if not_in:
433 result = result.logical_not()
434 return result
436 def visitNumericLiteral(self, value: str, node: Node) -> VisitorResult:
437 # Docstring inherited.
438 try:
439 return ColumnExpression.literal(int(value), dtype=int)
440 except ValueError:
441 return ColumnExpression.literal(float(value), dtype=float)
443 def visitParens(self, expression: VisitorResult, node: Node) -> VisitorResult:
444 # Docstring inherited.
445 return expression
447 def visitPointNode(self, ra: VisitorResult, dec: VisitorResult, node: Node) -> VisitorResult:
448 # Docstring inherited.
450 # this is a placeholder for future extension, we enabled syntax but
451 # do not support actual use just yet.
452 raise NotImplementedError("POINT() function is not supported yet")
454 def visitRangeLiteral(self, start: int, stop: int, stride: int | None, node: Node) -> VisitorResult:
455 # Docstring inherited.
456 return ColumnContainer.range_literal(range(start, stop + 1, stride or 1))
458 def visitStringLiteral(self, value: str, node: Node) -> VisitorResult:
459 # Docstring inherited.
460 return ColumnExpression.literal(value, dtype=str)
462 def visitTimeLiteral(self, value: astropy.time.Time, node: Node) -> VisitorResult:
463 # Docstring inherited.
464 return ColumnExpression.literal(value, dtype=astropy.time.Time)
466 def visitTupleNode(self, items: tuple[VisitorResult, ...], node: Node) -> VisitorResult:
467 # Docstring inherited.
468 match items:
469 case [
470 ColumnLiteral(value=begin, dtype=astropy.time.Time | types.NoneType),
471 ColumnLiteral(value=end, dtype=astropy.time.Time | types.NoneType),
472 ]:
473 return ColumnExpression.literal(timespan.Timespan(begin, end), dtype=timespan.Timespan)
474 raise ExpressionTypeError(
475 f'Invalid type(s) ({items[0].dtype}, {items[1].dtype}) in timespan tuple "{node}" '
476 '(Note that date/time strings must be preceded by "T" to be recognized).'
477 )
479 def visitUnaryOp(self, operator: str, operand: VisitorResult, node: Node) -> VisitorResult:
480 # Docstring inherited.
481 match (operator, operand):
482 case ["NOT", Predicate() as operand]:
483 return operand.logical_not()
484 case ["+", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
485 return operand.method("__pos__")
486 case ["-", ColumnExpression(dtype=builtins.int | builtins.float) as operand]:
487 return operand.method("__neg__")
488 raise ExpressionTypeError(
489 f"Unary operator {operator!r} is not valid for operand of type {operand.dtype!s} in {node!s}."
490 )