Coverage for python / lsst / daf / butler / queries / expressions / parser / exprTree.py: 27%

269 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 08:43 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Module which defines classes for intermediate representation of the 

29expression tree produced by parser. 

30 

31The purpose of the intermediate representation is to be able to generate 

32same expression as a part of SQL statement with the minimal changes. We 

33will need to be able to replace identifiers in original expression with 

34database-specific identifiers but everything else will probably be sent 

35to database directly. 

36""" 

37 

38from __future__ import annotations 

39 

40__all__ = [ 

41 "BinaryOp", 

42 "BoxNode", 

43 "CircleNode", 

44 "FunctionCall", 

45 "Identifier", 

46 "IsIn", 

47 "Node", 

48 "NumericLiteral", 

49 "Parens", 

50 "PointNode", 

51 "PolygonNode", 

52 "RangeLiteral", 

53 "RegionNode", 

54 "StringLiteral", 

55 "TimeLiteral", 

56 "TupleNode", 

57 "UnaryOp", 

58 "function_call", 

59] 

60 

61# ------------------------------- 

62# Imports of standard modules -- 

63# ------------------------------- 

64from abc import ABC, abstractmethod 

65from typing import TYPE_CHECKING, Any 

66from uuid import UUID 

67 

68# ----------------------------- 

69# Imports for other modules -- 

70# ----------------------------- 

71 

72# ---------------------------------- 

73# Local non-exported definitions -- 

74# ---------------------------------- 

75 

76if TYPE_CHECKING: 

77 import astropy.time 

78 

79 from .treeVisitor import TreeVisitor 

80 

81 

82def _strip_parens(expression: Node) -> Node: 

83 """Strip all parentheses from an expression.""" 

84 if isinstance(expression, Parens): 

85 return _strip_parens(expression.expr) 

86 return expression 

87 

88 

89# ------------------------ 

90# Exported definitions -- 

91# ------------------------ 

92 

93 

94class Node(ABC): 

95 """Base class of IR node in expression tree. 

96 

97 The purpose of this class is to simplify visiting of the 

98 all nodes in a tree. It has a list of sub-nodes of this 

99 node so that visiting code can navigate whole tree without 

100 knowing exact types of each node. 

101 

102 Parameters 

103 ---------- 

104 children : tuple of :py:class:`Node` 

105 Possibly empty list of sub-nodes. 

106 """ 

107 

108 def __init__(self, children: tuple[Node, ...] | None = None): 

109 self.children = tuple(children or ()) 

110 

111 @abstractmethod 

112 def visit(self, visitor: TreeVisitor) -> Any: 

113 """Implement Visitor pattern for parsed tree. 

114 

115 Parameters 

116 ---------- 

117 visitor : `TreeVisitor` 

118 Instance of visitor type. 

119 """ 

120 

121 

122class LiteralNode(Node): 

123 """Intermediate base class for nodes representing literals of any knid.""" 

124 

125 

126class BinaryOp(Node): 

127 """Node representing binary operator. 

128 

129 This class is used for representing all binary operators including 

130 arithmetic and boolean operations. 

131 

132 Parameters 

133 ---------- 

134 lhs : `Node` 

135 Left-hand side of the operation. 

136 op : `str` 

137 Operator name, e.g. '+', 'OR'. 

138 rhs : `Node` 

139 Right-hand side of the operation. 

140 """ 

141 

142 def __init__(self, lhs: Node, op: str, rhs: Node): 

143 super().__init__((lhs, rhs)) 

144 self.lhs = lhs 

145 self.op = op 

146 self.rhs = rhs 

147 

148 def visit(self, visitor: TreeVisitor) -> Any: 

149 # Docstring inherited from Node.visit 

150 lhs = self.lhs.visit(visitor) 

151 rhs = self.rhs.visit(visitor) 

152 return visitor.visitBinaryOp(self.op, lhs, rhs, self) 

153 

154 def __str__(self) -> str: 

155 return "{lhs} {op} {rhs}".format(**vars(self)) 

156 

157 

158class UnaryOp(Node): 

159 """Node representing unary operator. 

