Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%

235 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-12 10:07 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "Predicate", 

32 "PredicateLeaf", 

33 "LogicalNotOperand", 

34 "PredicateOperands", 

35 "ComparisonOperator", 

36) 

37 

38import itertools 

39from abc import ABC, abstractmethod 

40from typing import TYPE_CHECKING, Annotated, Iterable, Literal, TypeAlias, TypeVar, Union, cast, final 

41 

42import pydantic 

43 

44from ._base import InvalidQueryError, QueryTreeBase 

45from ._column_expression import ColumnExpression 

46 

47if TYPE_CHECKING: 

48 from ..visitors import PredicateVisitFlags, PredicateVisitor 

49 from ._column_set import ColumnSet 

50 from ._query_tree import QueryTree 

51 

52ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"] 

53 

54 

55_L = TypeVar("_L") 

56_A = TypeVar("_A") 

57_O = TypeVar("_O") 

58 

59 

60class PredicateLeafBase(QueryTreeBase, ABC): 

61 """Base class for leaf nodes of the `Predicate` tree. 

62 

63 This is a closed hierarchy whose concrete, `~typing.final` derived classes 

64 are members of the `PredicateLeaf` union. That union should generally 

65 be used in type annotations rather than the technically-open base class. 

66 """ 

67 

68 @abstractmethod 

69 def gather_required_columns(self, columns: ColumnSet) -> None: 

70 """Add any columns required to evaluate this predicate leaf to the 

71 given column set. 

72 

73 Parameters 

74 ---------- 

75 columns : `ColumnSet` 

76 Set of columns to modify in place. 

77 """ 

78 raise NotImplementedError() 

79 

80 def invert(self) -> PredicateLeaf: 

81 """Return a new leaf that is the logical not of this one.""" 

82 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self)) 

83 

84 @abstractmethod 

85 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

86 """Invoke the visitor interface. 

87 

88 Parameters 

89 ---------- 

90 visitor : `PredicateVisitor` 

91 Visitor to invoke a method on. 

92 flags : `PredicateVisitFlags` 

93 Flags that provide information about where this leaf appears in the 

94 larger predicate tree. 

95 

96 Returns 

97 ------- 

98 result : `object` 

99 Forwarded result from the visitor. 

100 """ 

101 raise NotImplementedError() 

102 

103 

104@final 

105class Predicate(QueryTreeBase): 

106 """A boolean column expression. 

107 

108 Notes 

109 ----- 

110 Predicate is the only class representing a boolean column expression that 

111 should be used outside of this module (though the objects it nests appear 

112 in its serialized form and hence are not fully private). It provides 

113 several `classmethod` factories for constructing those nested types inside 

114 a `Predicate` instance, and `PredicateVisitor` subclasses should be used 

115 to process them. 

116 """ 

117 

118 operands: PredicateOperands 

119 """Nested tuple of operands, with outer items combined via AND and inner 

120 items combined via OR. 

121 """ 

122 

123 @property 

124 def column_type(self) -> Literal["bool"]: 

125 """A string enumeration value representing the type of the column 

126 expression. 

127 """ 

128 return "bool" 

129 

130 @classmethod 

131 def from_bool(cls, value: bool) -> Predicate: 

132 """Construct a predicate that always evaluates to `True` or `False`. 

133 

134 Parameters 

135 ---------- 

136 value : `bool` 

137 Value the predicate should evaluate to. 

138 

139 Returns 

140 ------- 

141 predicate : `Predicate` 

142 Predicate that evaluates to the given boolean value. 

143 """ 

144 # The values for True and False here make sense if you think about 

145 # calling `all` and `any` with empty sequences; note that the 

146 # `self.operands` attribute is evaluated as: 

147 # 

148 # value = all(any(or_group) for or_group in self.operands) 

149 # 

150 return cls.model_construct(operands=() if value else ((),)) 

151 

152 @classmethod 

153 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: 

154 """Construct a predicate representing a binary comparison between 

155 two non-boolean column expressions. 

156 

157 Parameters 

158 ---------- 

159 a : `ColumnExpression` 

160 First column expression in the comparison. 

161 operator : `str` 

162 Enumerated string representing the comparison operator to apply. 

163 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". 

164 b : `ColumnExpression` 

165 Second column expression in the comparison. 

166 

167 Returns 

168 ------- 

169 predicate : `Predicate` 

170 Predicate representing the comparison. 

171 """ 

172 return cls._from_leaf(Comparison(a=a, operator=operator, b=b)) 

173 

174 @classmethod 

175 def is_null(cls, operand: ColumnExpression) -> Predicate: 

176 """Construct a predicate that tests whether a column expression is 

177 NULL. 

178 

179 Parameters 

180 ---------- 

181 operand : `ColumnExpression` 

182 Column expression to test. 

183 

184 Returns 

185 ------- 

186 predicate : `Predicate` 

187 Predicate representing the NULL check. 

188 """ 

