Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%

237 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 02:53 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "Predicate", 

32 "PredicateLeaf", 

33 "LogicalNotOperand", 

34 "PredicateOperands", 

35 "ComparisonOperator", 

36) 

37 

38import itertools 

39from abc import ABC, abstractmethod 

40from collections.abc import Iterable 

41from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, cast, final 

42 

43import pydantic 

44 

45from ..._exceptions import InvalidQueryError 

46from ._base import QueryTreeBase 

47from ._column_expression import ColumnExpression 

48 

49if TYPE_CHECKING: 

50 from ..visitors import PredicateVisitFlags, PredicateVisitor 

51 from ._column_set import ColumnSet 

52 from ._query_tree import QueryTree 

53 

54ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"] 

55 

56 

57_L = TypeVar("_L") 

58_A = TypeVar("_A") 

59_O = TypeVar("_O") 

60 

61 

62class PredicateLeafBase(QueryTreeBase, ABC): 

63 """Base class for leaf nodes of the `Predicate` tree. 

64 

65 This is a closed hierarchy whose concrete, `~typing.final` derived classes 

66 are members of the `PredicateLeaf` union. That union should generally 

67 be used in type annotations rather than the technically-open base class. 

68 """ 

69 

70 @abstractmethod 

71 def gather_required_columns(self, columns: ColumnSet) -> None: 

72 """Add any columns required to evaluate this predicate leaf to the 

73 given column set. 

74 

75 Parameters 

76 ---------- 

77 columns : `ColumnSet` 

78 Set of columns to modify in place. 

79 """ 

80 raise NotImplementedError() 

81 

82 def invert(self) -> PredicateLeaf: 

83 """Return a new leaf that is the logical not of this one.""" 

84 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self)) 

85 

86 @abstractmethod 

87 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

88 """Invoke the visitor interface. 

89 

90 Parameters 

91 ---------- 

92 visitor : `PredicateVisitor` 

93 Visitor to invoke a method on. 

94 flags : `PredicateVisitFlags` 

95 Flags that provide information about where this leaf appears in the 

96 larger predicate tree. 

97 

98 Returns 

99 ------- 

100 result : `object` 

101 Forwarded result from the visitor. 

102 """ 

103 raise NotImplementedError() 

104 

105 

106@final 

107class Predicate(QueryTreeBase): 

108 """A boolean column expression. 

109 

110 Notes 

111 ----- 

112 Predicate is the only class representing a boolean column expression that 

113 should be used outside of this module (though the objects it nests appear 

114 in its serialized form and hence are not fully private). It provides 

115 several `classmethod` factories for constructing those nested types inside 

116 a `Predicate` instance, and `PredicateVisitor` subclasses should be used 

117 to process them. 

118 """ 

119 

120 operands: PredicateOperands 

121 """Nested tuple of operands, with outer items combined via AND and inner 

122 items combined via OR. 

123 """ 

124 

125 @property 

126 def column_type(self) -> Literal["bool"]: 

127 """A string enumeration value representing the type of the column 

128 expression. 

129 """ 

130 return "bool" 

131 

132 @classmethod 

133 def from_bool(cls, value: bool) -> Predicate: 

134 """Construct a predicate that always evaluates to `True` or `False`. 

135 

136 Parameters 

137 ---------- 

138 value : `bool` 

139 Value the predicate should evaluate to. 

140 

141 Returns 

142 ------- 

143 predicate : `Predicate` 

144 Predicate that evaluates to the given boolean value. 

145 """ 

146 # The values for True and False here make sense if you think about 

147 # calling `all` and `any` with empty sequences; note that the 

148 # `self.operands` attribute is evaluated as: 

149 # 

150 # value = all(any(or_group) for or_group in self.operands) 

151 # 

152 return cls.model_construct(operands=() if value else ((),)) 

153 

154 @classmethod 

155 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: 

156 """Construct a predicate representing a binary comparison between 

157 two non-boolean column expressions. 

158 

159 Parameters 

160 ---------- 

161 a : `ColumnExpression` 

162 First column expression in the comparison. 

163 operator : `str` 

164 Enumerated string representing the comparison operator to apply. 

165 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". 

166 b : `ColumnExpression` 

167 Second column expression in the comparison. 

168 

169 Returns 

170 ------- 

171 predicate : `Predicate` 

172 Predicate representing the comparison. 

173 """ 

174 return cls._from_leaf(Comparison(a=a, operator=operator, b=b)) 

175 

176 @classmethod 

177 def is_null(cls, operand: ColumnExpression) -> Predicate: 

178 """Construct a predicate that tests whether a column expression is 

179 NULL. 

180 

181 Parameters 

182 ---------- 

183 operand : `ColumnExpression` 

184 Column expression to test. 

185 

186 Returns 

187 ------- 

188 predicate : `Predicate` 

189 Predicate representing the NULL check. 

190 """ 