160 

161 This class is used for representing all unary operators including 

162 arithmetic and boolean operations. 

163 

164 Parameters 

165 ---------- 

166 op : `str` 

167 Operator name, e.g. '+', 'NOT'. 

168 operand : `Node` 

169 Operand. 

170 """ 

171 

172 def __init__(self, op: str, operand: Node): 

173 super().__init__((operand,)) 

174 self.op = op 

175 self.operand = operand 

176 

177 def visit(self, visitor: TreeVisitor) -> Any: 

178 # Docstring inherited from Node.visit 

179 operand = self.operand.visit(visitor) 

180 return visitor.visitUnaryOp(self.op, operand, self) 

181 

182 def __str__(self) -> str: 

183 return "{op} {operand}".format(**vars(self)) 

184 

185 

186class StringLiteral(LiteralNode): 

187 """Node representing string literal. 

188 

189 Parameters 

190 ---------- 

191 value : `str` 

192 Literal value. 

193 """ 

194 

195 def __init__(self, value: str): 

196 super().__init__() 

197 self.value = value 

198 

199 def visit(self, visitor: TreeVisitor) -> Any: 

200 # Docstring inherited from Node.visit 

201 return visitor.visitStringLiteral(self.value, self) 

202 

203 def __str__(self) -> str: 

204 return "'{value}'".format(**vars(self)) 

205 

206 

207class TimeLiteral(LiteralNode): 

208 """Node representing time literal. 

209 

210 Parameters 

211 ---------- 

212 value : `astropy.time.Time` 

213 Literal string value. 

214 """ 

215 

216 def __init__(self, value: astropy.time.Time): 

217 super().__init__() 

218 self.value = value 

219 

220 def visit(self, visitor: TreeVisitor) -> Any: 

221 # Docstring inherited from Node.visit 

222 return visitor.visitTimeLiteral(self.value, self) 

223 

224 def __str__(self) -> str: 

225 return "T'{value}'".format(**vars(self)) 

226 

227 

228class NumericLiteral(LiteralNode): 

229 """Node representing a numeric literal. 

230 

231 We do not convert literals to numbers, their text representation 

232 is stored literally. 

233 

234 Parameters 

235 ---------- 

236 value : str 

237 Literal value. 

238 """ 

239 

240 def __init__(self, value: str): 

241 super().__init__() 

242 self.value = value 

243 

244 def visit(self, visitor: TreeVisitor) -> Any: 

245 # Docstring inherited from Node.visit 

246 return visitor.visitNumericLiteral(self.value, self) 

247 

248 def __str__(self) -> str: 

249 return "{value}".format(**vars(self)) 

250 

251 

252class UuidLiteral(LiteralNode): 

253 """Node representing a UUID literal. 

254 

255 Parameters 

256 ---------- 

257 value : `UUID` 

258 Literal value. 

259 """ 

260 

261 def __init__(self, value: UUID): 

262 super().__init__() 

263 self.value = value 

264 

265 def visit(self, visitor: TreeVisitor) -> Any: 

266 # Docstring inherited from Node.visit 

267 return visitor.visitUuidLiteral(self.value, self) 

268 

269 def __str__(self) -> str: 

270 return f"UUID('{self.value}')" 

271 

272 

273class Identifier(Node): 

274 """Node representing identifier. 

275 

276 Value of the identifier is its name, it may contain zero, one, or two dot 

277 characters. 

278 

279 Parameters 

280 ---------- 

281 name : str 

282 Identifier name. 

283 """ 

284 

285 def __init__(self, name: str): 

286 super().__init__() 

287 self.name = name 

288 

289 def visit(self, visitor: TreeVisitor) -> Any: 

290 # Docstring inherited from Node.visit 

291 return visitor.visitIdentifier(self.name, self) 

292 

293 def __str__(self) -> str: 

294 return "{name}".format(**vars(self)) 

295 

296 

297class BindName(Node): 

298 """Node representing a bind name. 