189 return cls._from_leaf(IsNull(operand=operand)) 

190 

191 @classmethod 

192 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate: 

193 """Construct a predicate that tests whether one column expression is 

194 a member of a container of other column expressions. 

195 

196 Parameters 

197 ---------- 

198 member : `ColumnExpression` 

199 Column expression that may be a member of the container. 

200 container : `~collections.abc.Iterable` [ `ColumnExpression` ] 

201 Container of column expressions to test for membership in. 

202 

203 Returns 

204 ------- 

205 predicate : `Predicate` 

206 Predicate representing the membership test. 

207 """ 

208 return cls._from_leaf(InContainer(member=member, container=tuple(container))) 

209 

210 @classmethod 

211 def in_range( 

212 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1 

213 ) -> Predicate: 

214 """Construct a predicate that tests whether an integer column 

215 expression is part of a strided range. 

216 

217 Parameters 

218 ---------- 

219 member : `ColumnExpression` 

220 Column expression that may be a member of the range. 

221 start : `int`, optional 

222 Beginning of the range, inclusive. 

223 stop : `int` or `None`, optional 

224 End of the range, exclusive. 

225 step : `int`, optional 

226 Offset between values in the range. 

227 

228 Returns 

229 ------- 

230 predicate : `Predicate` 

231 Predicate representing the membership test. 

232 """ 

233 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step)) 

234 

235 @classmethod 

236 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate: 

237 """Construct a predicate that tests whether a column expression is 

238 present in a single-column projection of a query tree. 

239 

240 Parameters 

241 ---------- 

242 member : `ColumnExpression` 

243 Column expression that may be present in the query. 

244 column : `ColumnExpression` 

245 Column to project from the query. 

246 query_tree : `QueryTree` 

247 Query tree to select from. 

248 

249 Returns 

250 ------- 

251 predicate : `Predicate` 

252 Predicate representing the membership test. 

253 """ 

254 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree)) 

255 

256 def gather_required_columns(self, columns: ColumnSet) -> None: 

257 """Add any columns required to evaluate this predicate to the given 

258 column set. 

259 

260 Parameters 

261 ---------- 

262 columns : `ColumnSet` 

263 Set of columns to modify in place. 

264 """ 

265 for or_group in self.operands: 

266 for operand in or_group: 

267 operand.gather_required_columns(columns) 

268 

269 def logical_and(self, *args: Predicate) -> Predicate: 

270 """Construct a predicate representing the logical AND of this predicate 

271 and one or more others. 

272 

273 Parameters 

274 ---------- 

275 *args : `Predicate` 

276 Other predicates. 

277 

278 Returns 

279 ------- 

280 predicate : `Predicate` 

281 Predicate representing the logical AND. 

282 """ 

283 operands = self.operands 

284 for arg in args: 

285 operands = self._impl_and(operands, arg.operands) 

286 if not all(operands): 

287 # If any item in operands is an empty tuple (i.e. False), simplify. 

288 operands = ((),) 

289 return Predicate.model_construct(operands=operands) 

290 

291 def logical_or(self, *args: Predicate) -> Predicate: 

292 """Construct a predicate representing the logical OR of this predicate 

293 and one or more others. 

294 

295 Parameters 

296 ---------- 

297 *args : `Predicate` 

298 Other predicates. 

299 

300 Returns 

301 ------- 

302 predicate : `Predicate` 

303 Predicate representing the logical OR. 

304 """ 

305 operands = self.operands 

306 for arg in args: 

307 operands = self._impl_or(operands, arg.operands) 

308 return Predicate.model_construct(operands=operands) 

309 

310 def logical_not(self) -> Predicate: 

311 """Construct a predicate representing the logical NOT of this 

312 predicate. 

313 

314 Returns 

315 ------- 

316 predicate : `Predicate` 

317 Predicate representing the logical NOT. 

318 """ 

319 new_operands: PredicateOperands = ((),) 

320 for or_group in self.operands: 

321 new_group: PredicateOperands = () 

322 for leaf in or_group: 

323 new_group = self._impl_and(new_group, ((leaf.invert(),),)) 

324 new_operands = self._impl_or(new_operands, new_group) 

325 return Predicate.model_construct(operands=new_operands) 

326 

327 def __str__(self) -> str: 

328 and_terms = [] 

329 for or_group in self.operands: 

330 match len(or_group): 

331 case 0: 

332 and_terms.append("False") 

333 case 1: 

334 and_terms.append(str(or_group[0])) 

335 case _: 

336 or_str = " OR ".join(str(operand) for operand in or_group) 

337 if len(self.operands) > 1: 

338 and_terms.append(f"({or_str})") 

339 else: 

340 and_terms.append(or_str) 

