Coverage for python/lsst/daf/butler/registry/queries/expressions/_predicate.py: 12%

207 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 08:00 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("make_string_expression_predicate", "ExpressionTypeError") 

30 

31import builtins 

32import datetime 

33import types 

34import warnings 

35from collections.abc import Mapping, Set 

36from typing import Any, cast 

37 

38import astropy.time 

39import astropy.utils.exceptions 

40from lsst.daf.relation import ( 

41 ColumnContainer, 

42 ColumnExpression, 

43 ColumnExpressionSequence, 

44 ColumnLiteral, 

45 ColumnTag, 

46 Predicate, 

47 sql, 

48) 

49 

50# We import the timespan module rather than types within it because match 

51# syntax uses qualified names with periods to distinguish literals from 

52# captures. 

53from ....core import ( 

54 ColumnTypeInfo, 

55 DataCoordinate, 

56 DatasetColumnTag, 

57 Dimension, 

58 DimensionGraph, 

59 DimensionKeyColumnTag, 

60 DimensionRecordColumnTag, 

61 DimensionUniverse, 

62 timespan, 

63) 

64from ..._exceptions import UserExpressionError, UserExpressionSyntaxError 

65from .categorize import ExpressionConstant, categorizeConstant, categorizeElementId 

66from .check import CheckVisitor 

67from .normalForm import NormalForm, NormalFormExpression 

68from .parser import Node, ParserYacc, TreeVisitor # type: ignore 

69 

70# As of astropy 4.2, the erfa interface is shipped independently and 

71# ErfaWarning is no longer an AstropyWarning 

72try: 

73 import erfa 

74except ImportError: 

75 erfa = None 

76 

77 

78class ExpressionTypeError(TypeError): 

79 """Exception raised when the types in a query expression are not 

80 compatible with the operators or other syntax. 

81 """ 

82 

83 

84def make_string_expression_predicate( 

85 string: str, 

86 dimensions: DimensionGraph, 

87 *, 

88 column_types: ColumnTypeInfo, 

89 bind: Mapping[str, Any] | None = None, 

90 data_id: DataCoordinate | None = None, 

91 defaults: DataCoordinate | None = None, 

92 dataset_type_name: str | None = None, 

93 allow_orphans: bool = False, 

94) -> tuple[Predicate | None, Mapping[str, Set[str]]]: 

95 """Create a predicate by parsing and analyzing a string expression. 

96 

97 Parameters 

98 ---------- 

99 string : `str` 

100 String to parse. 

101 dimensions : `DimensionGraph` 

102 The dimensions the query would include in the absence of this WHERE 

103 expression. 

104 column_types : `ColumnTypeInfo` 

105 Information about column types. 

106 bind : `~collections.abc.Mapping` [ `str`, `Any` ], optional 

107 Literal values referenced in the expression. 

108 data_id : `DataCoordinate`, optional 

109 A fully-expanded data ID identifying dimensions known in advance. 

110 If not provided, will be set to an empty data ID. 

111 ``dataId.hasRecords()`` must return `True`. 

112 defaults : `DataCoordinate`, optional 

113 A data ID containing default for governor dimensions. Ignored 

114 unless ``check=True``. 

115 dataset_type_name : `str` or `None`, optional 

116 The name of the dataset type to assume for unqualified dataset 

117 columns, or `None` if there are no such identifiers. 

118 allow_orphans : `bool`, optional 

119 If `True`, permit expressions to refer to dimensions without 

120 providing a value for their governor dimensions (e.g. referring to 

121 a visit without an instrument). Should be left to default to 

122 `False` in essentially all new code. 

123 

124 Returns 

125 ------- 

126 predicate : `lsst.daf.relation.colum_expressions.Predicate` or `None` 

127 New predicate derived from the string expression, or `None` if the 

128 string is empty. 

129 governor_constraints : `~collections.abc.Mapping` [ `str` , \ 

130 `~collections.abc.Set` ] 

131 Constraints on dimension values derived from the expression and data 

132 ID. 

133 """ 

134 governor_constraints: dict[str, Set[str]] = {} 

135 if data_id is None: 

136 data_id = DataCoordinate.makeEmpty(dimensions.universe) 

137 if not string: 

138 for dimension in data_id.graph.governors: 