191 return cls._from_leaf(IsNull(operand=operand)) 

192 

193 @classmethod 

194 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate: 

195 """Construct a predicate that tests whether one column expression is 

196 a member of a container of other column expressions. 

197 

198 Parameters 

199 ---------- 

200 member : `ColumnExpression` 

201 Column expression that may be a member of the container. 

202 container : `~collections.abc.Iterable` [ `ColumnExpression` ] 

203 Container of column expressions to test for membership in. 

204 

205 Returns 

206 ------- 

207 predicate : `Predicate` 

208 Predicate representing the membership test. 

209 """ 

210 return cls._from_leaf(InContainer(member=member, container=tuple(container))) 

211 

212 @classmethod 

213 def in_range( 

214 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1 

215 ) -> Predicate: 

216 """Construct a predicate that tests whether an integer column 

217 expression is part of a strided range. 

218 

219 Parameters 

220 ---------- 

221 member : `ColumnExpression` 

222 Column expression that may be a member of the range. 

223 start : `int`, optional 

224 Beginning of the range, inclusive. 

225 stop : `int` or `None`, optional 

226 End of the range, exclusive. 

227 step : `int`, optional 

228 Offset between values in the range. 

229 

230 Returns 

231 ------- 

232 predicate : `Predicate` 

233 Predicate representing the membership test. 

234 """ 

235 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step)) 

236 

237 @classmethod 

238 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate: 

239 """Construct a predicate that tests whether a column expression is 

240 present in a single-column projection of a query tree. 

241 

242 Parameters 

243 ---------- 

244 member : `ColumnExpression` 

245 Column expression that may be present in the query. 

246 column : `ColumnExpression` 

247 Column to project from the query. 

248 query_tree : `QueryTree` 

249 Query tree to select from. 

250 

251 Returns 

252 ------- 

253 predicate : `Predicate` 

254 Predicate representing the membership test. 

255 """ 

256 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree)) 

257 

258 def gather_required_columns(self, columns: ColumnSet) -> None: 

259 """Add any columns required to evaluate this predicate to the given 

260 column set. 

261 

262 Parameters 

263 ---------- 

264 columns : `ColumnSet` 

265 Set of columns to modify in place. 

266 """ 

267 for or_group in self.operands: 

268 for operand in or_group: 

269 operand.gather_required_columns(columns) 

270 

271 def logical_and(self, *args: Predicate) -> Predicate: 

272 """Construct a predicate representing the logical AND of this predicate 

273 and one or more others. 

274 

275 Parameters 

276 ---------- 

277 *args : `Predicate` 

278 Other predicates. 

279 

280 Returns 

281 ------- 

282 predicate : `Predicate` 

283 Predicate representing the logical AND. 

284 """ 

285 operands = self.operands 

286 for arg in args: 

287 operands = self._impl_and(operands, arg.operands) 

288 if not all(operands): 

289 # If any item in operands is an empty tuple (i.e. False), simplify. 

290 operands = ((),) 

291 return Predicate.model_construct(operands=operands) 

292 

293 def logical_or(self, *args: Predicate) -> Predicate: 

294 """Construct a predicate representing the logical OR of this predicate 

295 and one or more others. 

296 

297 Parameters 

298 ---------- 

299 *args : `Predicate` 

300 Other predicates. 

301 

302 Returns 

303 ------- 

304 predicate : `Predicate` 

305 Predicate representing the logical OR. 

306 """ 

307 operands = self.operands 

308 for arg in args: 

309 operands = self._impl_or(operands, arg.operands) 

310 return Predicate.model_construct(operands=operands) 

311 

312 def logical_not(self) -> Predicate: 

313 """Construct a predicate representing the logical NOT of this 

314 predicate. 

315 

316 Returns 

317 ------- 

318 predicate : `Predicate` 

319 Predicate representing the logical NOT. 

320 """ 

321 new_operands: PredicateOperands = ((),) 

322 for or_group in self.operands: 

323 new_group: PredicateOperands = () 

324 for leaf in or_group: 

325 new_group = self._impl_and(new_group, ((leaf.invert(),),)) 

326 new_operands = self._impl_or(new_operands, new_group) 

327 return Predicate.model_construct(operands=new_operands) 

328 

329 def __str__(self) -> str: 

330 and_terms = [] 

331 for or_group in self.operands: 

332 match len(or_group): 

333 case 0: 

334 and_terms.append("False") 

335 case 1: 

336 and_terms.append(str(or_group[0])) 

337 case _: 

338 or_str = " OR ".join(str(operand) for operand in or_group) 

339 if len(self.operands) > 1: 

340 and_terms.append(f"({or_str})") 