299 

300 Value of the bind is its name, which is a simple identifier. 

301 

302 Parameters 

303 ---------- 

304 name : str 

305 Bind name. 

306 """ 

307 

308 def __init__(self, name: str): 

309 super().__init__() 

310 self.name = name 

311 

312 def visit(self, visitor: TreeVisitor) -> Any: 

313 # Docstring inherited from Node.visit 

314 return visitor.visitBind(self.name, self) 

315 

316 def __str__(self) -> str: 

317 return ":{name}".format(**vars(self)) 

318 

319 

320class RangeLiteral(LiteralNode): 

321 """Node representing range literal appearing in `IN` list. 

322 

323 Range literal defines a range of integer numbers with start and 

324 end of the range (with inclusive end) and optional stride value 

325 (default is 1). 

326 

327 Parameters 

328 ---------- 

329 start : `int` 

330 Start value of a range. 

331 stop : `int` 

332 End value of a range, inclusive, same or higher than ``start``. 

333 stride : `int` or `None`, optional 

334 Stride value, must be positive, can be `None` which means that stride 

335 was not specified. Consumers are supposed to treat `None` the same way 

336 as stride=1 but for some consumers it may be useful to know that 

337 stride was missing from literal. 

338 """ 

339 

340 def __init__(self, start: int, stop: int, stride: int | None = None): 

341 super().__init__() 

342 self.start = start 

343 self.stop = stop 

344 self.stride = stride 

345 

346 def visit(self, visitor: TreeVisitor) -> Any: 

347 # Docstring inherited from Node.visit 

348 return visitor.visitRangeLiteral(self.start, self.stop, self.stride, self) 

349 

350 def __str__(self) -> str: 

351 res = f"{self.start}..{self.stop}" + (f":{self.stride}" if self.stride else "") 

352 return res 

353 

354 

355class IsIn(Node): 

356 """Node representing IN or NOT IN expression. 

357 

358 Parameters 

359 ---------- 

360 lhs : `Node` 

361 Left-hand side of the operation. 

362 values : `list` of `Node` 

363 List of values on the right side. 

364 not_in : `bool` 

365 If `True` then it is NOT IN expression, otherwise it is IN expression. 

366 """ 

367 

368 def __init__(self, lhs: Node, values: list[Node], not_in: bool = False): 

369 # All values must be literals or binds (and we allow simple identifiers 

370 # as binds). 

371 for node in values: 

372 node = _strip_parens(node) 

373 if not isinstance(node, LiteralNode | BindName | Identifier): 

374 raise TypeError(f"Unsupported type of expression in IN operator: {node}") 

375 super().__init__((lhs,) + tuple(values)) 

376 self.lhs = lhs 

377 self.values = values 

378 self.not_in = not_in 

379 

380 def visit(self, visitor: TreeVisitor) -> Any: 

381 # Docstring inherited from Node.visit 

382 lhs = self.lhs.visit(visitor) 

383 values = [value.visit(visitor) for value in self.values] 

384 return visitor.visitIsIn(lhs, values, self.not_in, self) 

385 

386 def __str__(self) -> str: 

387 values = ", ".join(str(x) for x in self.values) 

388 not_in = "" 

389 if self.not_in: 

390 not_in = "NOT " 

391 return f"{self.lhs} {not_in}IN ({values})" 

392 

393 

394class Parens(Node): 

395 """Node representing parenthesized expression. 

396 

397 Parameters 

398 ---------- 

399 expr : `Node` 

400 Expression inside parentheses. 