139 governor_constraints[dimension.name] = {cast(str, data_id[dimension])} 

140 return None, governor_constraints 

141 try: 

142 parser = ParserYacc() 

143 tree = parser.parse(string) 

144 except Exception as exc: 

145 raise UserExpressionSyntaxError(f"Failed to parse user expression {string!r}.") from exc 

146 if bind is None: 

147 bind = {} 

148 if bind: 

149 for identifier in bind: 

150 if identifier in dimensions.universe.getStaticElements().names: 

151 raise RuntimeError(f"Bind parameter key {identifier!r} conflicts with a dimension element.") 

152 table, _, column = identifier.partition(".") 

153 if column and table in dimensions.universe.getStaticElements().names: 

154 raise RuntimeError(f"Bind parameter key {identifier!r} looks like a dimension column.") 

155 if defaults is None: 

156 defaults = DataCoordinate.makeEmpty(dimensions.universe) 

157 # Convert the expression to disjunctive normal form (ORs of ANDs). 

158 # That's potentially super expensive in the general case (where there's 

159 # a ton of nesting of ANDs and ORs). That won't be the case for the 

160 # expressions we expect, and we actually use disjunctive normal instead 

161 # of conjunctive (i.e. ANDs of ORs) because I think the worst-case is 

162 # a long list of OR'd-together data IDs, which is already in or very 

163 # close to disjunctive normal form. 

164 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

165 # Check the expression for consistency and completeness. 

166 visitor = CheckVisitor(data_id, dimensions, bind, defaults, allow_orphans=allow_orphans) 

167 try: 

168 summary = expr.visit(visitor) 

169 except UserExpressionError as err: 

170 exprOriginal = str(tree) 

171 exprNormal = str(expr.toTree()) 

172 if exprNormal == exprOriginal: 

173 msg = f'Error in query expression "{exprOriginal}": {err}' 

174 else: 

175 msg = f'Error in query expression "{exprOriginal}" (normalized to "{exprNormal}"): {err}' 

176 raise UserExpressionError(msg) from None 

177 for dimension_name, values in summary.dimension_constraints.items(): 

178 if dimension_name in dimensions.universe.getGovernorDimensions().names: 

179 governor_constraints[dimension_name] = cast(Set[str], values) 

180 converter = PredicateConversionVisitor(bind, dataset_type_name, dimensions.universe, column_types) 

181 predicate = tree.visit(converter) 

182 return predicate, governor_constraints 

183 

184 

185VisitorResult = Predicate | ColumnExpression | ColumnContainer 

186 

187 

188class PredicateConversionVisitor(TreeVisitor[VisitorResult]): 

189 def __init__( 

190 self, 

191 bind: Mapping[str, Any], 

192 dataset_type_name: str | None, 

193 universe: DimensionUniverse, 

194 column_types: ColumnTypeInfo, 

195 ): 

196 self.bind = bind 

197 self.dataset_type_name = dataset_type_name 

198 self.universe = universe 

199 self.column_types = column_types 

200 

201 OPERATOR_MAP = { 

202 "=": "__eq__", 

203 "!=": "__ne__", 

204 "<": "__lt__", 

205 ">": "__gt__", 

206 "<=": "__le__", 

207 ">=": "__ge__", 

208 "+": "__add__", 

209 "-": "__sub__", 

210 "/": "__mul__", 

211 } 

212 

213 def to_datetime(self, time: astropy.time.Time) -> datetime.datetime: 

214 with warnings.catch_warnings(): 

215 warnings.simplefilter("ignore", category=astropy.utils.exceptions.AstropyWarning) 

216 if erfa is not None: 

217 warnings.simplefilter("ignore", category=erfa.ErfaWarning) 

218 return time.to_datetime() 

219 

220 def visitBinaryOp( 

221 self, operator: str, lhs: VisitorResult, rhs: VisitorResult, node: Node 

222 ) -> VisitorResult: 

223 # Docstring inherited. 

224 b = builtins 

225 match (operator, lhs, rhs): 

226 case ["OR", Predicate() as lhs, Predicate() as rhs]: 

227 return lhs.logical_or(rhs) 

228 case ["AND", Predicate() as lhs, Predicate() as rhs]: 

229 return lhs.logical_and(rhs) 

230 # Allow all comparisons between expressions of the same type for 