341 else: 

342 and_terms.append(or_str) 

343 if not and_terms: 

344 return "True" 

345 return " AND ".join(and_terms) 

346 

347 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A: 

348 """Invoke the visitor interface. 

349 

350 Parameters 

351 ---------- 

352 visitor : `PredicateVisitor` 

353 Visitor to invoke a method on. 

354 

355 Returns 

356 ------- 

357 result : `object` 

358 Forwarded result from the visitor. 

359 """ 

360 return visitor._visit_logical_and(self.operands) 

361 

362 @classmethod 

363 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate: 

364 return cls._from_or_group((leaf,)) 

365 

366 @classmethod 

367 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate: 

368 return Predicate.model_construct(operands=(or_group,)) 

369 

370 @classmethod 

371 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

372 # We could simplify cases where both sides have some of the same leaf 

373 # expressions; even using 'is' tests would simplify some cases where 

374 # converting to conjunctive normal form twice leads to a lot of 

375 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of 

376 # double-negation. Right now those cases seem pathological enough to 

377 # be not worth our time. 

378 return a + b if a is not b else a 

379 

380 @classmethod 

381 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

382 # Same comment re simplification as in _impl_and applies here. 

383 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)]) 

384 

385 

386@final 

387class LogicalNot(PredicateLeafBase): 

388 """A boolean column expression that inverts its operand.""" 

389 

390 predicate_type: Literal["not"] = "not" 

391 

392 operand: LogicalNotOperand 

393 """Upstream boolean expression to invert.""" 

394 

395 def gather_required_columns(self, columns: ColumnSet) -> None: 

396 # Docstring inherited. 

397 self.operand.gather_required_columns(columns) 

398 

399 def __str__(self) -> str: 

400 return f"NOT {self.operand}" 

401 

402 def invert(self) -> LogicalNotOperand: 

403 # Docstring inherited. 

404 return self.operand 

405 

406 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

407 # Docstring inherited. 

408 return visitor._visit_logical_not(self.operand, flags) 

409 

410 

411@final 

412class IsNull(PredicateLeafBase): 

413 """A boolean column expression that tests whether its operand is NULL.""" 

414 

415 predicate_type: Literal["is_null"] = "is_null" 

416 

417 operand: ColumnExpression 

418 """Upstream expression to test.""" 

419 

420 def gather_required_columns(self, columns: ColumnSet) -> None: 

421 # Docstring inherited. 

422 self.operand.gather_required_columns(columns) 

423 

424 def __str__(self) -> str: 

425 return f"{self.operand} IS NULL" 

426 

427 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

428 # Docstring inherited. 

429 return visitor.visit_is_null(self.operand, flags) 

430 

431 

432@final 

433class Comparison(PredicateLeafBase): 

434 """A boolean columns expression formed by comparing two non-boolean 

435 expressions. 

436 """ 

437 

438 predicate_type: Literal["comparison"] = "comparison" 

439 

440 a: ColumnExpression 

441 """Left-hand side expression for the comparison.""" 

442 

443 b: ColumnExpression 

444 """Right-hand side expression for the comparison.""" 

445 

446 operator: ComparisonOperator 

447 """Comparison operator.""" 

448 

449 def gather_required_columns(self, columns: ColumnSet) -> None: 

450 # Docstring inherited. 

451 self.a.gather_required_columns(columns) 

452 self.b.gather_required_columns(columns) 

453 

454 def __str__(self) -> str: 

455 return f"{self.a} {self.operator.upper()} {self.b}" 

456 

457 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

458 # Docstring inherited. 

459 return visitor.visit_comparison(self.a, self.operator, self.b, flags) 

460 

461 @pydantic.model_validator(mode="after") 

462 def _validate_column_types(self) -> Comparison: 

463 if self.a.column_type != self.b.column_type: 

464 raise InvalidQueryError( 

465 f"Column types for comparison {self} do not agree " 

466 f"({self.a.column_type}, {self.b.column_type})." 

467 ) 

468 match (self.operator, self.a.column_type): 

469 case ("==" | "!=", _): 

470 pass 

471 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"): 

472 pass 

473 case ("overlaps", "region" | "timespan"): 

474 pass 

475 case _: 

476 raise InvalidQueryError( 

477 f"Invalid column type {self.a.column_type} for operator {self.operator!r}." 

478 ) 

479 return self 

480 

481 

482@final 

483class InContainer(PredicateLeafBase): 

484 """A boolean column expression that tests whether one expression is a 

485 member of an explicit sequence of other expressions. 

486 """ 

487 

488 predicate_type: Literal["in_container"] = "in_container" 

489 

490 member: ColumnExpression 

491 """Expression to test for membership.""" 

492 

493 container: tuple[ColumnExpression, ...] 

