Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%

236 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-15 02:03 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "Predicate", 

32 "PredicateLeaf", 

33 "LogicalNotOperand", 

34 "PredicateOperands", 

35 "ComparisonOperator", 

36) 

37 

38import itertools 

39from abc import ABC, abstractmethod 

40from collections.abc import Iterable 

41from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, TypeVar, cast, final 

42 

43import pydantic 

44 

45from ._base import InvalidQueryError, QueryTreeBase 

46from ._column_expression import ColumnExpression 

47 

48if TYPE_CHECKING: 

49 from ..visitors import PredicateVisitFlags, PredicateVisitor 

50 from ._column_set import ColumnSet 

51 from ._query_tree import QueryTree 

52 

53ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"] 

54 

55 

56_L = TypeVar("_L") 

57_A = TypeVar("_A") 

58_O = TypeVar("_O") 

59 

60 

61class PredicateLeafBase(QueryTreeBase, ABC): 

62 """Base class for leaf nodes of the `Predicate` tree. 

63 

64 This is a closed hierarchy whose concrete, `~typing.final` derived classes 

65 are members of the `PredicateLeaf` union. That union should generally 

66 be used in type annotations rather than the technically-open base class. 

67 """ 

68 

69 @abstractmethod 

70 def gather_required_columns(self, columns: ColumnSet) -> None: 

71 """Add any columns required to evaluate this predicate leaf to the 

72 given column set. 

73 

74 Parameters 

75 ---------- 

76 columns : `ColumnSet` 

77 Set of columns to modify in place. 

78 """ 

79 raise NotImplementedError() 

80 

81 def invert(self) -> PredicateLeaf: 

82 """Return a new leaf that is the logical not of this one.""" 

83 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self)) 

84 

85 @abstractmethod 

86 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

87 """Invoke the visitor interface. 

88 

89 Parameters 

90 ---------- 

91 visitor : `PredicateVisitor` 

92 Visitor to invoke a method on. 

93 flags : `PredicateVisitFlags` 

94 Flags that provide information about where this leaf appears in the 

95 larger predicate tree. 

96 

97 Returns 

98 ------- 

99 result : `object` 

100 Forwarded result from the visitor. 

101 """ 

102 raise NotImplementedError() 

103 

104 

105@final 

106class Predicate(QueryTreeBase): 

107 """A boolean column expression. 

108 

109 Notes 

110 ----- 

111 Predicate is the only class representing a boolean column expression that 

112 should be used outside of this module (though the objects it nests appear 

113 in its serialized form and hence are not fully private). It provides 

114 several `classmethod` factories for constructing those nested types inside 

115 a `Predicate` instance, and `PredicateVisitor` subclasses should be used 

116 to process them. 

117 """ 

118 

119 operands: PredicateOperands 

120 """Nested tuple of operands, with outer items combined via AND and inner 

121 items combined via OR. 

122 """ 

123 

124 @property 

125 def column_type(self) -> Literal["bool"]: 

126 """A string enumeration value representing the type of the column 

127 expression. 

128 """ 

129 return "bool" 

130 

131 @classmethod 

132 def from_bool(cls, value: bool) -> Predicate: 

133 """Construct a predicate that always evaluates to `True` or `False`. 

134 

135 Parameters 

136 ---------- 

137 value : `bool` 

138 Value the predicate should evaluate to. 

139 

140 Returns 

141 ------- 

142 predicate : `Predicate` 

143 Predicate that evaluates to the given boolean value. 

144 """ 

145 # The values for True and False here make sense if you think about 

146 # calling `all` and `any` with empty sequences; note that the 

147 # `self.operands` attribute is evaluated as: 

148 # 

149 # value = all(any(or_group) for or_group in self.operands) 

150 # 

151 return cls.model_construct(operands=() if value else ((),)) 

152 

153 @classmethod 

154 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: 

155 """Construct a predicate representing a binary comparison between 

156 two non-boolean column expressions. 

157 

158 Parameters 

159 ---------- 

160 a : `ColumnExpression` 

161 First column expression in the comparison. 

162 operator : `str` 

163 Enumerated string representing the comparison operator to apply. 

164 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". 

165 b : `ColumnExpression` 

166 Second column expression in the comparison. 

167 

168 Returns 

169 ------- 

170 predicate : `Predicate` 

171 Predicate representing the comparison. 

172 """ 

173 return cls._from_leaf(Comparison(a=a, operator=operator, b=b)) 

174 

175 @classmethod 

176 def is_null(cls, operand: ColumnExpression) -> Predicate: 

177 """Construct a predicate that tests whether a column expression is 

178 NULL. 

179 

180 Parameters 

181 ---------- 

182 operand : `ColumnExpression` 

183 Column expression to test. 

184 

185 Returns 

186 ------- 

187 predicate : `Predicate` 

188 Predicate representing the NULL check. 

189 """ 