341 if not and_terms: 

342 return "True" 

343 return " AND ".join(and_terms) 

344 

345 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A: 

346 """Invoke the visitor interface. 

347 

348 Parameters 

349 ---------- 

350 visitor : `PredicateVisitor` 

351 Visitor to invoke a method on. 

352 

353 Returns 

354 ------- 

355 result : `object` 

356 Forwarded result from the visitor. 

357 """ 

358 return visitor._visit_logical_and(self.operands) 

359 

360 @classmethod 

361 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate: 

362 return cls._from_or_group((leaf,)) 

363 

364 @classmethod 

365 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate: 

366 return Predicate.model_construct(operands=(or_group,)) 

367 

368 @classmethod 

369 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

370 # We could simplify cases where both sides have some of the same leaf 

371 # expressions; even using 'is' tests would simplify some cases where 

372 # converting to conjunctive normal form twice leads to a lot of 

373 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of 

374 # double-negation. Right now those cases seem pathological enough to 

375 # be not worth our time. 

376 return a + b if a is not b else a 

377 

378 @classmethod 

379 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

380 # Same comment re simplification as in _impl_and applies here. 

381 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)]) 

382 

383 

384@final 

385class LogicalNot(PredicateLeafBase): 

386 """A boolean column expression that inverts its operand.""" 

387 

388 predicate_type: Literal["not"] = "not" 

389 

390 operand: LogicalNotOperand 

391 """Upstream boolean expression to invert.""" 

392 

393 def gather_required_columns(self, columns: ColumnSet) -> None: 

394 # Docstring inherited. 

395 self.operand.gather_required_columns(columns) 

396 

397 def __str__(self) -> str: 

398 return f"NOT {self.operand}" 

399 

400 def invert(self) -> LogicalNotOperand: 

401 # Docstring inherited. 

402 return self.operand 

403 

404 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

405 # Docstring inherited. 

406 return visitor._visit_logical_not(self.operand, flags) 

407 

408 

409@final 

410class IsNull(PredicateLeafBase): 

411 """A boolean column expression that tests whether its operand is NULL.""" 

412 

413 predicate_type: Literal["is_null"] = "is_null" 

414 

415 operand: ColumnExpression 

416 """Upstream expression to test.""" 

417 

418 def gather_required_columns(self, columns: ColumnSet) -> None: 

419 # Docstring inherited. 

420 self.operand.gather_required_columns(columns) 

421 

422 def __str__(self) -> str: 

423 return f"{self.operand} IS NULL" 

424 

425 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

426 # Docstring inherited. 

427 return visitor.visit_is_null(self.operand, flags) 

428 

429 

430@final 

431class Comparison(PredicateLeafBase): 

432 """A boolean columns expression formed by comparing two non-boolean 

433 expressions. 

434 """ 

435 

436 predicate_type: Literal["comparison"] = "comparison" 

437 

438 a: ColumnExpression 

439 """Left-hand side expression for the comparison.""" 

440 

441 b: ColumnExpression 

442 """Right-hand side expression for the comparison.""" 

443 

444 operator: ComparisonOperator 

445 """Comparison operator.""" 

446 

447 def gather_required_columns(self, columns: ColumnSet) -> None: 

448 # Docstring inherited. 

449 self.a.gather_required_columns(columns) 

450 self.b.gather_required_columns(columns) 

451 

452 def __str__(self) -> str: 

453 return f"{self.a} {self.operator.upper()} {self.b}" 

454 

455 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

456 # Docstring inherited. 

457 return visitor.visit_comparison(self.a, self.operator, self.b, flags) 

458 

459 @pydantic.model_validator(mode="after") 

460 def _validate_column_types(self) -> Comparison: 

461 if self.a.column_type != self.b.column_type: 

462 raise InvalidQueryError( 

463 f"Column types for comparison {self} do not agree " 

464 f"({self.a.column_type}, {self.b.column_type})." 

465 ) 

466 match (self.operator, self.a.column_type): 

467 case ("==" | "!=", _): 

468 pass 

469 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"): 

470 pass 

471 case ("overlaps", "region" | "timespan"): 

472 pass 

473 case _: 

474 raise InvalidQueryError( 

475 f"Invalid column type {self.a.column_type} for operator {self.operator!r}." 

476 ) 

477 return self 

478 

479 

480@final 

481class InContainer(PredicateLeafBase): 

482 """A boolean column expression that tests whether one expression is a 

483 member of an explicit sequence of other expressions. 

484 """ 

485 

486 predicate_type: Literal["in_container"] = "in_container" 

487 

488 member: ColumnExpression 

489 """Expression to test for membership.""" 

490 

491 container: tuple[ColumnExpression, ...] 

492 """Expressions representing the elements of the container.""" 

493 

494 def gather_required_columns(self, columns: ColumnSet) -> None: 

