Coverage for python/lsst/daf/butler/queries/tree/_predicate.py: 51%

235 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-05 11:36 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ( 

31 "Predicate", 

32 "PredicateLeaf", 

33 "LogicalNotOperand", 

34 "PredicateOperands", 

35 "ComparisonOperator", 

36) 

37 

38import itertools 

39from abc import ABC, abstractmethod 

40from typing import TYPE_CHECKING, Annotated, Iterable, Literal, TypeAlias, TypeVar, Union, cast, final 

41 

42import pydantic 

43 

44from ._base import InvalidQueryError, QueryTreeBase 

45from ._column_expression import ColumnExpression 

46 

47if TYPE_CHECKING: 

48 from ..visitors import PredicateVisitFlags, PredicateVisitor 

49 from ._column_set import ColumnSet 

50 from ._query_tree import QueryTree 

51 

52ComparisonOperator: TypeAlias = Literal["==", "!=", "<", ">", ">=", "<=", "overlaps"] 

53 

54 

55_L = TypeVar("_L") 

56_A = TypeVar("_A") 

57_O = TypeVar("_O") 

58 

59 

60class PredicateLeafBase(QueryTreeBase, ABC): 

61 """Base class for leaf nodes of the `Predicate` tree. 

62 

63 This is a closed hierarchy whose concrete, `~typing.final` derived classes 

64 are members of the `PredicateLeaf` union. That union should generally 

65 be used in type annotations rather than the technically-open base class. 

66 """ 

67 

68 @abstractmethod 

69 def gather_required_columns(self, columns: ColumnSet) -> None: 

70 """Add any columns required to evaluate this predicate leaf to the 

71 given column set. 

72 

73 Parameters 

74 ---------- 

75 columns : `ColumnSet` 

76 Set of columns to modify in place. 

77 """ 

78 raise NotImplementedError() 

79 

80 def invert(self) -> PredicateLeaf: 

81 """Return a new leaf that is the logical not of this one.""" 

82 return LogicalNot.model_construct(operand=cast("LogicalNotOperand", self)) 

83 

84 @abstractmethod 

85 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

86 """Invoke the visitor interface. 

87 

88 Parameters 

89 ---------- 

90 visitor : `PredicateVisitor` 

91 Visitor to invoke a method on. 

92 flags : `PredicateVisitFlags` 

93 Flags that provide information about where this leaf appears in the 

94 larger predicate tree. 

95 

96 Returns 

97 ------- 

98 result : `object` 

99 Forwarded result from the visitor. 

100 """ 

101 raise NotImplementedError() 

102 

103 

104@final 

105class Predicate(QueryTreeBase): 

106 """A boolean column expression. 

107 

108 Notes 

109 ----- 

110 Predicate is the only class representing a boolean column expression that 

111 should be used outside of this module (though the objects it nests appear 

112 in its serialized form and hence are not fully private). It provides 

113 several `classmethod` factories for constructing those nested types inside 

114 a `Predicate` instance, and `PredicateVisitor` subclasses should be used 

115 to process them. 

116 """ 

117 

118 operands: PredicateOperands 

119 """Nested tuple of operands, with outer items combined via AND and inner 

120 items combined via OR. 

121 """ 

122 

123 @property 

124 def column_type(self) -> Literal["bool"]: 

125 """A string enumeration value representing the type of the column 

126 expression. 

127 """ 

128 return "bool" 

129 

130 @classmethod 

131 def from_bool(cls, value: bool) -> Predicate: 

132 """Construct a predicate that always evaluates to `True` or `False`. 

133 

134 Parameters 

135 ---------- 

136 value : `bool` 

137 Value the predicate should evaluate to. 

138 

139 Returns 

140 ------- 

141 predicate : `Predicate` 

142 Predicate that evaluates to the given boolean value. 

143 """ 

144 return cls.model_construct(operands=() if value else ((),)) 

145 

146 @classmethod 

147 def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate: 

