Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%
235 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-05 11:36 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-05 11:36 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "Predicate",
32 "PredicateLeaf",
33 "LogicalNotOperand",
34 "PredicateOperands",
35 "ComparisonOperator",
36)
38import itertools
39from abc import ABC, abstractmethod
40from typing import TYPE_CHECKING, Annotated, Iterable, Literal, TypeAlias, TypeVar, Union, cast, final
42import pydantic
44from ._base import InvalidQueryError, QueryTreeBase
45from ._column_expression import ColumnExpression
47if TYPE_CHECKING:
48 from ..visitors import PredicateVisitFlags, PredicateVisitor
49 from ._column_set import ColumnSet
50 from ._query_tree import QueryTree
52ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"]
55_L = TypeVar("_L")
56_A = TypeVar("_A")
57_O = TypeVar("_O")
60class PredicateLeafBase(QueryTreeBase, ABC):
61 """Base class for leaf nodes of the `Predicate` tree.
63 This is a closed hierarchy whose concrete, `~typing.final` derived classes
64 are members of the `PredicateLeaf` union. That union should generally
65 be used in type annotations rather than the technically-open base class.
66 """
68 @abstractmethod
69 def gather_required_columns(self, columns: ColumnSet) -> None:
70 """Add any columns required to evaluate this predicate leaf to the
71 given column set.
73 Parameters
74 ----------
75 columns : `ColumnSet`
76 Set of columns to modify in place.
77 """
78 raise NotImplementedError()
80 def invert(self) -> PredicateLeaf:
81 """Return a new leaf that is the logical not of this one."""
82 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self))
84 @abstractmethod
85 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
86 """Invoke the visitor interface.
88 Parameters
89 ----------
90 visitor : `PredicateVisitor`
91 Visitor to invoke a method on.
92 flags : `PredicateVisitFlags`
93 Flags that provide information about where this leaf appears in the
94 larger predicate tree.
96 Returns
97 -------
98 result : `object`
99 Forwarded result from the visitor.
100 """
101 raise NotImplementedError()
104@final
105class Predicate(QueryTreeBase):
106 """A boolean column expression.
108 Notes
109 -----
110 Predicate is the only class representing a boolean column expression that
111 should be used outside of this module (though the objects it nests appear
112 in its serialized form and hence are not fully private). It provides
113 several `classmethod` factories for constructing those nested types inside
114 a `Predicate` instance, and `PredicateVisitor` subclasses should be used
115 to process them.
116 """
118 operands: PredicateOperands
119 """Nested tuple of operands, with outer items combined via AND and inner
120 items combined via OR.
121 """
123 @property
124 def column_type(self) -> Literal["bool"]:
125 """A string enumeration value representing the type of the column
126 expression.
127 """
128 return "bool"
130 @classmethod
131 def from_bool(cls, value: bool) -> Predicate:
132 """Construct a predicate that always evaluates to `True` or `False`.
134 Parameters
135 ----------
136 value : `bool`
137 Value the predicate should evaluate to.
139 Returns
140 -------
141 predicate : `Predicate`
142 Predicate that evaluates to the given boolean value.
143 """
144 return cls.model_construct(operands=() if value else ((),))
146 @classmethod
147 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate:
148 """Construct a predicate representing a binary comparison between
149 two non-boolean column expressions.
151 Parameters
152 ----------
153 a : `ColumnExpression`
154 First column expression in the comparison.
155 operator : `str`
156 Enumerated string representing the comparison operator to apply.
157 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps".
158 b : `ColumnExpression`
159 Second column expression in the comparison.
161 Returns
162 -------
163 predicate : `Predicate`
164 Predicate representing the comparison.
165 """
166 return cls._from_leaf(Comparison(a=a, operator=operator, b=b))
168 @classmethod
169 def is_null(cls, operand: ColumnExpression) -> Predicate:
170 """Construct a predicate that tests whether a column expression is
171 NULL.
173 Parameters
174 ----------
175 operand : `ColumnExpression`
176 Column expression to test.
178 Returns
179 -------
180 predicate : `Predicate`
181 Predicate representing the NULL check.
182 """
183 return cls._from_leaf(IsNull(operand=operand))
185 @classmethod
186 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate:
187 """Construct a predicate that tests whether one column expression is
188 a member of a container of other column expressions.
190 Parameters
191 ----------
192 member : `ColumnExpression`
193 Column expression that may be a member of the container.
194 container : `~collections.abc.Iterable` [ `ColumnExpression` ]
195 Container of column expressions to test for membership in.
197 Returns
198 -------
199 predicate : `Predicate`
200 Predicate representing the membership test.
201 """
202 return cls._from_leaf(InContainer(member=member, container=tuple(container)))
204 @classmethod
205 def in_range(
206 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1
207 ) -> Predicate:
208 """Construct a predicate that tests whether an integer column
209 expression is part of a strided range.
211 Parameters
212 ----------
213 member : `ColumnExpression`
214 Column expression that may be a member of the range.
215 start : `int`, optional
216 Beginning of the range, inclusive.
217 stop : `int` or `None`, optional
218 End of the range, exclusive.
219 step : `int`, optional
220 Offset between values in the range.
222 Returns
223 -------
224 predicate : `Predicate`
225 Predicate representing the membership test.
226 """
227 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step))
229 @classmethod
230 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate:
231 """Construct a predicate that tests whether a column expression is
232 present in a single-column projection of a query tree.
234 Parameters
235 ----------
236 member : `ColumnExpression`
237 Column expression that may be present in the query.
238 column : `ColumnExpression`
239 Column to project from the query.
240 query_tree : `QueryTree`
241 Query tree to select from.
243 Returns
244 -------
245 predicate : `Predicate`
246 Predicate representing the membership test.
247 """
248 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree))
250 def gather_required_columns(self, columns: ColumnSet) -> None:
251 """Add any columns required to evaluate this predicate to the given
252 column set.
254 Parameters
255 ----------
256 columns : `ColumnSet`
257 Set of columns to modify in place.
258 """
259 for or_group in self.operands:
260 for operand in or_group:
261 operand.gather_required_columns(columns)
263 def logical_and(self, *args: Predicate) -> Predicate:
264 """Construct a predicate representing the logical AND of this predicate
265 and one or more others.
267 Parameters
268 ----------
269 *args : `Predicate`
270 Other predicates.
272 Returns
273 -------
274 predicate : `Predicate`
275 Predicate representing the logical AND.
276 """
277 operands = self.operands
278 for arg in args:
279 operands = self._impl_and(operands, arg.operands)
280 if not all(operands):
281 # If any item in operands is an empty tuple (i.e. False), simplify.
282 operands = ((),)
283 return Predicate.model_construct(operands=operands)
285 def logical_or(self, *args: Predicate) -> Predicate:
286 """Construct a predicate representing the logical OR of this predicate
287 and one or more others.
289 Parameters
290 ----------
291 *args : `Predicate`
292 Other predicates.
294 Returns
295 -------
296 predicate : `Predicate`
297 Predicate representing the logical OR.
298 """
299 operands = self.operands
300 for arg in args:
301 operands = self._impl_or(operands, arg.operands)
302 return Predicate.model_construct(operands=operands)
304 def logical_not(self) -> Predicate:
305 """Construct a predicate representing the logical NOT of this
306 predicate.
308 Returns
309 -------
310 predicate : `Predicate`
311 Predicate representing the logical NOT.
312 """
313 new_operands: PredicateOperands = ((),)
314 for or_group in self.operands:
315 new_group: PredicateOperands = ()
316 for leaf in or_group:
317 new_group = self._impl_and(new_group, ((leaf.invert(),),))
318 new_operands = self._impl_or(new_operands, new_group)
319 return Predicate.model_construct(operands=new_operands)
321 def __str__(self) -> str:
322 and_terms = []
323 for or_group in self.operands:
324 match len(or_group):
325 case 0:
326 and_terms.append("False")
327 case 1:
328 and_terms.append(str(or_group[0]))
329 case _:
330 or_str = " OR ".join(str(operand) for operand in or_group)
331 if len(self.operands) > 1:
332 and_terms.append(f"({or_str})")
333 else:
334 and_terms.append(or_str)
335 if not and_terms:
336 return "True"
337 return " AND ".join(and_terms)
339 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A:
340 """Invoke the visitor interface.
342 Parameters
343 ----------
344 visitor : `PredicateVisitor`
345 Visitor to invoke a method on.
347 Returns
348 -------
349 result : `object`
350 Forwarded result from the visitor.
351 """
352 return visitor._visit_logical_and(self.operands)
354 @classmethod
355 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate:
356 return cls._from_or_group((leaf,))
358 @classmethod
359 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate:
360 return Predicate.model_construct(operands=(or_group,))
362 @classmethod
363 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
364 # We could simplify cases where both sides have some of the same leaf
365 # expressions; even using 'is' tests would simplify some cases where
366 # converting to conjunctive normal form twice leads to a lot of
367 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of
368 # double-negation. Right now those cases seem pathological enough to
369 # be not worth our time.
370 return a + b if a is not b else a
372 @classmethod
373 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
374 # Same comment re simplification as in _impl_and applies here.
375 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)])
378@final
379class LogicalNot(PredicateLeafBase):
380 """A boolean column expression that inverts its operand."""
382 predicate_type: Literal["not"] = "not"
384 operand: LogicalNotOperand
385 """Upstream boolean expression to invert."""
387 def gather_required_columns(self, columns: ColumnSet) -> None:
388 # Docstring inherited.
389 self.operand.gather_required_columns(columns)
391 def __str__(self) -> str:
392 return f"NOT {self.operand}"
394 def invert(self) -> LogicalNotOperand:
395 # Docstring inherited.
396 return self.operand
398 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
399 # Docstring inherited.
400 return visitor._visit_logical_not(self.operand, flags)
403@final
404class IsNull(PredicateLeafBase):
405 """A boolean column expression that tests whether its operand is NULL."""
407 predicate_type: Literal["is_null"] = "is_null"
409 operand: ColumnExpression
410 """Upstream expression to test."""
412 def gather_required_columns(self, columns: ColumnSet) -> None:
413 # Docstring inherited.
414 self.operand.gather_required_columns(columns)
416 def __str__(self) -> str:
417 return f"{self.operand} IS NULL"
419 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
420 # Docstring inherited.
421 return visitor.visit_is_null(self.operand, flags)
424@final
425class Comparison(PredicateLeafBase):
426 """A boolean columns expression formed by comparing two non-boolean
427 expressions.
428 """
430 predicate_type: Literal["comparison"] = "comparison"
432 a: ColumnExpression
433 """Left-hand side expression for the comparison."""
435 b: ColumnExpression
436 """Right-hand side expression for the comparison."""
438 operator: ComparisonOperator
439 """Comparison operator."""
441 def gather_required_columns(self, columns: ColumnSet) -> None:
442 # Docstring inherited.
443 self.a.gather_required_columns(columns)
444 self.b.gather_required_columns(columns)
446 def __str__(self) -> str:
447 return f"{self.a} {self.operator.upper()} {self.b}"
449 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
450 # Docstring inherited.
451 return visitor.visit_comparison(self.a, self.operator, self.b, flags)
453 @pydantic.model_validator(mode="after")
454 def _validate_column_types(self) -> Comparison:
455 if self.a.column_type != self.b.column_type:
456 raise InvalidQueryError(
457 f"Column types for comparison {self} do not agree "
458 f"({self.a.column_type}, {self.b.column_type})."
459 )
460 match (self.operator, self.a.column_type):
461 case ("==" | "!=", _):
462 pass
463 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"):
464 pass
465 case ("overlaps", "region" | "timespan"):
466 pass
467 case _:
468 raise InvalidQueryError(
469 f"Invalid column type {self.a.column_type} for operator {self.operator!r}."
470 )
471 return self
474@final
475class InContainer(PredicateLeafBase):
476 """A boolean column expression that tests whether one expression is a
477 member of an explicit sequence of other expressions.
478 """
480 predicate_type: Literal["in_container"] = "in_container"
482 member: ColumnExpression
483 """Expression to test for membership."""
485 container: tuple[ColumnExpression, ...]
486 """Expressions representing the elements of the container."""
488 def gather_required_columns(self, columns: ColumnSet) -> None:
489 # Docstring inherited.
490 self.member.gather_required_columns(columns)
491 for item in self.container:
492 item.gather_required_columns(columns)
494 def __str__(self) -> str:
495 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]"
497 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
498 # Docstring inherited.
499 return visitor.visit_in_container(self.member, self.container, flags)
501 @pydantic.model_validator(mode="after")
502 def _validate(self) -> InContainer:
503 if self.member.column_type == "timespan" or self.member.column_type == "region":
504 raise InvalidQueryError(
505 f"Timespan or region column {self.member} may not be used in IN expressions."
506 )
507 if not all(item.column_type == self.member.column_type for item in self.container):
508 raise InvalidQueryError(f"Column types for membership test {self} do not agree.")
509 return self
512@final
513class InRange(PredicateLeafBase):
514 """A boolean column expression that tests whether its expression is
515 included in an integer range.
516 """
518 predicate_type: Literal["in_range"] = "in_range"
520 member: ColumnExpression
521 """Expression to test for membership."""
523 start: int = 0
524 """Inclusive lower bound for the range."""
526 stop: int | None = None
527 """Exclusive upper bound for the range."""
529 step: int = 1
530 """Difference between values in the range."""
532 def gather_required_columns(self, columns: ColumnSet) -> None:
533 # Docstring inherited.
534 self.member.gather_required_columns(columns)
536 def __str__(self) -> str:
537 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}"
538 if self.step != 1:
539 s = f"{s}:{self.step}"
540 return f"{self.member} IN {s}"
542 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
543 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags)
545 @pydantic.model_validator(mode="after")
546 def _validate(self) -> InRange:
547 if self.member.column_type != "int":
548 raise InvalidQueryError(f"Column {self.member} is not an integer.")
549 if self.step < 1:
550 raise InvalidQueryError("Range step must be >= 1.")
551 if self.stop is not None and self.stop < self.start:
552 raise InvalidQueryError("Range stop must be >= start.")
553 return self
556@final
557class InQuery(PredicateLeafBase):
558 """A boolean column expression that tests whether its expression is
559 included single-column projection of a relation.
561 This is primarily intended to be used on dataset ID columns, but it may
562 be useful for other columns as well.
563 """
565 predicate_type: Literal["in_query"] = "in_query"
567 member: ColumnExpression
568 """Expression to test for membership."""
570 column: ColumnExpression
571 """Expression to extract from `query_tree`."""
573 query_tree: QueryTree
574 """Relation whose rows from `column` represent the container."""
576 def gather_required_columns(self, columns: ColumnSet) -> None:
577 # Docstring inherited.
578 # We're only gathering columns from the query_tree this predicate is
579 # attached to, not `self.column`, which belongs to `self.query_tree`.
580 self.member.gather_required_columns(columns)
582 def __str__(self) -> str:
583 return f"{self.member} IN (query).{self.column}"
585 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
586 # Docstring inherited.
587 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags)
589 @pydantic.model_validator(mode="after")
590 def _validate_column_types(self) -> InQuery:
591 if self.member.column_type == "timespan" or self.member.column_type == "region":
592 raise InvalidQueryError(
593 f"Timespan or region column {self.member} may not be used in IN expressions."
594 )
595 if self.member.column_type != self.column.column_type:
596 raise InvalidQueryError(
597 f"Column types for membership test {self} do not agree "
598 f"({self.member.column_type}, {self.column.column_type})."
599 )
601 from ._column_set import ColumnSet
603 columns_required_in_tree = ColumnSet(self.query_tree.dimensions)
604 self.column.gather_required_columns(columns_required_in_tree)
605 if columns_required_in_tree.dimensions != self.query_tree.dimensions:
606 raise InvalidQueryError(
607 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, "
608 f"but query tree only has {self.query_tree.dimensions}."
609 )
610 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys():
611 raise InvalidQueryError(
612 f"Column {self.column} requires dataset types "
613 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree."
614 )
615 return self
618LogicalNotOperand: TypeAlias = Union[
619 IsNull,
620 Comparison,
621 InContainer,
622 InRange,
623 InQuery,
624]
625PredicateLeaf: TypeAlias = Annotated[
626 Union[LogicalNotOperand, LogicalNot], pydantic.Field(discriminator="predicate_type")
627]
629PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]