190 return cls._from_leaf(IsNull(operand=operand)) 

191 

192 @classmethod 

193 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate: 

194 """Construct a predicate that tests whether one column expression is 

195 a member of a container of other column expressions. 

196 

197 Parameters 

198 ---------- 

199 member : `ColumnExpression` 

200 Column expression that may be a member of the container. 

201 container : `~collections.abc.Iterable` [ `ColumnExpression` ] 

202 Container of column expressions to test for membership in. 

203 

204 Returns 

205 ------- 

206 predicate : `Predicate` 

207 Predicate representing the membership test. 

208 """ 

209 return cls._from_leaf(InContainer(member=member, container=tuple(container))) 

210 

211 @classmethod 

212 def in_range( 

213 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1 

214 ) -> Predicate: 

215 """Construct a predicate that tests whether an integer column 

216 expression is part of a strided range. 

217 

218 Parameters 

219 ---------- 

220 member : `ColumnExpression` 

221 Column expression that may be a member of the range. 

222 start : `int`, optional 

223 Beginning of the range, inclusive. 

224 stop : `int` or `None`, optional 

225 End of the range, exclusive. 

226 step : `int`, optional 

227 Offset between values in the range. 

228 

229 Returns 

230 ------- 

231 predicate : `Predicate` 

232 Predicate representing the membership test. 

233 """ 

234 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step)) 

235 

236 @classmethod 

237 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate: 

238 """Construct a predicate that tests whether a column expression is 

239 present in a single-column projection of a query tree. 

240 

241 Parameters 

242 ---------- 

243 member : `ColumnExpression` 

244 Column expression that may be present in the query. 

245 column : `ColumnExpression` 

246 Column to project from the query. 

247 query_tree : `QueryTree` 

248 Query tree to select from. 

249 

250 Returns 

251 ------- 

252 predicate : `Predicate` 

253 Predicate representing the membership test. 

254 """ 

255 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree)) 

256 

257 def gather_required_columns(self, columns: ColumnSet) -> None: 

258 """Add any columns required to evaluate this predicate to the given 

259 column set. 

260 

261 Parameters 

262 ---------- 

263 columns : `ColumnSet` 

264 Set of columns to modify in place. 

265 """ 

266 for or_group in self.operands: 

267 for operand in or_group: 

268 operand.gather_required_columns(columns) 

269 

270 def logical_and(self, *args: Predicate) -> Predicate: 

271 """Construct a predicate representing the logical AND of this predicate 

272 and one or more others. 

273 

274 Parameters 

275 ---------- 

276 *args : `Predicate` 

277 Other predicates. 

278 

279 Returns 

280 ------- 

281 predicate : `Predicate` 

282 Predicate representing the logical AND. 

283 """ 

284 operands = self.operands 

285 for arg in args: 

286 operands = self._impl_and(operands, arg.operands) 

287 if not all(operands): 

288 # If any item in operands is an empty tuple (i.e. False), simplify. 

289 operands = ((),) 

290 return Predicate.model_construct(operands=operands) 

291 

292 def logical_or(self, *args: Predicate) -> Predicate: 

293 """Construct a predicate representing the logical OR of this predicate 

294 and one or more others. 

295 

296 Parameters 

297 ---------- 

298 *args : `Predicate` 

299 Other predicates. 

300 

301 Returns 

302 ------- 

303 predicate : `Predicate` 

304 Predicate representing the logical OR. 

305 """ 

306 operands = self.operands 

307 for arg in args: 

308 operands = self._impl_or(operands, arg.operands) 

309 return Predicate.model_construct(operands=operands) 

310 

311 def logical_not(self) -> Predicate: 

312 """Construct a predicate representing the logical NOT of this 

313 predicate. 

314 

315 Returns 

316 ------- 

317 predicate : `Predicate` 

318 Predicate representing the logical NOT. 

319 """ 

320 new_operands: PredicateOperands = ((),) 

321 for or_group in self.operands: 

322 new_group: PredicateOperands = () 

323 for leaf in or_group: 

324 new_group = self._impl_and(new_group, ((leaf.invert(),),)) 

325 new_operands = self._impl_or(new_operands, new_group) 

326 return Predicate.model_construct(operands=new_operands) 

327 

328 def __str__(self) -> str: 

329 and_terms = [] 

330 for or_group in self.operands: 

331 match len(or_group): 

332 case 0: 

333 and_terms.append("False") 

334 case 1: 

335 and_terms.append(str(or_group[0])) 

336 case _: 

337 or_str = " OR ".join(str(operand) for operand in or_group) 

338 if len(self.operands) > 1: 

339 and_terms.append(f"({or_str})") 