231 # sortable types. 

232 case [ 

233 "=" | "!=" | "<" | ">" | "<=" | ">=", 

234 ColumnExpression( 

235 dtype=b.int | b.float | b.str | astropy.time.Time | datetime.datetime 

236 ) as lhs, 

237 ColumnExpression() as rhs, 

238 ] if lhs.dtype is rhs.dtype: 

239 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

240 # Allow comparisons between datetime expressions and 

241 # astropy.time.Time literals/binds (only), by coercing the 

242 # astropy.time.Time version to datetime. 

243 case [ 

244 "=" | "!=" | "<" | ">" | "<=" | ">=", 

245 ColumnLiteral(dtype=astropy.time.Time) as lhs, 

246 ColumnExpression(dtype=datetime.datetime) as rhs, 

247 ]: 

248 lhs = ColumnLiteral(self.to_datetime(lhs.value), datetime.datetime) 

249 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

250 case [ 

251 "=" | "!=" | "<" | ">" | "<=" | ">=", 

252 ColumnExpression(dtype=datetime.datetime) as lhs, 

253 ColumnLiteral(dtype=astropy.time.Time) as rhs, 

254 ]: 

255 rhs = ColumnLiteral(self.to_datetime(rhs.value), datetime.datetime) 

256 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

257 # Allow comparisons between astropy.time.Time expressions and 

258 # datetime literals/binds, by coercing the 

259 # datetime literals to astropy.time.Time (in UTC scale). 

260 case [ 

261 "=" | "!=" | "<" | ">" | "<=" | ">=", 

262 ColumnLiteral(dtype=datetime.datetime) as lhs, 

263 ColumnExpression(dtype=astropy.time.Time) as rhs, 

264 ]: 

265 lhs = ColumnLiteral(astropy.time.Time(lhs.value, scale="utc"), astropy.time.Time) 

266 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

267 case [ 

268 "=" | "!=" | "<" | ">" | "<=" | ">=", 

269 ColumnExpression(dtype=astropy.time.Time) as lhs, 

270 ColumnLiteral(dtype=datetime.datetime) as rhs, 

271 ]: 

272 rhs = ColumnLiteral(astropy.time.Time(rhs.value, scale="utc"), astropy.time.Time) 

273 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

274 # Allow equality comparisons with None/NULL. We don't have an 'IS' 

275 # operator. 

276 case ["=" | "!=", ColumnExpression(dtype=types.NoneType) as lhs, ColumnExpression() as rhs]: 

277 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

278 case ["=" | "!=", ColumnExpression() as lhs, ColumnExpression(dtype=types.NoneType) as rhs]: 

279 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

280 # Comparisions between Time and Timespan need have the Timespan on 

281 # the lhs, since that (actually TimespanDatabaseRepresentation) is 

282 # what actually has the methods. 

283 case [ 

284 "<", 

285 ColumnExpression(dtype=astropy.time.Time) as lhs, 

286 ColumnExpression(dtype=timespan.Timespan) as rhs, 

287 ]: 

288 return rhs.predicate_method(self.OPERATOR_MAP[">"], lhs) 

289 case [ 

290 ">", 

291 ColumnExpression(dtype=astropy.time.Time) as lhs, 

292 ColumnExpression(dtype=timespan.Timespan) as rhs, 

293 ]: 

294 return rhs.predicate_method(self.OPERATOR_MAP["<"], lhs) 

295 # Enable other comparisons between times and Timespans (many of the 

296 # combinations matched by this branch will have already been 

297 # covered by a preceding branch). 

298 case [ 

299 "<" | ">", 

300 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as lhs, 

301 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as rhs, 

302 ]: 

303 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

304 # Enable "overlaps" operations between timespans, and between times 

305 # and timespans. The latter resolve to the `Timespan.contains` or 

306 # `TimespanDatabaseRepresentation.contains` methods, but we use 

307 # OVERLAPS in the string expression language to keep that simple. 

308 case [ 

309 "OVERLAPS", 

310 ColumnExpression(dtype=timespan.Timespan) as lhs, 

311 ColumnExpression(dtype=timespan.Timespan) as rhs, 

312 ]: 

313 return lhs.predicate_method("overlaps", rhs) 