494 """Expressions representing the elements of the container.""" 

495 

496 def gather_required_columns(self, columns: ColumnSet) -> None: 

497 # Docstring inherited. 

498 self.member.gather_required_columns(columns) 

499 for item in self.container: 

500 item.gather_required_columns(columns) 

501 

502 def __str__(self) -> str: 

503 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]" 

504 

505 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

506 # Docstring inherited. 

507 return visitor.visit_in_container(self.member, self.container, flags) 

508 

509 @pydantic.model_validator(mode="after") 

510 def _validate(self) -> InContainer: 

511 if self.member.column_type == "timespan" or self.member.column_type == "region": 

512 raise InvalidQueryError( 

513 f"Timespan or region column {self.member} may not be used in IN expressions." 

514 ) 

515 if not all(item.column_type == self.member.column_type for item in self.container): 

516 raise InvalidQueryError(f"Column types for membership test {self} do not agree.") 

517 return self 

518 

519 

520@final 

521class InRange(PredicateLeafBase): 

522 """A boolean column expression that tests whether its expression is 

523 included in an integer range. 

524 """ 

525 

526 predicate_type: Literal["in_range"] = "in_range" 

527 

528 member: ColumnExpression 

529 """Expression to test for membership.""" 

530 

531 start: int = 0 

532 """Inclusive lower bound for the range.""" 

533 

534 stop: int | None = None 

535 """Exclusive upper bound for the range.""" 

536 

537 step: int = 1 

538 """Difference between values in the range.""" 

539 

540 def gather_required_columns(self, columns: ColumnSet) -> None: 

541 # Docstring inherited. 

542 self.member.gather_required_columns(columns) 

543 

544 def __str__(self) -> str: 

545 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}" 

546 if self.step != 1: 

547 s = f"{s}:{self.step}" 

548 return f"{self.member} IN {s}" 

549 

550 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

551 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags) 

552 

553 @pydantic.model_validator(mode="after") 

554 def _validate(self) -> InRange: 

555 if self.member.column_type != "int": 

556 raise InvalidQueryError(f"Column {self.member} is not an integer.") 

557 if self.step < 1: 

558 raise InvalidQueryError("Range step must be >= 1.") 

559 if self.stop is not None and self.stop < self.start: 

560 raise InvalidQueryError("Range stop must be >= start.") 

561 return self 

562 

563 

564@final 

565class InQuery(PredicateLeafBase): 

566 """A boolean column expression that tests whether its expression is 

567 included single-column projection of a relation. 

568 

569 This is primarily intended to be used on dataset ID columns, but it may 

570 be useful for other columns as well. 

571 """ 

572 

573 predicate_type: Literal["in_query"] = "in_query" 

574 

575 member: ColumnExpression 

576 """Expression to test for membership.""" 

577 

578 column: ColumnExpression 

579 """Expression to extract from `query_tree`.""" 

580 

581 query_tree: QueryTree 

582 """Relation whose rows from `column` represent the container.""" 

583 

584 def gather_required_columns(self, columns: ColumnSet) -> None: 

585 # Docstring inherited. 

586 # We're only gathering columns from the query_tree this predicate is 

587 # attached to, not `self.column`, which belongs to `self.query_tree`. 

588 self.member.gather_required_columns(columns) 

589 

590 def __str__(self) -> str: 

591 return f"{self.member} IN (query).{self.column}" 

592 

593 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

594 # Docstring inherited. 

595 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags) 

596 

597 @pydantic.model_validator(mode="after") 

598 def _validate_column_types(self) -> InQuery: 

599 if self.member.column_type == "timespan" or self.member.column_type == "region": 

600 raise InvalidQueryError( 

601 f"Timespan or region column {self.member} may not be used in IN expressions." 

602 ) 

603 if self.member.column_type != self.column.column_type: 

604 raise InvalidQueryError( 

605 f"Column types for membership test {self} do not agree " 

606 f"({self.member.column_type}, {self.column.column_type})." 

607 ) 

608 

609 from ._column_set import ColumnSet 

610 

611 columns_required_in_tree = ColumnSet(self.query_tree.dimensions) 

612 self.column.gather_required_columns(columns_required_in_tree) 

613 if columns_required_in_tree.dimensions != self.query_tree.dimensions: 

614 raise InvalidQueryError( 

615 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, " 

616 f"but query tree only has {self.query_tree.dimensions}." 

617 ) 

618 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys(): 

619 raise InvalidQueryError( 

620 f"Column {self.column} requires dataset types " 

621 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree." 

622 ) 

623 return self 

624 

625 

626LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery 

627PredicateLeaf: TypeAlias = Annotated[ 

628 LogicalNotOperand | LogicalNot, pydantic.Field(discriminator="predicate_type") 

629] 

630 

631PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]