Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%
235 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-07 11:04 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-07 11:04 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "Predicate",
32 "PredicateLeaf",
33 "LogicalNotOperand",
34 "PredicateOperands",
35 "ComparisonOperator",
36)
38import itertools
39from abc import ABC, abstractmethod
40from typing import TYPE_CHECKING, Annotated, Iterable, Literal, TypeAlias, TypeVar, Union, cast, final
42import pydantic
44from ._base import InvalidQueryError, QueryTreeBase
45from ._column_expression import ColumnExpression
47if TYPE_CHECKING:
48 from ..visitors import PredicateVisitFlags, PredicateVisitor
49 from ._column_set import ColumnSet
50 from ._query_tree import QueryTree
52ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"]
55_L = TypeVar("_L")
56_A = TypeVar("_A")
57_O = TypeVar("_O")
60class PredicateLeafBase(QueryTreeBase, ABC):
61 """Base class for leaf nodes of the `Predicate` tree.
63 This is a closed hierarchy whose concrete, `~typing.final` derived classes
64 are members of the `PredicateLeaf` union. That union should generally
65 be used in type annotations rather than the technically-open base class.
66 """
68 @abstractmethod
69 def gather_required_columns(self, columns: ColumnSet) -> None:
70 """Add any columns required to evaluate this predicate leaf to the
71 given column set.
73 Parameters
74 ----------
75 columns : `ColumnSet`
76 Set of columns to modify in place.
77 """
78 raise NotImplementedError()
80 def invert(self) -> PredicateLeaf:
81 """Return a new leaf that is the logical not of this one."""
82 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self))
84 @abstractmethod
85 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
86 """Invoke the visitor interface.
88 Parameters
89 ----------
90 visitor : `PredicateVisitor`
91 Visitor to invoke a method on.
92 flags : `PredicateVisitFlags`
93 Flags that provide information about where this leaf appears in the
94 larger predicate tree.
96 Returns
97 -------
98 result : `object`
99 Forwarded result from the visitor.
100 """
101 raise NotImplementedError()
104@final
105class Predicate(QueryTreeBase):
106 """A boolean column expression.
108 Notes
109 -----
110 Predicate is the only class representing a boolean column expression that
111 should be used outside of this module (though the objects it nests appear
112 in its serialized form and hence are not fully private). It provides
113 several `classmethod` factories for constructing those nested types inside
114 a `Predicate` instance, and `PredicateVisitor` subclasses should be used
115 to process them.
116 """
118 operands: PredicateOperands
119 """Nested tuple of operands, with outer items combined via AND and inner
120 items combined via OR.
121 """
123 @property
124 def column_type(self) -> Literal["bool"]:
125 """A string enumeration value representing the type of the column
126 expression.
127 """
128 return "bool"
130 @classmethod
131 def from_bool(cls, value: bool) -> Predicate:
132 """Construct a predicate that always evaluates to `True` or `False`.
134 Parameters
135 ----------
136 value : `bool`
137 Value the predicate should evaluate to.
139 Returns
140 -------
141 predicate : `Predicate`
142 Predicate that evaluates to the given boolean value.
143 """
144 # The values for True and False here make sense if you think about
145 # calling `all` and `any` with empty sequences; note that the
146 # `self.operands` attribute is evaluated as:
147 #
148 # value = all(any(or_group) for or_group in self.operands)
149 #
150 return cls.model_construct(operands=() if value else ((),))
152 @classmethod
153 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate:
154 """Construct a predicate representing a binary comparison between
155 two non-boolean column expressions.
157 Parameters
158 ----------
159 a : `ColumnExpression`
160 First column expression in the comparison.
161 operator : `str`
162 Enumerated string representing the comparison operator to apply.
163 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps".
164 b : `ColumnExpression`
165 Second column expression in the comparison.
167 Returns
168 -------
169 predicate : `Predicate`
170 Predicate representing the comparison.
171 """
172 return cls._from_leaf(Comparison(a=a, operator=operator, b=b))
174 @classmethod
175 def is_null(cls, operand: ColumnExpression) -> Predicate:
176 """Construct a predicate that tests whether a column expression is
177 NULL.
179 Parameters
180 ----------
181 operand : `ColumnExpression`
182 Column expression to test.
184 Returns
185 -------
186 predicate : `Predicate`
187 Predicate representing the NULL check.
188 """
189 return cls._from_leaf(IsNull(operand=operand))
191 @classmethod
192 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate:
193 """Construct a predicate that tests whether one column expression is
194 a member of a container of other column expressions.
196 Parameters
197 ----------
198 member : `ColumnExpression`
199 Column expression that may be a member of the container.
200 container : `~collections.abc.Iterable` [ `ColumnExpression` ]
201 Container of column expressions to test for membership in.
203 Returns
204 -------
205 predicate : `Predicate`
206 Predicate representing the membership test.
207 """
208 return cls._from_leaf(InContainer(member=member, container=tuple(container)))
210 @classmethod
211 def in_range(
212 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1
213 ) -> Predicate:
214 """Construct a predicate that tests whether an integer column
215 expression is part of a strided range.
217 Parameters
218 ----------
219 member : `ColumnExpression`
220 Column expression that may be a member of the range.
221 start : `int`, optional
222 Beginning of the range, inclusive.
223 stop : `int` or `None`, optional
224 End of the range, exclusive.
225 step : `int`, optional
226 Offset between values in the range.
228 Returns
229 -------
230 predicate : `Predicate`
231 Predicate representing the membership test.
232 """
233 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step))
235 @classmethod
236 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate:
237 """Construct a predicate that tests whether a column expression is
238 present in a single-column projection of a query tree.
240 Parameters
241 ----------
242 member : `ColumnExpression`
243 Column expression that may be present in the query.
244 column : `ColumnExpression`
245 Column to project from the query.
246 query_tree : `QueryTree`
247 Query tree to select from.
249 Returns
250 -------
251 predicate : `Predicate`
252 Predicate representing the membership test.
253 """
254 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree))
256 def gather_required_columns(self, columns: ColumnSet) -> None:
257 """Add any columns required to evaluate this predicate to the given
258 column set.
260 Parameters
261 ----------
262 columns : `ColumnSet`
263 Set of columns to modify in place.
264 """
265 for or_group in self.operands:
266 for operand in or_group:
267 operand.gather_required_columns(columns)
269 def logical_and(self, *args: Predicate) -> Predicate:
270 """Construct a predicate representing the logical AND of this predicate
271 and one or more others.
273 Parameters
274 ----------
275 *args : `Predicate`
276 Other predicates.
278 Returns
279 -------
280 predicate : `Predicate`
281 Predicate representing the logical AND.
282 """
283 operands = self.operands
284 for arg in args:
285 operands = self._impl_and(operands, arg.operands)
286 if not all(operands):
287 # If any item in operands is an empty tuple (i.e. False), simplify.
288 operands = ((),)
289 return Predicate.model_construct(operands=operands)
291 def logical_or(self, *args: Predicate) -> Predicate:
292 """Construct a predicate representing the logical OR of this predicate
293 and one or more others.
295 Parameters
296 ----------
297 *args : `Predicate`
298 Other predicates.
300 Returns
301 -------
302 predicate : `Predicate`
303 Predicate representing the logical OR.
304 """
305 operands = self.operands
306 for arg in args:
307 operands = self._impl_or(operands, arg.operands)
308 return Predicate.model_construct(operands=operands)
310 def logical_not(self) -> Predicate:
311 """Construct a predicate representing the logical NOT of this
312 predicate.
314 Returns
315 -------
316 predicate : `Predicate`
317 Predicate representing the logical NOT.
318 """
319 new_operands: PredicateOperands = ((),)
320 for or_group in self.operands:
321 new_group: PredicateOperands = ()
322 for leaf in or_group:
323 new_group = self._impl_and(new_group, ((leaf.invert(),),))
324 new_operands = self._impl_or(new_operands, new_group)
325 return Predicate.model_construct(operands=new_operands)
327 def __str__(self) -> str:
328 and_terms = []
329 for or_group in self.operands:
330 match len(or_group):
331 case 0:
332 and_terms.append("False")
333 case 1:
334 and_terms.append(str(or_group[0]))
335 case _:
336 or_str = " OR ".join(str(operand) for operand in or_group)
337 if len(self.operands) > 1:
338 and_terms.append(f"({or_str})")
339 else:
340 and_terms.append(or_str)
341 if not and_terms:
342 return "True"
343 return " AND ".join(and_terms)
345 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A:
346 """Invoke the visitor interface.
348 Parameters
349 ----------
350 visitor : `PredicateVisitor`
351 Visitor to invoke a method on.
353 Returns
354 -------
355 result : `object`
356 Forwarded result from the visitor.
357 """
358 return visitor._visit_logical_and(self.operands)
360 @classmethod
361 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate:
362 return cls._from_or_group((leaf,))
364 @classmethod
365 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate:
366 return Predicate.model_construct(operands=(or_group,))
368 @classmethod
369 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
370 # We could simplify cases where both sides have some of the same leaf
371 # expressions; even using 'is' tests would simplify some cases where
372 # converting to conjunctive normal form twice leads to a lot of
373 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of
374 # double-negation. Right now those cases seem pathological enough to
375 # be not worth our time.
376 return a + b if a is not b else a
378 @classmethod
379 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
380 # Same comment re simplification as in _impl_and applies here.
381 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)])
384@final
385class LogicalNot(PredicateLeafBase):
386 """A boolean column expression that inverts its operand."""
388 predicate_type: Literal["not"] = "not"
390 operand: LogicalNotOperand
391 """Upstream boolean expression to invert."""
393 def gather_required_columns(self, columns: ColumnSet) -> None:
394 # Docstring inherited.
395 self.operand.gather_required_columns(columns)
397 def __str__(self) -> str:
398 return f"NOT {self.operand}"
400 def invert(self) -> LogicalNotOperand:
401 # Docstring inherited.
402 return self.operand
404 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
405 # Docstring inherited.
406 return visitor._visit_logical_not(self.operand, flags)
409@final
410class IsNull(PredicateLeafBase):
411 """A boolean column expression that tests whether its operand is NULL."""
413 predicate_type: Literal["is_null"] = "is_null"
415 operand: ColumnExpression
416 """Upstream expression to test."""
418 def gather_required_columns(self, columns: ColumnSet) -> None:
419 # Docstring inherited.
420 self.operand.gather_required_columns(columns)
422 def __str__(self) -> str:
423 return f"{self.operand} IS NULL"
425 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
426 # Docstring inherited.
427 return visitor.visit_is_null(self.operand, flags)
430@final
431class Comparison(PredicateLeafBase):
432 """A boolean columns expression formed by comparing two non-boolean
433 expressions.
434 """
436 predicate_type: Literal["comparison"] = "comparison"
438 a: ColumnExpression
439 """Left-hand side expression for the comparison."""
441 b: ColumnExpression
442 """Right-hand side expression for the comparison."""
444 operator: ComparisonOperator
445 """Comparison operator."""
447 def gather_required_columns(self, columns: ColumnSet) -> None:
448 # Docstring inherited.
449 self.a.gather_required_columns(columns)
450 self.b.gather_required_columns(columns)
452 def __str__(self) -> str:
453 return f"{self.a} {self.operator.upper()} {self.b}"
455 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
456 # Docstring inherited.
457 return visitor.visit_comparison(self.a, self.operator, self.b, flags)
459 @pydantic.model_validator(mode="after")
460 def _validate_column_types(self) -> Comparison:
461 if self.a.column_type != self.b.column_type:
462 raise InvalidQueryError(
463 f"Column types for comparison {self} do not agree "
464 f"({self.a.column_type}, {self.b.column_type})."
465 )
466 match (self.operator, self.a.column_type):
467 case ("==" | "!=", _):
468 pass
469 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"):
470 pass
471 case ("overlaps", "region" | "timespan"):
472 pass
473 case _:
474 raise InvalidQueryError(
475 f"Invalid column type {self.a.column_type} for operator {self.operator!r}."
476 )
477 return self
480@final
481class InContainer(PredicateLeafBase):
482 """A boolean column expression that tests whether one expression is a
483 member of an explicit sequence of other expressions.
484 """
486 predicate_type: Literal["in_container"] = "in_container"
488 member: ColumnExpression
489 """Expression to test for membership."""
491 container: tuple[ColumnExpression, ...]
492 """Expressions representing the elements of the container."""
494 def gather_required_columns(self, columns: ColumnSet) -> None:
495 # Docstring inherited.
496 self.member.gather_required_columns(columns)
497 for item in self.container:
498 item.gather_required_columns(columns)
500 def __str__(self) -> str:
501 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]"
503 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
504 # Docstring inherited.
505 return visitor.visit_in_container(self.member, self.container, flags)
507 @pydantic.model_validator(mode="after")
508 def _validate(self) -> InContainer:
509 if self.member.column_type == "timespan" or self.member.column_type == "region":
510 raise InvalidQueryError(
511 f"Timespan or region column {self.member} may not be used in IN expressions."
512 )
513 if not all(item.column_type == self.member.column_type for item in self.container):
514 raise InvalidQueryError(f"Column types for membership test {self} do not agree.")
515 return self
518@final
519class InRange(PredicateLeafBase):
520 """A boolean column expression that tests whether its expression is
521 included in an integer range.
522 """
524 predicate_type: Literal["in_range"] = "in_range"
526 member: ColumnExpression
527 """Expression to test for membership."""
529 start: int = 0
530 """Inclusive lower bound for the range."""
532 stop: int | None = None
533 """Exclusive upper bound for the range."""
535 step: int = 1
536 """Difference between values in the range."""
538 def gather_required_columns(self, columns: ColumnSet) -> None:
539 # Docstring inherited.
540 self.member.gather_required_columns(columns)
542 def __str__(self) -> str:
543 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}"
544 if self.step != 1:
545 s = f"{s}:{self.step}"
546 return f"{self.member} IN {s}"
548 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
549 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags)
551 @pydantic.model_validator(mode="after")
552 def _validate(self) -> InRange:
553 if self.member.column_type != "int":
554 raise InvalidQueryError(f"Column {self.member} is not an integer.")
555 if self.step < 1:
556 raise InvalidQueryError("Range step must be >= 1.")
557 if self.stop is not None and self.stop < self.start:
558 raise InvalidQueryError("Range stop must be >= start.")
559 return self
562@final
563class InQuery(PredicateLeafBase):
564 """A boolean column expression that tests whether its expression is
565 included single-column projection of a relation.
567 This is primarily intended to be used on dataset ID columns, but it may
568 be useful for other columns as well.
569 """
571 predicate_type: Literal["in_query"] = "in_query"
573 member: ColumnExpression
574 """Expression to test for membership."""
576 column: ColumnExpression
577 """Expression to extract from `query_tree`."""
579 query_tree: QueryTree
580 """Relation whose rows from `column` represent the container."""
582 def gather_required_columns(self, columns: ColumnSet) -> None:
583 # Docstring inherited.
584 # We're only gathering columns from the query_tree this predicate is
585 # attached to, not `self.column`, which belongs to `self.query_tree`.
586 self.member.gather_required_columns(columns)
588 def __str__(self) -> str:
589 return f"{self.member} IN (query).{self.column}"
591 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
592 # Docstring inherited.
593 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags)
595 @pydantic.model_validator(mode="after")
596 def _validate_column_types(self) -> InQuery:
597 if self.member.column_type == "timespan" or self.member.column_type == "region":
598 raise InvalidQueryError(
599 f"Timespan or region column {self.member} may not be used in IN expressions."
600 )
601 if self.member.column_type != self.column.column_type:
602 raise InvalidQueryError(
603 f"Column types for membership test {self} do not agree "
604 f"({self.member.column_type}, {self.column.column_type})."
605 )
607 from ._column_set import ColumnSet
609 columns_required_in_tree = ColumnSet(self.query_tree.dimensions)
610 self.column.gather_required_columns(columns_required_in_tree)
611 if columns_required_in_tree.dimensions != self.query_tree.dimensions:
612 raise InvalidQueryError(
613 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, "
614 f"but query tree only has {self.query_tree.dimensions}."
615 )
616 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys():
617 raise InvalidQueryError(
618 f"Column {self.column} requires dataset types "
619 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree."
620 )
621 return self
624LogicalNotOperand: TypeAlias = Union[
625 IsNull,
626 Comparison,
627 InContainer,
628 InRange,
629 InQuery,
630]
631PredicateLeaf: TypeAlias = Annotated[
632 Union[LogicalNotOperand, LogicalNot], pydantic.Field(discriminator="predicate_type")
633]
635PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]