401 """ 

402 

403 def __init__(self, expr: Node): 

404 super().__init__((expr,)) 

405 self.expr = expr 

406 

407 def visit(self, visitor: TreeVisitor) -> Any: 

408 # Docstring inherited from Node.visit 

409 expr = self.expr.visit(visitor) 

410 return visitor.visitParens(expr, self) 

411 

412 def __str__(self) -> str: 

413 return "({expr})".format(**vars(self)) 

414 

415 

416class TupleNode(Node): 

417 """Node representing a tuple, sequence of parenthesized expressions. 

418 

419 Tuple is used to represent time ranges, for now parser supports tuples 

420 with two items, though this class can be used to represent different 

421 number of items in sequence. 

422 

423 Parameters 

424 ---------- 

425 items : `tuple` of `Node` 

426 Expressions inside parentheses. 

427 """ 

428 

429 def __init__(self, items: tuple[Node, ...]): 

430 super().__init__(items) 

431 self.items = items 

432 

433 def visit(self, visitor: TreeVisitor) -> Any: 

434 # Docstring inherited from Node.visit 

435 items = tuple(item.visit(visitor) for item in self.items) 

436 return visitor.visitTupleNode(items, self) 

437 

438 def __str__(self) -> str: 

439 items = ", ".join(str(item) for item in self.items) 

440 return f"({items})" 

441 

442 

443class FunctionCall(Node): 

444 """Node representing a function call. 

445 

446 Parameters 

447 ---------- 

448 function : `str` 

449 Name of the function. 

450 args : `list` [ `Node` ] 

451 Arguments passed to function. 

452 """ 

453 

454 def __init__(self, function: str, args: list[Node]): 

455 super().__init__(tuple(args)) 

456 self.name = function 

457 self.args = args[:] 

458 

459 def visit(self, visitor: TreeVisitor) -> Any: 

460 # Docstring inherited from Node.visit 

461 args = [arg.visit(visitor) for arg in self.args] 

462 return visitor.visitFunctionCall(self.name, args, self) 

463 

464 def __str__(self) -> str: 

465 args = ", ".join(str(arg) for arg in self.args) 

466 return f"{self.name}({args})" 

467 

468 

469class PointNode(Node): 

470 """Node representing a point, (ra, dec) pair. 

471 

472 Parameters 

473 ---------- 

474 ra : `Node` 

475 Node representing ra value. 

476 dec : `Node` 

477 Node representing dec value. 

478 """ 

479 

480 def __init__(self, ra: Node, dec: Node): 

481 super().__init__((ra, dec)) 

482 self.ra = ra 

483 self.dec = dec 

484 

485 def visit(self, visitor: TreeVisitor) -> Any: 

486 # Docstring inherited from Node.visit 

487 ra = self.ra.visit(visitor) 

488 dec = self.dec.visit(visitor) 

489 return visitor.visitPointNode(ra, dec, self) 

490 

491 def __str__(self) -> str: 

492 return f"POINT({self.ra}, {self.dec})" 

493 

494 

495class CircleNode(Node): 

496 """Node representing a circle, (ra, dec, radius) pair. 

497 

498 Parameters 

499 ---------- 

500 ra : `Node` 

501 Node representing circle center ra value. 

502 dec : `Node` 

503 Node representing circle center dec value. 

504 radius : `Node` 

505 Node representing circle radius value. 

506 """ 

507 

508 def __init__(self, ra: Node, dec: Node, radius: Node): 

509 super().__init__((ra, dec, radius)) 

510 self.ra = ra 

511 self.dec = dec 

512 self.radius = radius 

513 

514 def visit(self, visitor: TreeVisitor) -> Any: 

515 # Docstring inherited from Node.visit 

516 ra = self.ra.visit(visitor) 

517 dec = self.dec.visit(visitor) 

518 radius = self.radius.visit(visitor) 

519 return visitor.visitCircleNode(ra, dec, radius, self) 

520 

521 def __str__(self) -> str: 

522 return f"CIRCLE({self.ra}, {self.dec}, {self.radius})" 

523 

524 

525class BoxNode(Node): 

526 """Node representing box region in ADQL notation (ra, dec, width, height). 