148 """Construct a predicate representing a binary comparison between 

149 two non-boolean column expressions. 

150 

151 Parameters 

152 ---------- 

153 a : `ColumnExpression` 

154 First column expression in the comparison. 

155 operator : `str` 

156 Enumerated string representing the comparison operator to apply. 

157 May be and of "==", "!=", "<", ">", "<=", ">=", or "overlaps". 

158 b : `ColumnExpression` 

159 Second column expression in the comparison. 

160 

161 Returns 

162 ------- 

163 predicate : `Predicate` 

164 Predicate representing the comparison. 

165 """ 

166 return cls._from_leaf(Comparison(a=a, operator=operator, b=b)) 

167 

168 @classmethod 

169 def is_null(cls, operand: ColumnExpression) -> Predicate: 

170 """Construct a predicate that tests whether a column expression is 

171 NULL. 

172 

173 Parameters 

174 ---------- 

175 operand : `ColumnExpression` 

176 Column expression to test. 

177 

178 Returns 

179 ------- 

180 predicate : `Predicate` 

181 Predicate representing the NULL check. 

182 """ 

183 return cls._from_leaf(IsNull(operand=operand)) 

184 

185 @classmethod 

186 def in_container(cls, member: ColumnExpression, container: Iterable[ColumnExpression]) -> Predicate: 

187 """Construct a predicate that tests whether one column expression is 

188 a member of a container of other column expressions. 

189 

190 Parameters 

191 ---------- 

192 member : `ColumnExpression` 

193 Column expression that may be a member of the container. 

194 container : `~collections.abc.Iterable` [ `ColumnExpression` ] 

195 Container of column expressions to test for membership in. 

196 

197 Returns 

198 ------- 

199 predicate : `Predicate` 

200 Predicate representing the membership test. 

201 """ 

202 return cls._from_leaf(InContainer(member=member, container=tuple(container))) 

203 

204 @classmethod 

205 def in_range( 

206 cls, member: ColumnExpression, start: int = 0, stop: int | None = None, step: int = 1 

207 ) -> Predicate: 

208 """Construct a predicate that tests whether an integer column 

209 expression is part of a strided range. 

210 

211 Parameters 

212 ---------- 

213 member : `ColumnExpression` 

214 Column expression that may be a member of the range. 

215 start : `int`, optional 

216 Beginning of the range, inclusive. 

217 stop : `int` or `None`, optional 

218 End of the range, exclusive. 

219 step : `int`, optional 

220 Offset between values in the range. 

221 

222 Returns 

223 ------- 

224 predicate : `Predicate` 

225 Predicate representing the membership test. 

226 """ 

227 return cls._from_leaf(InRange(member=member, start=start, stop=stop, step=step)) 

228 

229 @classmethod 

230 def in_query(cls, member: ColumnExpression, column: ColumnExpression, query_tree: QueryTree) -> Predicate: 

231 """Construct a predicate that tests whether a column expression is 

232 present in a single-column projection of a query tree. 

233 

234 Parameters 

235 ---------- 

236 member : `ColumnExpression` 

237 Column expression that may be present in the query. 

238 column : `ColumnExpression` 

239 Column to project from the query. 

240 query_tree : `QueryTree` 

241 Query tree to select from. 

242 

243 Returns 

244 ------- 

245 predicate : `Predicate` 

246 Predicate representing the membership test. 

247 """ 

248 return cls._from_leaf(InQuery(member=member, column=column, query_tree=query_tree)) 

249 

250 def gather_required_columns(self, columns: ColumnSet) -> None: 

251 """Add any columns required to evaluate this predicate to the given 

252 column set. 

253 

254 Parameters 

255 ---------- 

256 columns : `ColumnSet` 

257 Set of columns to modify in place. 

258 """ 

259 for or_group in self.operands: 

260 for operand in or_group: 

261 operand.gather_required_columns(columns) 

262 

263 def logical_and(self, *args: Predicate) -> Predicate: 

264 """Construct a predicate representing the logical AND of this predicate 

265 and one or more others. 

266 

267 Parameters 

268 ---------- 

269 *args : `Predicate` 

270 Other predicates. 

271 

272 Returns 

273 ------- 

274 predicate : `Predicate` 

275 Predicate representing the logical AND. 

276 """ 

277 operands = self.operands 

278 for arg in args: 

279 operands = self._impl_and(operands, arg.operands) 

280 if not all(operands): 

281 # If any item in operands is an empty tuple (i.e. False), simplify. 

282 operands = ((),) 

