Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%
237 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-27 03:00 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-27 03:00 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "Predicate",
32 "PredicateLeaf",
33 "LogicalNotOperand",
34 "PredicateOperands",
35 "ComparisonOperator",
36)
38import itertools
39from abc import ABC, abstractmethod
40from collections.abc import Iterable
41from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, cast, final
43import pydantic
45from ..._exceptions import InvalidQueryError
46from ._base import QueryTreeBase
47from ._column_expression import ColumnExpression
49if TYPE_CHECKING:
50 from ..visitors import PredicateVisitFlags, PredicateVisitor
51 from ._column_set import ColumnSet
52 from ._query_tree import QueryTree
54ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"]
57_L = TypeVar("_L")
58_A = TypeVar("_A")
59_O = TypeVar("_O")
62class PredicateLeafBase(QueryTreeBase, ABC):
63 """Base class for leaf nodes of the `Predicate` tree.
65 This is a closed hierarchy whose concrete, `~typing.final` derived classes
66 are members of the `PredicateLeaf` union. That union should generally
67 be used in type annotations rather than the technically-open base class.
68 """
70 @abstractmethod
71 def gather_required_columns(self, columns: ColumnSet) -> None:
72 """Add any columns required to evaluate this predicate leaf to the
73 given column set.
75 Parameters
76 ----------
77 columns : `ColumnSet`
78 Set of columns to modify in place.
79 """
80 raise NotImplementedError()
82 def invert(self) -> PredicateLeaf:
83 """Return a new leaf that is the logical not of this one."""
84 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self))
86 @abstractmethod
87 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
88 """Invoke the visitor interface.
90 Parameters
91 ----------
92 visitor : `PredicateVisitor`
93 Visitor to invoke a method on.
94 flags : `PredicateVisitFlags`
95 Flags that provide information about where this leaf appears in the
96 larger predicate tree.
98 Returns
99 -------
100 result : `object`
101 Forwarded result from the visitor.
102 """
103 raise NotImplementedError()
106@final
107class Predicate(QueryTreeBase):
108 """A boolean column expression.
110 Notes
111 -----
112 Predicate is the only class representing a boolean column expression that
113 should be used outside of this module (though the objects it nests appear
114 in its serialized form and hence are not fully private). It provides
115 several `classmethod` factories for constructing those nested types inside
116 a `Predicate` instance, and `PredicateVisitor` subclasses should be used
117 to process them.
118 """
120 operands: PredicateOperands
121 """Nested tuple of operands, with outer items combined via AND and inner
122 items combined via OR.
123 """
125 @property
126 def column_type(self) -> Literal["bool"]:
127 """A string enumeration value representing the type of the column
128 expression.
129 """
130 return "bool"
132 @classmethod
133 def from_bool(cls, value: bool) -> Predicate:
134 """Construct a predicate that always evaluates to `True` or `False`.
136 Parameters
137 ----------
138 value : `bool`
139 Value the predicate should evaluate to.
141 Returns
142 -------
143 predicate : `Predicate`
144 Predicate that evaluates to the given boolean value.
145 """
146 # The values for True and False here make sense if you think about
147 # calling `all` and `any` with empty sequences; note that the
148 # `self.operands` attribute is evaluated as:
149 #
150 # value = all(any(or_group) for or_group in self.operands)
151 #
152 return cls.model_construct(operands=() if value else ((),))
154 @classmethod
155 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate:
156 """Construct a predicate representing a binary comparison between
157 two non-boolean column expressions.
159 Parameters
160 ----------
161 a : `ColumnExpression`
162 First column expression in the comparison.
163 operator : `str`
164 Enumerated string representing the comparison operator to apply.
165 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps".
166 b : `ColumnExpression`
167 Second column expression in the comparison.
169 Returns
170 -------
171 predicate : `Predicate`
172 Predicate representing the comparison.
173 """
174 return cls._from_leaf(Comparison(a=a, operator=operator, b=b))
176 @classmethod
177 def is_null(cls, operand: ColumnExpression) -> Predicate:
178 """Construct a predicate that tests whether a column expression is
179 NULL.
181 Parameters
182 ----------
183 operand : `ColumnExpression`
184 Column expression to test.
186 Returns
187 -------
188 predicate : `Predicate`
189 Predicate representing the NULL check.
190 """
191 return cls._from_leaf(IsNull(operand=operand))
193 @classmethod
194 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate:
195 """Construct a predicate that tests whether one column expression is
196 a member of a container of other column expressions.
198 Parameters
199 ----------
200 member : `ColumnExpression`
201 Column expression that may be a member of the container.
202 container : `~collections.abc.Iterable` [ `ColumnExpression` ]
203 Container of column expressions to test for membership in.
205 Returns
206 -------
207 predicate : `Predicate`
208 Predicate representing the membership test.
209 """
210 return cls._from_leaf(InContainer(member=member, container=tuple(container)))
212 @classmethod
213 def in_range(
214 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1
215 ) -> Predicate:
216 """Construct a predicate that tests whether an integer column
217 expression is part of a strided range.
219 Parameters
220 ----------
221 member : `ColumnExpression`
222 Column expression that may be a member of the range.
223 start : `int`, optional
224 Beginning of the range, inclusive.
225 stop : `int` or `None`, optional
226 End of the range, exclusive.
227 step : `int`, optional
228 Offset between values in the range.
230 Returns
231 -------
232 predicate : `Predicate`
233 Predicate representing the membership test.
234 """
235 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step))
237 @classmethod
238 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate:
239 """Construct a predicate that tests whether a column expression is
240 present in a single-column projection of a query tree.
242 Parameters
243 ----------
244 member : `ColumnExpression`
245 Column expression that may be present in the query.
246 column : `ColumnExpression`
247 Column to project from the query.
248 query_tree : `QueryTree`
249 Query tree to select from.
251 Returns
252 -------
253 predicate : `Predicate`
254 Predicate representing the membership test.
255 """
256 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree))
258 def gather_required_columns(self, columns: ColumnSet) -> None:
259 """Add any columns required to evaluate this predicate to the given
260 column set.
262 Parameters
263 ----------
264 columns : `ColumnSet`
265 Set of columns to modify in place.
266 """
267 for or_group in self.operands:
268 for operand in or_group:
269 operand.gather_required_columns(columns)
271 def logical_and(self, *args: Predicate) -> Predicate:
272 """Construct a predicate representing the logical AND of this predicate
273 and one or more others.
275 Parameters
276 ----------
277 *args : `Predicate`
278 Other predicates.
280 Returns
281 -------
282 predicate : `Predicate`
283 Predicate representing the logical AND.
284 """
285 operands = self.operands
286 for arg in args:
287 operands = self._impl_and(operands, arg.operands)
288 if not all(operands):
289 # If any item in operands is an empty tuple (i.e. False), simplify.
290 operands = ((),)
291 return Predicate.model_construct(operands=operands)
293 def logical_or(self, *args: Predicate) -> Predicate:
294 """Construct a predicate representing the logical OR of this predicate
295 and one or more others.
297 Parameters
298 ----------
299 *args : `Predicate`
300 Other predicates.
302 Returns
303 -------
304 predicate : `Predicate`
305 Predicate representing the logical OR.
306 """
307 operands = self.operands
308 for arg in args:
309 operands = self._impl_or(operands, arg.operands)
310 return Predicate.model_construct(operands=operands)
312 def logical_not(self) -> Predicate:
313 """Construct a predicate representing the logical NOT of this
314 predicate.
316 Returns
317 -------
318 predicate : `Predicate`
319 Predicate representing the logical NOT.
320 """
321 new_operands: PredicateOperands = ((),)
322 for or_group in self.operands:
323 new_group: PredicateOperands = ()
324 for leaf in or_group:
325 new_group = self._impl_and(new_group, ((leaf.invert(),),))
326 new_operands = self._impl_or(new_operands, new_group)
327 return Predicate.model_construct(operands=new_operands)
329 def __str__(self) -> str:
330 and_terms = []
331 for or_group in self.operands:
332 match len(or_group):
333 case 0:
334 and_terms.append("False")
335 case 1:
336 and_terms.append(str(or_group[0]))
337 case _:
338 or_str = " OR ".join(str(operand) for operand in or_group)
339 if len(self.operands) > 1:
340 and_terms.append(f"({or_str})")
341 else:
342 and_terms.append(or_str)
343 if not and_terms:
344 return "True"
345 return " AND ".join(and_terms)
347 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A:
348 """Invoke the visitor interface.
350 Parameters
351 ----------
352 visitor : `PredicateVisitor`
353 Visitor to invoke a method on.
355 Returns
356 -------
357 result : `object`
358 Forwarded result from the visitor.
359 """
360 return visitor._visit_logical_and(self.operands)
362 @classmethod
363 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate:
364 return cls._from_or_group((leaf,))
366 @classmethod
367 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate:
368 return Predicate.model_construct(operands=(or_group,))
370 @classmethod
371 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
372 # We could simplify cases where both sides have some of the same leaf
373 # expressions; even using 'is' tests would simplify some cases where
374 # converting to conjunctive normal form twice leads to a lot of
375 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of
376 # double-negation. Right now those cases seem pathological enough to
377 # be not worth our time.
378 return a + b if a is not b else a
380 @classmethod
381 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
382 # Same comment re simplification as in _impl_and applies here.
383 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)])
386@final
387class LogicalNot(PredicateLeafBase):
388 """A boolean column expression that inverts its operand."""
390 predicate_type: Literal["not"] = "not"
392 operand: LogicalNotOperand
393 """Upstream boolean expression to invert."""
395 def gather_required_columns(self, columns: ColumnSet) -> None:
396 # Docstring inherited.
397 self.operand.gather_required_columns(columns)
399 def __str__(self) -> str:
400 return f"NOT {self.operand}"
402 def invert(self) -> LogicalNotOperand:
403 # Docstring inherited.
404 return self.operand
406 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
407 # Docstring inherited.
408 return visitor._visit_logical_not(self.operand, flags)
411@final
412class IsNull(PredicateLeafBase):
413 """A boolean column expression that tests whether its operand is NULL."""
415 predicate_type: Literal["is_null"] = "is_null"
417 operand: ColumnExpression
418 """Upstream expression to test."""
420 def gather_required_columns(self, columns: ColumnSet) -> None:
421 # Docstring inherited.
422 self.operand.gather_required_columns(columns)
424 def __str__(self) -> str:
425 return f"{self.operand} IS NULL"
427 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
428 # Docstring inherited.
429 return visitor.visit_is_null(self.operand, flags)
432@final
433class Comparison(PredicateLeafBase):
434 """A boolean columns expression formed by comparing two non-boolean
435 expressions.
436 """
438 predicate_type: Literal["comparison"] = "comparison"
440 a: ColumnExpression
441 """Left-hand side expression for the comparison."""
443 b: ColumnExpression
444 """Right-hand side expression for the comparison."""
446 operator: ComparisonOperator
447 """Comparison operator."""
449 def gather_required_columns(self, columns: ColumnSet) -> None:
450 # Docstring inherited.
451 self.a.gather_required_columns(columns)
452 self.b.gather_required_columns(columns)
454 def __str__(self) -> str:
455 return f"{self.a} {self.operator.upper()} {self.b}"
457 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
458 # Docstring inherited.
459 return visitor.visit_comparison(self.a, self.operator, self.b, flags)
461 @pydantic.model_validator(mode="after")
462 def _validate_column_types(self) -> Comparison:
463 if self.a.column_type != self.b.column_type:
464 raise InvalidQueryError(
465 f"Column types for comparison {self} do not agree "
466 f"({self.a.column_type}, {self.b.column_type})."
467 )
468 match (self.operator, self.a.column_type):
469 case ("==" | "!=", _):
470 pass
471 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"):
472 pass
473 case ("overlaps", "region" | "timespan"):
474 pass
475 case _:
476 raise InvalidQueryError(
477 f"Invalid column type {self.a.column_type} for operator {self.operator!r}."
478 )
479 return self
482@final
483class InContainer(PredicateLeafBase):
484 """A boolean column expression that tests whether one expression is a
485 member of an explicit sequence of other expressions.
486 """
488 predicate_type: Literal["in_container"] = "in_container"
490 member: ColumnExpression
491 """Expression to test for membership."""
493 container: tuple[ColumnExpression, ...]
494 """Expressions representing the elements of the container."""
496 def gather_required_columns(self, columns: ColumnSet) -> None:
497 # Docstring inherited.
498 self.member.gather_required_columns(columns)
499 for item in self.container:
500 item.gather_required_columns(columns)
502 def __str__(self) -> str:
503 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]"
505 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
506 # Docstring inherited.
507 return visitor.visit_in_container(self.member, self.container, flags)
509 @pydantic.model_validator(mode="after")
510 def _validate(self) -> InContainer:
511 if self.member.column_type == "timespan" or self.member.column_type == "region":
512 raise InvalidQueryError(
513 f"Timespan or region column {self.member} may not be used in IN expressions."
514 )
515 if not all(item.column_type == self.member.column_type for item in self.container):
516 raise InvalidQueryError(f"Column types for membership test {self} do not agree.")
517 return self
520@final
521class InRange(PredicateLeafBase):
522 """A boolean column expression that tests whether its expression is
523 included in an integer range.
524 """
526 predicate_type: Literal["in_range"] = "in_range"
528 member: ColumnExpression
529 """Expression to test for membership."""
531 start: int = 0
532 """Inclusive lower bound for the range."""
534 stop: int | None = None
535 """Exclusive upper bound for the range."""
537 step: int = 1
538 """Difference between values in the range."""
540 def gather_required_columns(self, columns: ColumnSet) -> None:
541 # Docstring inherited.
542 self.member.gather_required_columns(columns)
544 def __str__(self) -> str:
545 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}"
546 if self.step != 1:
547 s = f"{s}:{self.step}"
548 return f"{self.member} IN {s}"
550 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
551 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags)
553 @pydantic.model_validator(mode="after")
554 def _validate(self) -> InRange:
555 if self.member.column_type != "int":
556 raise InvalidQueryError(f"Column {self.member} is not an integer.")
557 if self.step < 1:
558 raise InvalidQueryError("Range step must be >= 1.")
559 if self.stop is not None and self.stop < self.start:
560 raise InvalidQueryError("Range stop must be >= start.")
561 return self
564@final
565class InQuery(PredicateLeafBase):
566 """A boolean column expression that tests whether its expression is
567 included single-column projection of a relation.
569 This is primarily intended to be used on dataset ID columns, but it may
570 be useful for other columns as well.
571 """
573 predicate_type: Literal["in_query"] = "in_query"
575 member: ColumnExpression
576 """Expression to test for membership."""
578 column: ColumnExpression
579 """Expression to extract from `query_tree`."""
581 query_tree: QueryTree
582 """Relation whose rows from `column` represent the container."""
584 def gather_required_columns(self, columns: ColumnSet) -> None:
585 # Docstring inherited.
586 # We're only gathering columns from the query_tree this predicate is
587 # attached to, not `self.column`, which belongs to `self.query_tree`.
588 self.member.gather_required_columns(columns)
590 def __str__(self) -> str:
591 return f"{self.member} IN (query).{self.column}"
593 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
594 # Docstring inherited.
595 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags)
597 @pydantic.model_validator(mode="after")
598 def _validate_column_types(self) -> InQuery:
599 if self.member.column_type == "timespan" or self.member.column_type == "region":
600 raise InvalidQueryError(
601 f"Timespan or region column {self.member} may not be used in IN expressions."
602 )
603 if self.member.column_type != self.column.column_type:
604 raise InvalidQueryError(
605 f"Column types for membership test {self} do not agree "
606 f"({self.member.column_type}, {self.column.column_type})."
607 )
609 from ._column_set import ColumnSet
611 columns_required_in_tree = ColumnSet(self.query_tree.dimensions)
612 self.column.gather_required_columns(columns_required_in_tree)
613 if columns_required_in_tree.dimensions != self.query_tree.dimensions:
614 raise InvalidQueryError(
615 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, "
616 f"but query tree only has {self.query_tree.dimensions}."
617 )
618 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys():
619 raise InvalidQueryError(
620 f"Column {self.column} requires dataset types "
621 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree."
622 )
623 return self
626LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery
627PredicateLeaf: TypeAlias = Annotated[
628 LogicalNotOperand | LogicalNot, pydantic.Field(discriminator="predicate_type")
629]
631PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]