314 case [ 

315 "OVERLAPS", 

316 ColumnExpression(dtype=timespan.Timespan) as lhs, 

317 ColumnExpression(dtype=astropy.time.Time) as rhs, 

318 ]: 

319 return lhs.predicate_method("overlaps", rhs) 

320 case [ 

321 "OVERLAPS", 

322 ColumnExpression(dtype=astropy.time.Time) as lhs, 

323 ColumnExpression(dtype=timespan.Timespan) as rhs, 

324 ]: 

325 return rhs.predicate_method("overlaps", lhs) 

326 # Enable arithmetic operators on numeric types, without any type 

327 # coercion or broadening. 

328 case [ 

329 "+" | "-" | "*", 

330 ColumnExpression(dtype=b.int | b.float) as lhs, 

331 ColumnExpression() as rhs, 

332 ] if lhs.dtype is rhs.dtype: 

333 return lhs.method(self.OPERATOR_MAP[operator], rhs, dtype=lhs.dtype) 

334 case ["/", ColumnExpression(dtype=b.float) as lhs, ColumnExpression(dtype=b.float) as rhs]: 

335 return lhs.method("__truediv__", rhs, dtype=b.float) 

336 case ["/", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]: 

337 # SQLAlchemy maps Python's '/' (__truediv__) operator directly 

338 # to SQL's '/', despite those being defined differently for 

339 # integers. Our expression language uses the SQL definition, 

340 # and we only care about these expressions being evaluated in 

341 # SQL right now, but we still want to guard against it being 

342 # evaluated in Python and producing a surprising answer, so we 

343 # mark it as being supported only by a SQL engine. 

344 return lhs.method( 

345 "__truediv__", 

346 rhs, 

347 dtype=b.int, 

348 supporting_engine_types={sql.Engine}, 

349 ) 

350 case ["%", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]: 

351 return lhs.method("__mod__", rhs, dtype=b.int) 

352 assert ( 

353 lhs.dtype is not None and rhs.dtype is not None 

354 ), "Expression converter should not yield untyped nodes." 

355 raise ExpressionTypeError( 

356 f"Invalid types {lhs.dtype.__name__}, {rhs.dtype.__name__} for binary operator {operator!r} " 

357 f"in expression {node!s}." 

358 ) 

359 

360 def visitIdentifier(self, name: str, node: Node) -> VisitorResult: 

361 # Docstring inherited. 

362 if name in self.bind: 

363 value = self.bind[name] 

364 if isinstance(value, list | tuple | Set): 

365 elements = [] 

366 all_dtypes = set() 

367 for item in value: 

368 dtype = type(item) 

369 all_dtypes.add(dtype) 

370 elements.append(ColumnExpression.literal(item, dtype=dtype)) 

371 if len(all_dtypes) > 1: 

372 raise ExpressionTypeError( 

373 f"Mismatched types in bind iterable: {value} has a mix of {all_dtypes}." 

374 ) 

375 elif not elements: 

376 # Empty container 

377 return ColumnContainer.sequence([]) 

378 else: 

379 (dtype,) = all_dtypes 

380 return ColumnContainer.sequence(elements, dtype=dtype) 

381 return ColumnExpression.literal(value, dtype=type(value)) 

382 tag: ColumnTag 

383 match categorizeConstant(name): 

384 case ExpressionConstant.INGEST_DATE: 

385 assert self.dataset_type_name is not None 

386 tag = DatasetColumnTag(self.dataset_type_name, "ingest_date") 

387 return ColumnExpression.reference(tag, self.column_types.ingest_date_pytype) 

388 case ExpressionConstant.NULL: 

389 return ColumnExpression.literal(None, type(None)) 

390 case None: 

391 pass 

392 case _: 

393 raise AssertionError("Check for enum values should be exhaustive.") 

394 element, column = categorizeElementId(self.universe, name) 

395 if column is not None: 

396 tag = DimensionRecordColumnTag(element.name, column) 

397 dtype = ( 

398 timespan.Timespan 

399 if column == timespan.TimespanDatabaseRepresentation.NAME 

400 else element.RecordClass.fields.standard[column].getPythonType() 

401 ) 

402 return ColumnExpression.reference(tag, dtype) 

403 else: 

404 tag = DimensionKeyColumnTag(element.name) 

405 assert isinstance(element, Dimension) 

406 return ColumnExpression.reference(tag, element.primaryKey.getPythonType()) 