283 return Predicate.model_construct(operands=operands) 

284 

285 def logical_or(self, *args: Predicate) -> Predicate: 

286 """Construct a predicate representing the logical OR of this predicate 

287 and one or more others. 

288 

289 Parameters 

290 ---------- 

291 *args : `Predicate` 

292 Other predicates. 

293 

294 Returns 

295 ------- 

296 predicate : `Predicate` 

297 Predicate representing the logical OR. 

298 """ 

299 operands = self.operands 

300 for arg in args: 

301 operands = self._impl_or(operands, arg.operands) 

302 return Predicate.model_construct(operands=operands) 

303 

304 def logical_not(self) -> Predicate: 

305 """Construct a predicate representing the logical NOT of this 

306 predicate. 

307 

308 Returns 

309 ------- 

310 predicate : `Predicate` 

311 Predicate representing the logical NOT. 

312 """ 

313 new_operands: PredicateOperands = ((),) 

314 for or_group in self.operands: 

315 new_group: PredicateOperands = () 

316 for leaf in or_group: 

317 new_group = self._impl_and(new_group, ((leaf.invert(),),)) 

318 new_operands = self._impl_or(new_operands, new_group) 

319 return Predicate.model_construct(operands=new_operands) 

320 

321 def __str__(self) -> str: 

322 and_terms = [] 

323 for or_group in self.operands: 

324 match len(or_group): 

325 case 0: 

326 and_terms.append("False") 

327 case 1: 

328 and_terms.append(str(or_group[0])) 

329 case _: 

330 or_str = " OR ".join(str(operand) for operand in or_group) 

331 if len(self.operands) > 1: 

332 and_terms.append(f"({or_str})") 

333 else: 

334 and_terms.append(or_str) 

335 if not and_terms: 

336 return "True" 

337 return " AND ".join(and_terms) 

338 

339 def visit(self, visitor: PredicateVisitor[_A, _O, _L]) -> _A: 

340 """Invoke the visitor interface. 

341 

342 Parameters 

343 ---------- 

344 visitor : `PredicateVisitor` 

345 Visitor to invoke a method on. 

346 

347 Returns 

348 ------- 

349 result : `object` 

350 Forwarded result from the visitor. 

351 """ 

352 return visitor._visit_logical_and(self.operands) 

353 

354 @classmethod 

355 def _from_leaf(cls, leaf: PredicateLeaf) -> Predicate: 

356 return cls._from_or_group((leaf,)) 

357 

358 @classmethod 

359 def _from_or_group(cls, or_group: tuple[PredicateLeaf, ...]) -> Predicate: 

360 return Predicate.model_construct(operands=(or_group,)) 

361 

362 @classmethod 

