Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%
236 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 09:55 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 09:55 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = (
31 "Predicate",
32 "PredicateLeaf",
33 "LogicalNotOperand",
34 "PredicateOperands",
35 "ComparisonOperator",
36)
38import itertools
39from abc import ABC, abstractmethod
40from collections.abc import Iterable
41from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, cast, final
43import pydantic
45from ._base import InvalidQueryError, QueryTreeBase
46from ._column_expression import ColumnExpression
48if TYPE_CHECKING:
49 from ..visitors import PredicateVisitFlags, PredicateVisitor
50 from ._column_set import ColumnSet
51 from ._query_tree import QueryTree
53ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"]
56_L = TypeVar("_L")
57_A = TypeVar("_A")
58_O = TypeVar("_O")
61class PredicateLeafBase(QueryTreeBase, ABC):
62 """Base class for leaf nodes of the `Predicate` tree.
64 This is a closed hierarchy whose concrete, `~typing.final` derived classes
65 are members of the `PredicateLeaf` union. That union should generally
66 be used in type annotations rather than the technically-open base class.
67 """
69 @abstractmethod
70 def gather_required_columns(self, columns: ColumnSet) -> None:
71 """Add any columns required to evaluate this predicate leaf to the
72 given column set.
74 Parameters
75 ----------
76 columns : `ColumnSet`
77 Set of columns to modify in place.
78 """
79 raise NotImplementedError()
81 def invert(self) -> PredicateLeaf:
82 """Return a new leaf that is the logical not of this one."""
83 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self))
85 @abstractmethod
86 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
87 """Invoke the visitor interface.
89 Parameters
90 ----------
91 visitor : `PredicateVisitor`
92 Visitor to invoke a method on.
93 flags : `PredicateVisitFlags`
94 Flags that provide information about where this leaf appears in the
95 larger predicate tree.
97 Returns
98 -------
99 result : `object`
100 Forwarded result from the visitor.
101 """
102 raise NotImplementedError()
105@final
106class Predicate(QueryTreeBase):
107 """A boolean column expression.
109 Notes
110 -----
111 Predicate is the only class representing a boolean column expression that
112 should be used outside of this module (though the objects it nests appear
113 in its serialized form and hence are not fully private). It provides
114 several `classmethod` factories for constructing those nested types inside
115 a `Predicate` instance, and `PredicateVisitor` subclasses should be used
116 to process them.
117 """
119 operands: PredicateOperands
120 """Nested tuple of operands, with outer items combined via AND and inner
121 items combined via OR.
122 """
124 @property
125 def column_type(self) -> Literal["bool"]:
126 """A string enumeration value representing the type of the column
127 expression.
128 """
129 return "bool"
131 @classmethod
132 def from_bool(cls, value: bool) -> Predicate:
133 """Construct a predicate that always evaluates to `True` or `False`.
135 Parameters
136 ----------
137 value : `bool`
138 Value the predicate should evaluate to.
140 Returns
141 -------
142 predicate : `Predicate`
143 Predicate that evaluates to the given boolean value.
144 """
145 # The values for True and False here make sense if you think about
146 # calling `all` and `any` with empty sequences; note that the
147 # `self.operands` attribute is evaluated as:
148 #
149 # value = all(any(or_group) for or_group in self.operands)
150 #
151 return cls.model_construct(operands=() if value else ((),))
153 @classmethod
154 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate:
155 """Construct a predicate representing a binary comparison between
156 two non-boolean column expressions.
158 Parameters
159 ----------
160 a : `ColumnExpression`
161 First column expression in the comparison.
162 operator : `str`
163 Enumerated string representing the comparison operator to apply.
164 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps".
165 b : `ColumnExpression`
166 Second column expression in the comparison.
168 Returns
169 -------
170 predicate : `Predicate`
171 Predicate representing the comparison.
172 """
173 return cls._from_leaf(Comparison(a=a, operator=operator, b=b))
175 @classmethod
176 def is_null(cls, operand: ColumnExpression) -> Predicate:
177 """Construct a predicate that tests whether a column expression is
178 NULL.
180 Parameters
181 ----------
182 operand : `ColumnExpression`
183 Column expression to test.
185 Returns
186 -------
187 predicate : `Predicate`
188 Predicate representing the NULL check.
189 """
190 return cls._from_leaf(IsNull(operand=operand))
192 @classmethod
193 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate:
194 """Construct a predicate that tests whether one column expression is
195 a member of a container of other column expressions.
197 Parameters
198 ----------
199 member : `ColumnExpression`
200 Column expression that may be a member of the container.
201 container : `~collections.abc.Iterable` [ `ColumnExpression` ]
202 Container of column expressions to test for membership in.
204 Returns
205 -------
206 predicate : `Predicate`
207 Predicate representing the membership test.
208 """
209 return cls._from_leaf(InContainer(member=member, container=tuple(container)))
211 @classmethod
212 def in_range(
213 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1
214 ) -> Predicate:
215 """Construct a predicate that tests whether an integer column
216 expression is part of a strided range.
218 Parameters
219 ----------
220 member : `ColumnExpression`
221 Column expression that may be a member of the range.
222 start : `int`, optional
223 Beginning of the range, inclusive.
224 stop : `int` or `None`, optional
225 End of the range, exclusive.
226 step : `int`, optional
227 Offset between values in the range.
229 Returns
230 -------
231 predicate : `Predicate`
232 Predicate representing the membership test.
233 """
234 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step))
236 @classmethod
237 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate:
238 """Construct a predicate that tests whether a column expression is
239 present in a single-column projection of a query tree.
241 Parameters
242 ----------
243 member : `ColumnExpression`
244 Column expression that may be present in the query.
245 column : `ColumnExpression`
246 Column to project from the query.
247 query_tree : `QueryTree`
248 Query tree to select from.
250 Returns
251 -------
252 predicate : `Predicate`
253 Predicate representing the membership test.
254 """
255 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree))
257 def gather_required_columns(self, columns: ColumnSet) -> None:
258 """Add any columns required to evaluate this predicate to the given
259 column set.
261 Parameters
262 ----------
263 columns : `ColumnSet`
264 Set of columns to modify in place.
265 """
266 for or_group in self.operands:
267 for operand in or_group:
268 operand.gather_required_columns(columns)
270 def logical_and(self, *args: Predicate) -> Predicate:
271 """Construct a predicate representing the logical AND of this predicate
272 and one or more others.
274 Parameters
275 ----------
276 *args : `Predicate`
277 Other predicates.
279 Returns
280 -------
281 predicate : `Predicate`
282 Predicate representing the logical AND.
283 """
284 operands = self.operands
285 for arg in args:
286 operands = self._impl_and(operands, arg.operands)
287 if not all(operands):
288 # If any item in operands is an empty tuple (i.e. False), simplify.
289 operands = ((),)
290 return Predicate.model_construct(operands=operands)
292 def logical_or(self, *args: Predicate) -> Predicate:
293 """Construct a predicate representing the logical OR of this predicate
294 and one or more others.
296 Parameters
297 ----------
298 *args : `Predicate`
299 Other predicates.
301 Returns
302 -------
303 predicate : `Predicate`
304 Predicate representing the logical OR.
305 """
306 operands = self.operands
307 for arg in args:
308 operands = self._impl_or(operands, arg.operands)
309 return Predicate.model_construct(operands=operands)
311 def logical_not(self) -> Predicate:
312 """Construct a predicate representing the logical NOT of this
313 predicate.
315 Returns
316 -------
317 predicate : `Predicate`
318 Predicate representing the logical NOT.
319 """
320 new_operands: PredicateOperands = ((),)
321 for or_group in self.operands:
322 new_group: PredicateOperands = ()
323 for leaf in or_group:
324 new_group = self._impl_and(new_group, ((leaf.invert(),),))
325 new_operands = self._impl_or(new_operands, new_group)
326 return Predicate.model_construct(operands=new_operands)
328 def __str__(self) -> str:
329 and_terms = []
330 for or_group in self.operands:
331 match len(or_group):
332 case 0:
333 and_terms.append("False")
334 case 1:
335 and_terms.append(str(or_group[0]))
336 case _:
337 or_str = " OR ".join(str(operand) for operand in or_group)
338 if len(self.operands) > 1:
339 and_terms.append(f"({or_str})")
340 else:
341 and_terms.append(or_str)
342 if not and_terms:
343 return "True"
344 return " AND ".join(and_terms)
346 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A:
347 """Invoke the visitor interface.
349 Parameters
350 ----------
351 visitor : `PredicateVisitor`
352 Visitor to invoke a method on.
354 Returns
355 -------
356 result : `object`
357 Forwarded result from the visitor.
358 """
359 return visitor._visit_logical_and(self.operands)
361 @classmethod
362 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate:
363 return cls._from_or_group((leaf,))
365 @classmethod
366 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate:
367 return Predicate.model_construct(operands=(or_group,))
369 @classmethod
370 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
371 # We could simplify cases where both sides have some of the same leaf
372 # expressions; even using 'is' tests would simplify some cases where
373 # converting to conjunctive normal form twice leads to a lot of
374 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of
375 # double-negation. Right now those cases seem pathological enough to
376 # be not worth our time.
377 return a + b if a is not b else a
379 @classmethod
380 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands:
381 # Same comment re simplification as in _impl_and applies here.
382 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)])
385@final
386class LogicalNot(PredicateLeafBase):
387 """A boolean column expression that inverts its operand."""
389 predicate_type: Literal["not"] = "not"
391 operand: LogicalNotOperand
392 """Upstream boolean expression to invert."""
394 def gather_required_columns(self, columns: ColumnSet) -> None:
395 # Docstring inherited.
396 self.operand.gather_required_columns(columns)
398 def __str__(self) -> str:
399 return f"NOT {self.operand}"
401 def invert(self) -> LogicalNotOperand:
402 # Docstring inherited.
403 return self.operand
405 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
406 # Docstring inherited.
407 return visitor._visit_logical_not(self.operand, flags)
410@final
411class IsNull(PredicateLeafBase):
412 """A boolean column expression that tests whether its operand is NULL."""
414 predicate_type: Literal["is_null"] = "is_null"
416 operand: ColumnExpression
417 """Upstream expression to test."""
419 def gather_required_columns(self, columns: ColumnSet) -> None:
420 # Docstring inherited.
421 self.operand.gather_required_columns(columns)
423 def __str__(self) -> str:
424 return f"{self.operand} IS NULL"
426 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
427 # Docstring inherited.
428 return visitor.visit_is_null(self.operand, flags)
431@final
432class Comparison(PredicateLeafBase):
433 """A boolean columns expression formed by comparing two non-boolean
434 expressions.
435 """
437 predicate_type: Literal["comparison"] = "comparison"
439 a: ColumnExpression
440 """Left-hand side expression for the comparison."""
442 b: ColumnExpression
443 """Right-hand side expression for the comparison."""
445 operator: ComparisonOperator
446 """Comparison operator."""
448 def gather_required_columns(self, columns: ColumnSet) -> None:
449 # Docstring inherited.
450 self.a.gather_required_columns(columns)
451 self.b.gather_required_columns(columns)
453 def __str__(self) -> str:
454 return f"{self.a} {self.operator.upper()} {self.b}"
456 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
457 # Docstring inherited.
458 return visitor.visit_comparison(self.a, self.operator, self.b, flags)
460 @pydantic.model_validator(mode="after")
461 def _validate_column_types(self) -> Comparison:
462 if self.a.column_type != self.b.column_type:
463 raise InvalidQueryError(
464 f"Column types for comparison {self} do not agree "
465 f"({self.a.column_type}, {self.b.column_type})."
466 )
467 match (self.operator, self.a.column_type):
468 case ("==" | "!=", _):
469 pass
470 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"):
471 pass
472 case ("overlaps", "region" | "timespan"):
473 pass
474 case _:
475 raise InvalidQueryError(
476 f"Invalid column type {self.a.column_type} for operator {self.operator!r}."
477 )
478 return self
481@final
482class InContainer(PredicateLeafBase):
483 """A boolean column expression that tests whether one expression is a
484 member of an explicit sequence of other expressions.
485 """
487 predicate_type: Literal["in_container"] = "in_container"
489 member: ColumnExpression
490 """Expression to test for membership."""
492 container: tuple[ColumnExpression, ...]
493 """Expressions representing the elements of the container."""
495 def gather_required_columns(self, columns: ColumnSet) -> None:
496 # Docstring inherited.
497 self.member.gather_required_columns(columns)
498 for item in self.container:
499 item.gather_required_columns(columns)
501 def __str__(self) -> str:
502 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]"
504 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
505 # Docstring inherited.
506 return visitor.visit_in_container(self.member, self.container, flags)
508 @pydantic.model_validator(mode="after")
509 def _validate(self) -> InContainer:
510 if self.member.column_type == "timespan" or self.member.column_type == "region":
511 raise InvalidQueryError(
512 f"Timespan or region column {self.member} may not be used in IN expressions."
513 )
514 if not all(item.column_type == self.member.column_type for item in self.container):
515 raise InvalidQueryError(f"Column types for membership test {self} do not agree.")
516 return self
519@final
520class InRange(PredicateLeafBase):
521 """A boolean column expression that tests whether its expression is
522 included in an integer range.
523 """
525 predicate_type: Literal["in_range"] = "in_range"
527 member: ColumnExpression
528 """Expression to test for membership."""
530 start: int = 0
531 """Inclusive lower bound for the range."""
533 stop: int | None = None
534 """Exclusive upper bound for the range."""
536 step: int = 1
537 """Difference between values in the range."""
539 def gather_required_columns(self, columns: ColumnSet) -> None:
540 # Docstring inherited.
541 self.member.gather_required_columns(columns)
543 def __str__(self) -> str:
544 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}"
545 if self.step != 1:
546 s = f"{s}:{self.step}"
547 return f"{self.member} IN {s}"
549 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
550 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags)
552 @pydantic.model_validator(mode="after")
553 def _validate(self) -> InRange:
554 if self.member.column_type != "int":
555 raise InvalidQueryError(f"Column {self.member} is not an integer.")
556 if self.step < 1:
557 raise InvalidQueryError("Range step must be >= 1.")
558 if self.stop is not None and self.stop < self.start:
559 raise InvalidQueryError("Range stop must be >= start.")
560 return self
563@final
564class InQuery(PredicateLeafBase):
565 """A boolean column expression that tests whether its expression is
566 included single-column projection of a relation.
568 This is primarily intended to be used on dataset ID columns, but it may
569 be useful for other columns as well.
570 """
572 predicate_type: Literal["in_query"] = "in_query"
574 member: ColumnExpression
575 """Expression to test for membership."""
577 column: ColumnExpression
578 """Expression to extract from `query_tree`."""
580 query_tree: QueryTree
581 """Relation whose rows from `column` represent the container."""
583 def gather_required_columns(self, columns: ColumnSet) -> None:
584 # Docstring inherited.
585 # We're only gathering columns from the query_tree this predicate is
586 # attached to, not `self.column`, which belongs to `self.query_tree`.
587 self.member.gather_required_columns(columns)
589 def __str__(self) -> str:
590 return f"{self.member} IN (query).{self.column}"
592 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
593 # Docstring inherited.
594 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags)
596 @pydantic.model_validator(mode="after")
597 def _validate_column_types(self) -> InQuery:
598 if self.member.column_type == "timespan" or self.member.column_type == "region":
599 raise InvalidQueryError(
600 f"Timespan or region column {self.member} may not be used in IN expressions."
601 )
602 if self.member.column_type != self.column.column_type:
603 raise InvalidQueryError(
604 f"Column types for membership test {self} do not agree "
605 f"({self.member.column_type}, {self.column.column_type})."
606 )
608 from ._column_set import ColumnSet
610 columns_required_in_tree = ColumnSet(self.query_tree.dimensions)
611 self.column.gather_required_columns(columns_required_in_tree)
612 if columns_required_in_tree.dimensions != self.query_tree.dimensions:
613 raise InvalidQueryError(
614 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, "
615 f"but query tree only has {self.query_tree.dimensions}."
616 )
617 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys():
618 raise InvalidQueryError(
619 f"Column {self.column} requires dataset types "
620 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree."
621 )
622 return self
625LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery
626PredicateLeaf: TypeAlias = Annotated[
627 LogicalNotOperand | LogicalNot, pydantic.Field(discriminator="predicate_type")
628]
630PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]