495 # Docstring inherited. 

496 self.member.gather_required_columns(columns) 

497 for item in self.container: 

498 item.gather_required_columns(columns) 

499 

500 def __str__(self) -> str: 

501 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]" 

502 

503 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

504 # Docstring inherited. 

505 return visitor.visit_in_container(self.member, self.container, flags) 

506 

507 @pydantic.model_validator(mode="after") 

508 def _validate(self) -> InContainer: 

509 if self.member.column_type == "timespan" or self.member.column_type == "region": 

510 raise InvalidQueryError( 

511 f"Timespan or region column {self.member} may not be used in IN expressions." 

512 ) 

513 if not all(item.column_type == self.member.column_type for item in self.container): 

514 raise InvalidQueryError(f"Column types for membership test {self} do not agree.") 

515 return self 

516 

517 

518@final 

519class InRange(PredicateLeafBase): 

520 """A boolean column expression that tests whether its expression is 

521 included in an integer range. 

522 """ 

523 

524 predicate_type: Literal["in_range"] = "in_range" 

525 

526 member: ColumnExpression 

527 """Expression to test for membership.""" 

528 

529 start: int = 0 

530 """Inclusive lower bound for the range.""" 

531 

532 stop: int | None = None 

533 """Exclusive upper bound for the range.""" 

534 

535 step: int = 1 

536 """Difference between values in the range.""" 

537 

538 def gather_required_columns(self, columns: ColumnSet) -> None: 

539 # Docstring inherited. 

540 self.member.gather_required_columns(columns) 

541 

542 def __str__(self) -> str: 

543 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}" 

544 if self.step != 1: 

545 s = f"{s}:{self.step}" 

546 return f"{self.member} IN {s}" 

547 

548 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

549 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags) 

550 

551 @pydantic.model_validator(mode="after") 

552 def _validate(self) -> InRange: 

553 if self.member.column_type != "int": 

554 raise InvalidQueryError(f"Column {self.member} is not an integer.") 

555 if self.step < 1: 

556 raise InvalidQueryError("Range step must be >= 1.") 

557 if self.stop is not None and self.stop < self.start: 

558 raise InvalidQueryError("Range stop must be >= start.") 

559 return self 

560 

561 

562@final 

563class InQuery(PredicateLeafBase): 

564 """A boolean column expression that tests whether its expression is 

565 included single-column projection of a relation. 

566 

567 This is primarily intended to be used on dataset ID columns, but it may 

568 be useful for other columns as well. 

569 """ 

570 

571 predicate_type: Literal["in_query"] = "in_query" 

572 

573 member: ColumnExpression 

574 """Expression to test for membership.""" 

575 

576 column: ColumnExpression 

577 """Expression to extract from `query_tree`.""" 

578 

579 query_tree: QueryTree 

580 """Relation whose rows from `column` represent the container.""" 

581 

582 def gather_required_columns(self, columns: ColumnSet) -> None: 

583 # Docstring inherited. 

584 # We're only gathering columns from the query_tree this predicate is 

585 # attached to, not `self.column`, which belongs to `self.query_tree`. 

586 self.member.gather_required_columns(columns) 

587 

588 def __str__(self) -> str: 

589 return f"{self.member} IN (query).{self.column}" 

590 

591 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

592 # Docstring inherited. 

593 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags) 

594 

595 @pydantic.model_validator(mode="after") 

596 def _validate_column_types(self) -> InQuery: 

597 if self.member.column_type == "timespan" or self.member.column_type == "region": 

598 raise InvalidQueryError( 

599 f"Timespan or region column {self.member} may not be used in IN expressions." 

600 ) 

601 if self.member.column_type != self.column.column_type: 

602 raise InvalidQueryError( 

603 f"Column types for membership test {self} do not agree " 

604 f"({self.member.column_type}, {self.column.column_type})." 

605 ) 

606 

607 from ._column_set import ColumnSet 

608 

609 columns_required_in_tree = ColumnSet(self.query_tree.dimensions) 

610 self.column.gather_required_columns(columns_required_in_tree) 

611 if columns_required_in_tree.dimensions != self.query_tree.dimensions: 

612 raise InvalidQueryError( 

613 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, " 

614 f"but query tree only has {self.query_tree.dimensions}." 

615 ) 

616 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys(): 

617 raise InvalidQueryError( 

618 f"Column {self.column} requires dataset types " 

619 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree." 

620 ) 

621 return self 

622 

623 

624LogicalNotOperand: TypeAlias = Union[ 

625 IsNull, 

626 Comparison, 

627 InContainer, 

628 InRange, 

629 InQuery, 

630] 

631PredicateLeaf: TypeAlias = Annotated[ 

632 Union[LogicalNotOperand, LogicalNot], pydantic.Field(discriminator="predicate_type") 

633] 

634 

635PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]