363 def _impl_and(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

364 # We could simplify cases where both sides have some of the same leaf 

365 # expressions; even using 'is' tests would simplify some cases where 

366 # converting to conjunctive normal form twice leads to a lot of 

367 # duplication, e.g. NOT ((A AND B) OR (C AND D)) or any kind of 

368 # double-negation. Right now those cases seem pathological enough to 

369 # be not worth our time. 

370 return a + b if a is not b else a 

371 

372 @classmethod 

373 def _impl_or(cls, a: PredicateOperands, b: PredicateOperands) -> PredicateOperands: 

374 # Same comment re simplification as in _impl_and applies here. 

375 return tuple([a_operand + b_operand for a_operand, b_operand in itertools.product(a, b)]) 

376 

377 

378@final 

379class LogicalNot(PredicateLeafBase): 

380 """A boolean column expression that inverts its operand.""" 

381 

382 predicate_type: Literal["not"] = "not" 

383 

384 operand: LogicalNotOperand 

385 """Upstream boolean expression to invert.""" 

386 

387 def gather_required_columns(self, columns: ColumnSet) -> None: 

388 # Docstring inherited. 

389 self.operand.gather_required_columns(columns) 

390 

391 def __str__(self) -> str: 

392 return f"NOT {self.operand}" 

393 

394 def invert(self) -> LogicalNotOperand: 

395 # Docstring inherited. 

396 return self.operand 

397 

398 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

399 # Docstring inherited. 

400 return visitor._visit_logical_not(self.operand, flags) 

401 

402 

403@final 

404class IsNull(PredicateLeafBase): 

405 """A boolean column expression that tests whether its operand is NULL.""" 

406 

407 predicate_type: Literal["is_null"] = "is_null" 

408 

409 operand: ColumnExpression 

410 """Upstream expression to test.""" 

411 

412 def gather_required_columns(self, columns: ColumnSet) -> None: 

413 # Docstring inherited. 

414 self.operand.gather_required_columns(columns) 

415 

416 def __str__(self) -> str: 

417 return f"{self.operand} IS NULL" 

418 

419 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

420 # Docstring inherited. 

421 return visitor.visit_is_null(self.operand, flags) 

422 

423 

424@final 

425class Comparison(PredicateLeafBase): 

426 """A boolean columns expression formed by comparing two non-boolean 

427 expressions. 

428 """ 

429 

430 predicate_type: Literal["comparison"] = "comparison" 

431 

432 a: ColumnExpression 

433 """Left-hand side expression for the comparison.""" 

434 

435 b: ColumnExpression 

436 """Right-hand side expression for the comparison.""" 

437 

438 operator: ComparisonOperator 

439 """Comparison operator.""" 

440 

441 def gather_required_columns(self, columns: ColumnSet) -> None: 

442 # Docstring inherited. 

443 self.a.gather_required_columns(columns) 

444 self.b.gather_required_columns(columns) 

445 

446 def __str__(self) -> str: 

447 return f"{self.a} {self.operator.upper()} {self.b}" 

448 

449 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

450 # Docstring inherited. 

451 return visitor.visit_comparison(self.a, self.operator, self.b, flags) 

452 

453 @pydantic.model_validator(mode="after") 

454 def _validate_column_types(self) -> Comparison: 

455 if self.a.column_type != self.b.column_type: 

456 raise InvalidQueryError( 

457 f"Column types for comparison {self} do not agree " 

458 f"({self.a.column_type}, {self.b.column_type})." 

459 ) 

460 match (self.operator, self.a.column_type): 

461 case ("==" | "!=", _): 

462 pass 

463 case ("<" | ">" | ">=" | "<=", "int" | "string" | "float" | "datetime"): 

464 pass 

465 case ("overlaps", "region" | "timespan"): 

466 pass 

467 case _: 

468 raise InvalidQueryError( 

469 f"Invalid column type {self.a.column_type} for operator {self.operator!r}." 

470 ) 

471 return self 

472 

473 

474@final 

475class InContainer(PredicateLeafBase): 

476 """A boolean column expression that tests whether one expression is a 

477 member of an explicit sequence of other expressions. 

478 """ 

479 

480 predicate_type: Literal["in_container"] = "in_container" 

481 

482 member: ColumnExpression 

483 """Expression to test for membership.""" 

484 

485 container: tuple[ColumnExpression, ...] 

486 """Expressions representing the elements of the container.""" 

487 

488 def gather_required_columns(self, columns: ColumnSet) -> None: 

489 # Docstring inherited. 

490 self.member.gather_required_columns(columns) 

491 for item in self.container: 

492 item.gather_required_columns(columns) 

493 

494 def __str__(self) -> str: 

495 return f"{self.member} IN [{', '.join(str(item) for item in self.container)}]" 

496 

497 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

498 # Docstring inherited. 

499 return visitor.visit_in_container(self.member, self.container, flags) 

500 

501 @pydantic.model_validator(mode="after") 

502 def _validate(self) -> InContainer: 

503 if self.member.column_type == "timespan" or self.member.column_type == "region": 

504 raise InvalidQueryError( 

505 f"Timespan or region column {self.member} may not be used in IN expressions." 

506 ) 

507 if not all(item.column_type == self.member.column_type for item in self.container): 

508 raise InvalidQueryError(f"Column types for membership test {self} do not agree.") 

509 return self 

510 

511 

512@final 

513class InRange(PredicateLeafBase): 

514 """A boolean column expression that tests whether its expression is 

515 included in an integer range. 

516 """ 

517 

518 predicate_type: Literal["in_range"] = "in_range" 

519 

520 member: ColumnExpression 

521 """Expression to test for membership.""" 

522 

523 start: int = 0 

524 """Inclusive lower bound for the range.""" 

525 

526 stop: int | None = None 

527 """Exclusive upper bound for the range.""" 

528 

529 step: int = 1 

530 """Difference between values in the range.""" 

531 

532 def gather_required_columns(self, columns: ColumnSet) -> None: 

533 # Docstring inherited. 

534 self.member.gather_required_columns(columns) 

535 

536 def __str__(self) -> str: 

537 s = f"{self.start if self.start else ''}:{self.stop if self.stop is not None else ''}" 

538 if self.step != 1: 

539 s = f"{s}:{self.step}" 

540 return f"{self.member} IN {s}" 

541 

542 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

543 return visitor.visit_in_range(self.member, self.start, self.stop, self.step, flags) 

544 

545 @pydantic.model_validator(mode="after") 

546 def _validate(self) -> InRange: 

547 if self.member.column_type != "int": 

548 raise InvalidQueryError(f"Column {self.member} is not an integer.") 

549 if self.step < 1: 

550 raise InvalidQueryError("Range step must be >= 1.") 

551 if self.stop is not None and self.stop < self.start: 

552 raise InvalidQueryError("Range stop must be >= start.") 

553 return self 

554 

555 

556@final 

557class InQuery(PredicateLeafBase): 

558 """A boolean column expression that tests whether its expression is 

559 included single-column projection of a relation. 

560 

561 This is primarily intended to be used on dataset ID columns, but it may 

562 be useful for other columns as well. 

563 """ 

564 

565 predicate_type: Literal["in_query"] = "in_query" 

566 

567 member: ColumnExpression 

568 """Expression to test for membership.""" 

569 

570 column: ColumnExpression 

571 """Expression to extract from `query_tree`.""" 

572 

573 query_tree: QueryTree 

574 """Relation whose rows from `column` represent the container.""" 

575 

576 def gather_required_columns(self, columns: ColumnSet) -> None: 

577 # Docstring inherited. 

578 # We're only gathering columns from the query_tree this predicate is 

579 # attached to, not `self.column`, which belongs to `self.query_tree`. 

580 self.member.gather_required_columns(columns) 

581 

582 def __str__(self) -> str: 

583 return f"{self.member} IN (query).{self.column}" 

584 

585 def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L: 

586 # Docstring inherited. 

587 return visitor.visit_in_query_tree(self.member, self.column, self.query_tree, flags) 

588 

589 @pydantic.model_validator(mode="after") 

590 def _validate_column_types(self) -> InQuery: 

591 if self.member.column_type == "timespan" or self.member.column_type == "region": 

592 raise InvalidQueryError( 

593 f"Timespan or region column {self.member} may not be used in IN expressions." 

594 ) 

595 if self.member.column_type != self.column.column_type: 

596 raise InvalidQueryError( 

597 f"Column types for membership test {self} do not agree " 

598 f"({self.member.column_type}, {self.column.column_type})." 

599 ) 

600 

601 from ._column_set import ColumnSet 

602 

603 columns_required_in_tree = ColumnSet(self.query_tree.dimensions) 

604 self.column.gather_required_columns(columns_required_in_tree) 

605 if columns_required_in_tree.dimensions != self.query_tree.dimensions: 

606 raise InvalidQueryError( 

607 f"Column {self.column} requires dimensions {columns_required_in_tree.dimensions}, " 

608 f"but query tree only has {self.query_tree.dimensions}." 

609 ) 

610 if not columns_required_in_tree.dataset_fields.keys() <= self.query_tree.datasets.keys(): 

611 raise InvalidQueryError( 

612 f"Column {self.column} requires dataset types " 

613 f"{set(columns_required_in_tree.dataset_fields.keys())} that are not present in query tree." 

614 ) 

615 return self 

616 

617 

618LogicalNotOperand: TypeAlias = Union[ 

619 IsNull, 

620 Comparison, 

621 InContainer, 

622 InRange, 

623 InQuery, 

624] 

625PredicateLeaf: TypeAlias = Annotated[ 

626 Union[LogicalNotOperand, LogicalNot], pydantic.Field(discriminator="predicate_type") 

627] 

628 

629PredicateOperands: TypeAlias = tuple[tuple[PredicateLeaf, ...], ...]