340 else: 

341 and_terms.append(or_str) 

342 if not and_terms: 

343 return "True" 

344 return " AND ".join(and_terms) 

345 

346 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A: 

347 """Invoke the visitor interface. 

348 

349 Parameters 

350 ---------- 

351 visitor : `PredicateVisitor` 

352 Visitor to invoke a method on. 

353 

354 Returns 

355 ------- 

356 result : `object` 

357 Forwarded result from the visitor. 

358 """ 

359 return visitor._visit_logical_and(self.operands) 

360 

361 @classmethod 

362 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate: 

363 return cls._from_or_group((leaf,)) 

364 

365 @classmethod 

366 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate: 

367 return Predicate.model_construct(operands=(or_group,)) 

368 

369 @classmethod 

370 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

371 # We could simplify cases where both sides have some of the same leaf 

372 # expressions; even using 'is' tests would simplify some cases where 

373 # converting to conjunctive normal form twice leads to a lot of 

374 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of 

375 # double-negation. Right now those cases seem pathological enough to 

376 # be not worth our time. 

377 return a + b if a is not b else a 

378 

379 @classmethod 

380 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

381 # Same comment re simplification as in _impl_and applies here. 

382 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)]) 

383 

384 

385@final 

386class LogicalNot(PredicateLeafBase): 

387 """A boolean column expression that inverts its operand.""" 

388 

389 predicate_type: Literal["not"] = "not" 

390 

391 operand: LogicalNotOperand 

392 """Upstream boolean expression to invert.""" 

393 

394 def gather_required_columns(self, columns: ColumnSet) -> None: 

395 # Docstring inherited. 

396 self.operand.gather_required_columns(columns) 

397 

398 def __str__(self) -> str: 

399 return f"NOT {self.operand}" 

400 

401 def invert(self) -> LogicalNotOperand: 

402 # Docstring inherited. 

403 return self.operand 

404 

405 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

406 # Docstring inherited. 

407 return visitor._visit_logical_not(self.operand, flags) 

408 

409 

410@final 

411class IsNull(PredicateLeafBase): 

412 """A boolean column expression that tests whether its operand is NULL.""" 

413 

414 predicate_type: Literal["is_null"] = "is_null" 

415 

416 operand: ColumnExpression 

417 """Upstream expression to test.""" 

418 

419 def gather_required_columns(self, columns: ColumnSet) -> None: 

420 # Docstring inherited. 

421 self.operand.gather_required_columns(columns) 

422 

423 def __str__(self) -> str: 

424 return f"{self.operand} IS NULL" 

425 

426 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

427 # Docstring inherited. 

428 return visitor.visit_is_null(self.operand, flags) 

429 

430 

431@final 

432class Comparison(PredicateLeafBase): 

433 """A boolean columns expression formed by comparing two non-boolean 

434 expressions. 

435 """ 

436 

437 predicate_type: Literal["comparison"] = "comparison" 

438 

439 a: ColumnExpression 

440 """Left-hand side expression for the comparison.""" 

441 

442 b: ColumnExpression 

443 """Right-hand side expression for the comparison.""" 

444 

445 operator: ComparisonOperator 

446 """Comparison operator.""" 

447 

448 def gather_required_columns(self, columns: ColumnSet) -> None: 

449 # Docstring inherited. 

450 self.a.gather_required_columns(columns) 

451 self.b.gather_required_columns(columns) 

452 

453 def __str__(self) -> str: 

454 return f"{self.a} {self.operator.upper()} {self.b}" 

455 

456 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

457 # Docstring inherited. 

458 return visitor.visit_comparison(self.a, self.operator, self.b, flags) 

459 

460 @pydantic.model_validator(mode="after") 

461 def _validate_column_types(self) -> Comparison: 

462 if self.a.column_type != self.b.column_type: 

463 raise InvalidQueryError( 

464 f"Column types for comparison {self} do not agree " 

465 f"({self.a.column_type}, {self.b.column_type})." 

466 ) 

467 match (self.operator, self.a.column_type): 

468 case ("==" | "!=", _): 

469 pass 

470 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"): 

471 pass 

472 case ("overlaps", "region" | "timespan"): 

473 pass 

474 case _: 

475 raise InvalidQueryError( 

476 f"Invalid column type {self.a.column_type} for operator {self.operator!r}." 

477 ) 

478 return self 

479 

480 

481@final 

482class InContainer(PredicateLeafBase): 

483 """A boolean column expression that tests whether one expression is a 

484 member of an explicit sequence of other expressions. 

485 """ 

486 

487 predicate_type: Literal["in_container"] = "in_container" 

488 

489 member: ColumnExpression 

490 """Expression to test for membership.""" 

491 

492 container: tuple[ColumnExpression, ...] 