527 

528 Parameters 

529 ---------- 

530 ra : `Node` 

531 Node representing box center ra value. 

532 dec : `Node` 

533 Node representing box center dec value. 

534 width : `Node` 

535 Node representing box ra width value. 

536 height : `Node` 

537 Node representing box dec height value. 

538 """ 

539 

540 def __init__(self, ra: Node, dec: Node, width: Node, height: Node): 

541 super().__init__((ra, dec, width, height)) 

542 self.ra = ra 

543 self.dec = dec 

544 self.width = width 

545 self.height = height 

546 

547 def visit(self, visitor: TreeVisitor) -> Any: 

548 # Docstring inherited from Node.visit 

549 ra = self.ra.visit(visitor) 

550 dec = self.dec.visit(visitor) 

551 width = self.width.visit(visitor) 

552 height = self.height.visit(visitor) 

553 return visitor.visitBoxNode(ra, dec, width, height, self) 

554 

555 def __str__(self) -> str: 

556 return f"BOX({self.ra}, {self.dec}, {self.width}, {self.height})" 

557 

558 

559class PolygonNode(Node): 

560 """Node representing polygon region in ADQL notation. 

561 

562 Parameters 

563 ---------- 

564 vertices : `list` [`tuple` [`Node`, `Node`]] 

565 Node representing vertices of polygon. 

566 """ 

567 

568 def __init__(self, vertices: list[tuple[Node, Node]]): 

569 super().__init__(sum(vertices, start=())) 

570 self.vertices = vertices 

571 

572 def visit(self, visitor: TreeVisitor) -> Any: 

573 # Docstring inherited from Node.visit 

574 vertices = [(ra.visit(visitor), dec.visit(visitor)) for ra, dec in self.vertices] 

575 return visitor.visitPolygonNode(vertices, self) 

576 

577 def __str__(self) -> str: 

578 params = ", ".join(str(param) for param in self.children) 

579 return f"POLYGON({params})" 

580 

581 

582class RegionNode(Node): 

583 """Node representing region using IVOA SIAv2 POS notation. 

584 

585 Parameters 

586 ---------- 

587 pos : `Node` 

588 IVOA SIAv2 POS string representation of a region. 

589 """ 

590 

591 def __init__(self, pos: Node): 

592 super().__init__((pos,)) 

593 self.pos = pos 

594 

595 def visit(self, visitor: TreeVisitor) -> Any: 

596 # Docstring inherited from Node.visit 

597 pos = self.pos.visit(visitor) 

598 return visitor.visitRegionNode(pos, self) 

599 

600 def __str__(self) -> str: 

601 return f"REGION({self.pos})" 

602 

603 

604class GlobNode(Node): 

605 """Node representing a call to GLOB(pattern, expression) function. 

606 

607 Parameters 

608 ---------- 

609 expression : `Node` 

610 Node representing expression matched against pattern, typically a 

611 column-like thing. 

612 pattern : `Node` 

613 Node representing a pattern, this must be either a `StringLiteral` or 

614 `BindName`. 

615 """ 

616 

617 def __init__(self, expression: Identifier, pattern: StringLiteral | BindName): 

618 super().__init__((expression, pattern)) 

619 self.expression = expression 

620 self.pattern = pattern 

621 

622 def visit(self, visitor: TreeVisitor) -> Any: 

623 # Docstring inherited from Node.visit 

624 expression = self.expression.visit(visitor) 

625 pattern = self.pattern.visit(visitor) 

626 return visitor.visitGlobNode(expression, pattern, self) 

627 

628 def __str__(self) -> str: 

629 return f"GLOB({self.expression}, {self.pattern})" 

630 

631 

632def function_call(function: str, args: list[Node]) -> Node: 

633 """Return node representing function calls. 

634 

635 Parameters 

636 ---------- 

637 function : `str` 

638 Name of the function. 