407 

408 def visitIsIn( 

409 self, lhs: VisitorResult, values: list[VisitorResult], not_in: bool, node: Node 

410 ) -> VisitorResult: 

411 # Docstring inherited. 

412 clauses: list[Predicate] = [] 

413 items: list[ColumnExpression] = [] 

414 assert isinstance(lhs, ColumnExpression), "LHS of IN guaranteed to be scalar by parser." 

415 for rhs_item in values: 

416 match rhs_item: 

417 case ColumnExpressionSequence( 

418 items=rhs_items, dtype=rhs_dtype 

419 ) if rhs_dtype is None or rhs_dtype == lhs.dtype: 

420 items.extend(rhs_items) 

421 case ColumnContainer(dtype=lhs.dtype): 

422 clauses.append(rhs_item.contains(lhs)) 

423 case ColumnExpression(dtype=lhs.dtype): 

424 items.append(rhs_item) 

425 case _: 

426 raise ExpressionTypeError( 

427 f"Invalid type {rhs_item.dtype} for element in {lhs.dtype} IN expression '{node}'." 

428 ) 

429 if items: 

430 clauses.append(ColumnContainer.sequence(items, dtype=lhs.dtype).contains(lhs)) 

431 result = Predicate.logical_or(*clauses) 

432 if not_in: 

433 result = result.logical_not() 

434 return result 

435 

436 def visitNumericLiteral(self, value: str, node: Node) -> VisitorResult: 

437 # Docstring inherited. 

438 try: 

439 return ColumnExpression.literal(int(value), dtype=int) 

440 except ValueError: 

441 return ColumnExpression.literal(float(value), dtype=float) 

442 

443 def visitParens(self, expression: VisitorResult, node: Node) -> VisitorResult: 

444 # Docstring inherited. 

445 return expression 

446 

447 def visitPointNode(self, ra: VisitorResult, dec: VisitorResult, node: Node) -> VisitorResult: 

448 # Docstring inherited. 

449 

450 # this is a placeholder for future extension, we enabled syntax but 

451 # do not support actual use just yet. 

452 raise NotImplementedError("POINT() function is not supported yet") 

453 

454 def visitRangeLiteral(self, start: int, stop: int, stride: int | None, node: Node) -> VisitorResult: 

455 # Docstring inherited. 

456 return ColumnContainer.range_literal(range(start, stop + 1, stride or 1)) 

457 

458 def visitStringLiteral(self, value: str, node: Node) -> VisitorResult: 

459 # Docstring inherited. 

460 return ColumnExpression.literal(value, dtype=str) 

461 

462 def visitTimeLiteral(self, value: astropy.time.Time, node: Node) -> VisitorResult: 

463 # Docstring inherited. 

464 return ColumnExpression.literal(value, dtype=astropy.time.Time) 

465 

466 def visitTupleNode(self, items: tuple[VisitorResult, ...], node: Node) -> VisitorResult: 

467 # Docstring inherited. 

468 match items: 

469 case [ 

470 ColumnLiteral(value=begin, dtype=astropy.time.Time | types.NoneType), 

471 ColumnLiteral(value=end, dtype=astropy.time.Time | types.NoneType), 

472 ]: 

473 return ColumnExpression.literal(timespan.Timespan(begin, end), dtype=timespan.Timespan) 

474 raise ExpressionTypeError( 

475 f'Invalid type(s) ({items[0].dtype}, {items[1].dtype}) in timespan tuple "{node}" ' 

476 '(Note that date/time strings must be preceded by "T" to be recognized).' 

477 ) 

478 

479 def visitUnaryOp(self, operator: str, operand: VisitorResult, node: Node) -> VisitorResult: 

480 # Docstring inherited. 

481 match (operator, operand): 

482 case ["NOT", Predicate() as operand]: 

483 return operand.logical_not() 

484 case ["+", ColumnExpression(dtype=builtins.int | builtins.float) as operand]: 

485 return operand.method("__pos__") 

486 case ["-", ColumnExpression(dtype=builtins.int | builtins.float) as operand]: 

487 return operand.method("__neg__") 

488 raise ExpressionTypeError( 

489 f"Unary operator {operator!r} is not valid for operand of type {operand.dtype!s} in {node!s}." 

490 )