493 """Expressions representing the elements of the container.""" 

494 

495 def gather_required_columns(self, columns: ColumnSet) -> None: 

496 # Docstring inherited. 

497 self.member.gather_required_columns(columns) 

498 for item in self.container: 

499 item.gather_required_columns(columns) 

500 

501 def __str__(self) -> str: 

502 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]" 

503 

504 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

505 # Docstring inherited. 

506 return visitor.visit_in_container(self.member, self.container, flags) 

507 

508 @pydantic.model_validator(mode="after") 

509 def _validate(self) -> InContainer: 

510 if self.member.column_type == "timespan" or self.member.column_type == "region": 

511 raise InvalidQueryError( 

512 f"Timespan or region column {self.member} may not be used in IN expressions." 

513 ) 

514 if not all(item.column_type == self.member.column_type for item in self.container): 

515 raise InvalidQueryError(f"Column types for membership test {self} do not agree.") 

516 return self 

517 

518 

519@final 

520class InRange(PredicateLeafBase): 

521 """A boolean column expression that tests whether its expression is 

522 included in an integer range. 

523 """ 

524 

525 predicate_type: Literal["in_range"] = "in_range" 

526 

527 member: ColumnExpression 

528 """Expression to test for membership.""" 

529 

530 start: int = 0 

531 """Inclusive lower bound for the range.""" 

532 

533 stop: int | None = None 

534 """Exclusive upper bound for the range.""" 

535 

536 step: int = 1 

537 """Difference between values in the range.""" 

538 

539 def gather_required_columns(self, columns: ColumnSet) -> None: 

540 # Docstring inherited. 

541 self.member.gather_required_columns(columns) 

542 

543 def __str__(self) -> str: 

544 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}" 

545 if self.step != 1: 

546 s = f"{s}:{self.step}" 

547 return f"{self.member} IN {s}" 

548 

549 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

550 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags) 

551 

552 @pydantic.model_validator(mode="after") 

553 def _validate(self) -> InRange: 

554 if self.member.column_type != "int": 

555 raise InvalidQueryError(f"Column {self.member} is not an integer.") 

556 if self.step < 1: 

557 raise InvalidQueryError("Range step must be >= 1.") 

558 if self.stop is not None and self.stop < self.start: 

559 raise InvalidQueryError("Range stop must be >= start.") 

560 return self 

561 

562 

563@final 

564class InQuery(PredicateLeafBase): 

565 """A boolean column expression that tests whether its expression is 

566 included single-column projection of a relation. 

567 

568 This is primarily intended to be used on dataset ID columns, but it may 

569 be useful for other columns as well. 

570 """ 

571 

572 predicate_type: Literal["in_query"] = "in_query" 

573 

574 member: ColumnExpression 

575 """Expression to test for membership.""" 

576 

577 column: ColumnExpression 

578 """Expression to extract from `query_tree`.""" 

579 

580 query_tree: QueryTree 

581 """Relation whose rows from `column` represent the container.""" 

582 

583 def gather_required_columns(self, columns: ColumnSet) -> None: 

584 # Docstring inherited. 

585 # We're only gathering columns from the query_tree this predicate is 

586 # attached to, not `self.column`, which belongs to `self.query_tree`. 

587 self.member.gather_required_columns(columns) 

588 

589 def __str__(self) -> str: 

590 return f"{self.member} IN (query).{self.column}" 

591 

592 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

593 # Docstring inherited. 

594 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags) 

595 

596 @pydantic.model_validator(mode="after") 

597 def _validate_column_types(self) -> InQuery: 

598 if self.member.column_type == "timespan" or self.member.column_type == "region": 

599 raise InvalidQueryError( 

600 f"Timespan or region column {self.member} may not be used in IN expressions." 

601 ) 

602 if self.member.column_type != self.column.column_type: 

603 raise InvalidQueryError( 

604 f"Column types for membership test {self} do not agree " 

605 f"({self.member.column_type}, {self.column.column_type})." 

606 ) 

607 

608 from ._column_set import ColumnSet 

609 

610 columns_required_in_tree = ColumnSet(self.query_tree.dimensions) 

611 self.column.gather_required_columns(columns_required_in_tree) 

612 if columns_required_in_tree.dimensions != self.query_tree.dimensions: 

613 raise InvalidQueryError( 

614 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, " 

615 f"but query tree only has {self.query_tree.dimensions}." 

616 ) 

617 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys(): 

618 raise InvalidQueryError( 

619 f"Column {self.column} requires dataset types " 

620 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree." 

621 ) 

622 return self 

623 

624 

625LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery 

626PredicateLeaf: TypeAlias = Annotated[ 

627 LogicalNotOperand | LogicalNot, pydantic.Field(discriminator="predicate_type") 

628] 

629 

630PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]