639 args : `list` [ `Node` ] 

640 Arguments passed to function. 

641 

642 Notes 

643 ----- 

644 Our parser supports arbitrary functions with arbitrary list of parameters. 

645 For now we only need a small set of functions, and to simplify 

646 implementation of visitors we define special type of node for each 

647 supported function. This method makes a special `Node` instance for those 

648 supported functions, and generic `FunctionCall` instance for all other 

649 functions. Tree visitors will most likely raise an error when visiting 

650 `FunctionCall` nodes. 

651 """ 

652 function_name = function.upper() 

653 if function_name == "POINT": 

654 if len(args) != 2: 

655 raise ValueError("POINT requires two arguments (ra, dec)") 

656 return PointNode(*args) 

657 elif function_name == "CIRCLE": 

658 if len(args) != 3: 

659 raise ValueError("CIRCLE requires three arguments (ra, dec, radius)") 

660 # Check types of arguments, we want to support expressions too. 

661 for name, arg in zip(("ra", "dec", "radius"), args, strict=True): 

662 if not isinstance(arg, NumericLiteral | BindName | Identifier | BinaryOp | UnaryOp): 

663 raise ValueError(f"CIRCLE {name} argument must be either numeric expression or bind value.") 

664 return CircleNode(*args) 

665 elif function_name == "BOX": 

666 if len(args) != 4: 

667 raise ValueError("CIRCLE requires four arguments (ra, dec, width, height)") 

668 # Check types of arguments, we want to support expressions too. 

669 for name, arg in zip(("ra", "dec", "width", "height"), args, strict=True): 

670 if not isinstance(arg, NumericLiteral | BindName | Identifier | BinaryOp | UnaryOp): 

671 raise ValueError(f"BOX {name} argument must be either numeric expression or bind value.") 

672 return BoxNode(*args) 

673 elif function_name == "POLYGON": 

674 if len(args) % 2 != 0: 

675 raise ValueError("POLYGON requires even number of arguments") 

676 if len(args) < 6: 

677 raise ValueError("POLYGON requires at least three vertices") 

678 # Check types of arguments, we want to support expressions too. 

679 for arg in args: 

680 if not isinstance(arg, NumericLiteral | BindName | Identifier | BinaryOp | UnaryOp): 

681 raise ValueError("POLYGON argument must be either numeric expression or bind value.") 

682 vertices = list(zip(args[::2], args[1::2])) 

683 return PolygonNode(vertices) 

684 elif function_name == "REGION": 

685 if len(args) != 1: 

686 raise ValueError("REGION requires a single string argument") 

687 if not isinstance(args[0], StringLiteral | BindName | Identifier): 

688 raise ValueError("REGION argument must be either a string or a bind value") 

689 return RegionNode(args[0]) 

690 elif function_name == "GLOB": 

691 if len(args) != 2: 

692 raise ValueError("GLOB requires two arguments (pattern, expression)") 

693 expression, pattern = (_strip_parens(arg) for arg in args) 

694 if not isinstance(expression, Identifier): 

695 raise TypeError("glob() first argument must be an identifier") 

696 if not isinstance(pattern, StringLiteral | BindName): 

697 raise TypeError("glob() second argument must be a string or a bind name (prefixed with colon)") 

698 return GlobNode(expression, pattern) 

699 elif function_name == "UUID": 

700 if len(args) != 1: 

701 raise ValueError("UUID() requires a single arguments (uuid-string)") 

702 argument = _strip_parens(args[0]) 

703 # Potentially we could allow BindName as argument but it's too 

704 # complicated and people should use bind with UUID instead. 

705 if not isinstance(argument, StringLiteral): 

706 raise TypeError("UUID() argument must be a string literal") 

707 # This will raise ValueError if string is not good. 

708 uuid = UUID(argument.value) 

709 return UuidLiteral(uuid) 

710 else: 

711 # generic function call 

712 return FunctionCall(function, args)