Coverage for python/lsst/daf/butler/registry/queries/expressions/_predicate.py: 12%

207 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-05 01:26 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("make_string_expression_predicate", "ExpressionTypeError") 

24 

25import builtins 

26import datetime 

27import types 

28import warnings 

29from collections.abc import Mapping, Set 

30from typing import Any, cast 

31 

32import astropy.time 

33import astropy.utils.exceptions 

34from lsst.daf.relation import ( 

35 ColumnContainer, 

36 ColumnExpression, 

37 ColumnExpressionSequence, 

38 ColumnLiteral, 

39 ColumnTag, 

40 Predicate, 

41 sql, 

42) 

43 

44# We import the timespan module rather than types within it because match 

45# syntax uses qualified names with periods to distinguish literals from 

46# captures. 

47from ....core import ( 

48 ColumnTypeInfo, 

49 DataCoordinate, 

50 DatasetColumnTag, 

51 Dimension, 

52 DimensionGraph, 

53 DimensionKeyColumnTag, 

54 DimensionRecordColumnTag, 

55 DimensionUniverse, 

56 timespan, 

57) 

58from ..._exceptions import UserExpressionError, UserExpressionSyntaxError 

59from .categorize import ExpressionConstant, categorizeConstant, categorizeElementId 

60from .check import CheckVisitor 

61from .normalForm import NormalForm, NormalFormExpression 

62from .parser import Node, ParserYacc, TreeVisitor # type: ignore 

63 

64# As of astropy 4.2, the erfa interface is shipped independently and 

65# ErfaWarning is no longer an AstropyWarning 

66try: 

67 import erfa 

68except ImportError: 

69 erfa = None 

70 

71 

72class ExpressionTypeError(TypeError): 

73 """Exception raised when the types in a query expression are not 

74 compatible with the operators or other syntax. 

75 """ 

76 

77 

78def make_string_expression_predicate( 

79 string: str, 

80 dimensions: DimensionGraph, 

81 *, 

82 column_types: ColumnTypeInfo, 

83 bind: Mapping[str, Any] | None = None, 

84 data_id: DataCoordinate | None = None, 

85 defaults: DataCoordinate | None = None, 

86 dataset_type_name: str | None = None, 

87 allow_orphans: bool = False, 

88) -> tuple[Predicate | None, Mapping[str, Set[str]]]: 

89 """Create a predicate by parsing and analyzing a string expression. 

90 

91 Parameters 

92 ---------- 

93 string : `str` 

94 String to parse. 

95 dimensions : `DimensionGraph` 

96 The dimensions the query would include in the absence of this WHERE 

97 expression. 

98 column_types : `ColumnTypeInfo` 

99 Information about column types. 

100 bind : `~collections.abc.Mapping` [ `str`, `Any` ], optional 

101 Literal values referenced in the expression. 

102 data_id : `DataCoordinate`, optional 

103 A fully-expanded data ID identifying dimensions known in advance. 

104 If not provided, will be set to an empty data ID. 

105 ``dataId.hasRecords()`` must return `True`. 

106 defaults : `DataCoordinate`, optional 

107 A data ID containing default for governor dimensions. Ignored 

108 unless ``check=True``. 

109 dataset_type_name : `str` or `None`, optional 

110 The name of the dataset type to assume for unqualified dataset 

111 columns, or `None` if there are no such identifiers. 

112 allow_orphans : `bool`, optional 

113 If `True`, permit expressions to refer to dimensions without 

114 providing a value for their governor dimensions (e.g. referring to 

115 a visit without an instrument). Should be left to default to 

116 `False` in essentially all new code. 

117 

118 Returns 

119 ------- 

120 predicate : `lsst.daf.relation.colum_expressions.Predicate` or `None` 

121 New predicate derived from the string expression, or `None` if the 

122 string is empty. 

123 governor_constraints : `~collections.abc.Mapping` [ `str` , \ 

124 `~collections.abc.Set` ] 

125 Constraints on dimension values derived from the expression and data 

126 ID. 

127 """ 

128 governor_constraints: dict[str, Set[str]] = {} 

129 if data_id is None: 

130 data_id = DataCoordinate.makeEmpty(dimensions.universe) 

131 if not string: 

132 for dimension in data_id.graph.governors: 

133 governor_constraints[dimension.name] = {cast(str, data_id[dimension])} 

134 return None, governor_constraints 

135 try: 

136 parser = ParserYacc() 

137 tree = parser.parse(string) 

138 except Exception as exc: 

139 raise UserExpressionSyntaxError(f"Failed to parse user expression {string!r}.") from exc 

140 if bind is None: 

141 bind = {} 

142 if bind: 

143 for identifier in bind: 

144 if identifier in dimensions.universe.getStaticElements().names: 

145 raise RuntimeError(f"Bind parameter key {identifier!r} conflicts with a dimension element.") 

146 table, _, column = identifier.partition(".") 

147 if column and table in dimensions.universe.getStaticElements().names: 

148 raise RuntimeError(f"Bind parameter key {identifier!r} looks like a dimension column.") 

149 if defaults is None: 

150 defaults = DataCoordinate.makeEmpty(dimensions.universe) 

151 # Convert the expression to disjunctive normal form (ORs of ANDs). 

152 # That's potentially super expensive in the general case (where there's 

153 # a ton of nesting of ANDs and ORs). That won't be the case for the 

154 # expressions we expect, and we actually use disjunctive normal instead 

155 # of conjunctive (i.e. ANDs of ORs) because I think the worst-case is 

156 # a long list of OR'd-together data IDs, which is already in or very 

157 # close to disjunctive normal form. 

158 expr = NormalFormExpression.fromTree(tree, NormalForm.DISJUNCTIVE) 

159 # Check the expression for consistency and completeness. 

160 visitor = CheckVisitor(data_id, dimensions, bind, defaults, allow_orphans=allow_orphans) 

161 try: 

162 summary = expr.visit(visitor) 

163 except UserExpressionError as err: 

164 exprOriginal = str(tree) 

165 exprNormal = str(expr.toTree()) 

166 if exprNormal == exprOriginal: 

167 msg = f'Error in query expression "{exprOriginal}": {err}' 

168 else: 

169 msg = f'Error in query expression "{exprOriginal}" (normalized to "{exprNormal}"): {err}' 

170 raise UserExpressionError(msg) from None 

171 for dimension_name, values in summary.dimension_constraints.items(): 

172 if dimension_name in dimensions.universe.getGovernorDimensions().names: 

173 governor_constraints[dimension_name] = cast(Set[str], values) 

174 converter = PredicateConversionVisitor(bind, dataset_type_name, dimensions.universe, column_types) 

175 predicate = tree.visit(converter) 

176 return predicate, governor_constraints 

177 

178 

179VisitorResult = Predicate | ColumnExpression | ColumnContainer 

180 

181 

182class PredicateConversionVisitor(TreeVisitor[VisitorResult]): 

183 def __init__( 

184 self, 

185 bind: Mapping[str, Any], 

186 dataset_type_name: str | None, 

187 universe: DimensionUniverse, 

188 column_types: ColumnTypeInfo, 

189 ): 

190 self.bind = bind 

191 self.dataset_type_name = dataset_type_name 

192 self.universe = universe 

193 self.column_types = column_types 

194 

195 OPERATOR_MAP = { 

196 "=": "__eq__", 

197 "!=": "__ne__", 

198 "<": "__lt__", 

199 ">": "__gt__", 

200 "<=": "__le__", 

201 ">=": "__ge__", 

202 "+": "__add__", 

203 "-": "__sub__", 

204 "/": "__mul__", 

205 } 

206 

207 def to_datetime(self, time: astropy.time.Time) -> datetime.datetime: 

208 with warnings.catch_warnings(): 

209 warnings.simplefilter("ignore", category=astropy.utils.exceptions.AstropyWarning) 

210 if erfa is not None: 

211 warnings.simplefilter("ignore", category=erfa.ErfaWarning) 

212 return time.to_datetime() 

213 

214 def visitBinaryOp( 

215 self, operator: str, lhs: VisitorResult, rhs: VisitorResult, node: Node 

216 ) -> VisitorResult: 

217 # Docstring inherited. 

218 b = builtins 

219 match (operator, lhs, rhs): 

220 case ["OR", Predicate() as lhs, Predicate() as rhs]: 

221 return lhs.logical_or(rhs) 

222 case ["AND", Predicate() as lhs, Predicate() as rhs]: 

223 return lhs.logical_and(rhs) 

224 # Allow all comparisons between expressions of the same type for 

225 # sortable types. 

226 case [ 

227 "=" | "!=" | "<" | ">" | "<=" | ">=", 

228 ColumnExpression( 

229 dtype=b.int | b.float | b.str | astropy.time.Time | datetime.datetime 

230 ) as lhs, 

231 ColumnExpression() as rhs, 

232 ] if lhs.dtype is rhs.dtype: 

233 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

234 # Allow comparisons between datetime expressions and 

235 # astropy.time.Time literals/binds (only), by coercing the 

236 # astropy.time.Time version to datetime. 

237 case [ 

238 "=" | "!=" | "<" | ">" | "<=" | ">=", 

239 ColumnLiteral(dtype=astropy.time.Time) as lhs, 

240 ColumnExpression(dtype=datetime.datetime) as rhs, 

241 ]: 

242 lhs = ColumnLiteral(self.to_datetime(lhs.value), datetime.datetime) 

243 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

244 case [ 

245 "=" | "!=" | "<" | ">" | "<=" | ">=", 

246 ColumnExpression(dtype=datetime.datetime) as lhs, 

247 ColumnLiteral(dtype=astropy.time.Time) as rhs, 

248 ]: 

249 rhs = ColumnLiteral(self.to_datetime(rhs.value), datetime.datetime) 

250 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

251 # Allow comparisons between astropy.time.Time expressions and 

252 # datetime literals/binds, by coercing the 

253 # datetime literals to astropy.time.Time (in UTC scale). 

254 case [ 

255 "=" | "!=" | "<" | ">" | "<=" | ">=", 

256 ColumnLiteral(dtype=datetime.datetime) as lhs, 

257 ColumnExpression(dtype=astropy.time.Time) as rhs, 

258 ]: 

259 lhs = ColumnLiteral(astropy.time.Time(lhs.value, scale="utc"), astropy.time.Time) 

260 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

261 case [ 

262 "=" | "!=" | "<" | ">" | "<=" | ">=", 

263 ColumnExpression(dtype=astropy.time.Time) as lhs, 

264 ColumnLiteral(dtype=datetime.datetime) as rhs, 

265 ]: 

266 rhs = ColumnLiteral(astropy.time.Time(rhs.value, scale="utc"), astropy.time.Time) 

267 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

268 # Allow equality comparisons with None/NULL. We don't have an 'IS' 

269 # operator. 

270 case ["=" | "!=", ColumnExpression(dtype=types.NoneType) as lhs, ColumnExpression() as rhs]: 

271 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

272 case ["=" | "!=", ColumnExpression() as lhs, ColumnExpression(dtype=types.NoneType) as rhs]: 

273 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

274 # Comparisions between Time and Timespan need have the Timespan on 

275 # the lhs, since that (actually TimespanDatabaseRepresentation) is 

276 # what actually has the methods. 

277 case [ 

278 "<", 

279 ColumnExpression(dtype=astropy.time.Time) as lhs, 

280 ColumnExpression(dtype=timespan.Timespan) as rhs, 

281 ]: 

282 return rhs.predicate_method(self.OPERATOR_MAP[">"], lhs) 

283 case [ 

284 ">", 

285 ColumnExpression(dtype=astropy.time.Time) as lhs, 

286 ColumnExpression(dtype=timespan.Timespan) as rhs, 

287 ]: 

288 return rhs.predicate_method(self.OPERATOR_MAP["<"], lhs) 

289 # Enable other comparisons between times and Timespans (many of the 

290 # combinations matched by this branch will have already been 

291 # covered by a preceding branch). 

292 case [ 

293 "<" | ">", 

294 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as lhs, 

295 ColumnExpression(dtype=timespan.Timespan | astropy.time.Time) as rhs, 

296 ]: 

297 return lhs.predicate_method(self.OPERATOR_MAP[operator], rhs) 

298 # Enable "overlaps" operations between timespans, and between times 

299 # and timespans. The latter resolve to the `Timespan.contains` or 

300 # `TimespanDatabaseRepresentation.contains` methods, but we use 

301 # OVERLAPS in the string expression language to keep that simple. 

302 case [ 

303 "OVERLAPS", 

304 ColumnExpression(dtype=timespan.Timespan) as lhs, 

305 ColumnExpression(dtype=timespan.Timespan) as rhs, 

306 ]: 

307 return lhs.predicate_method("overlaps", rhs) 

308 case [ 

309 "OVERLAPS", 

310 ColumnExpression(dtype=timespan.Timespan) as lhs, 

311 ColumnExpression(dtype=astropy.time.Time) as rhs, 

312 ]: 

313 return lhs.predicate_method("overlaps", rhs) 

314 case [ 

315 "OVERLAPS", 

316 ColumnExpression(dtype=astropy.time.Time) as lhs, 

317 ColumnExpression(dtype=timespan.Timespan) as rhs, 

318 ]: 

319 return rhs.predicate_method("overlaps", lhs) 

320 # Enable arithmetic operators on numeric types, without any type 

321 # coercion or broadening. 

322 case [ 

323 "+" | "-" | "*", 

324 ColumnExpression(dtype=b.int | b.float) as lhs, 

325 ColumnExpression() as rhs, 

326 ] if lhs.dtype is rhs.dtype: 

327 return lhs.method(self.OPERATOR_MAP[operator], rhs, dtype=lhs.dtype) 

328 case ["/", ColumnExpression(dtype=b.float) as lhs, ColumnExpression(dtype=b.float) as rhs]: 

329 return lhs.method("__truediv__", rhs, dtype=b.float) 

330 case ["/", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]: 

331 # SQLAlchemy maps Python's '/' (__truediv__) operator directly 

332 # to SQL's '/', despite those being defined differently for 

333 # integers. Our expression language uses the SQL definition, 

334 # and we only care about these expressions being evaluated in 

335 # SQL right now, but we still want to guard against it being 

336 # evaluated in Python and producing a surprising answer, so we 

337 # mark it as being supported only by a SQL engine. 

338 return lhs.method( 

339 "__truediv__", 

340 rhs, 

341 dtype=b.int, 

342 supporting_engine_types={sql.Engine}, 

343 ) 

344 case ["%", ColumnExpression(dtype=b.int) as lhs, ColumnExpression(dtype=b.int) as rhs]: 

345 return lhs.method("__mod__", rhs, dtype=b.int) 

346 assert ( 

347 lhs.dtype is not None and rhs.dtype is not None 

348 ), "Expression converter should not yield untyped nodes." 

349 raise ExpressionTypeError( 

350 f"Invalid types {lhs.dtype.__name__}, {rhs.dtype.__name__} for binary operator {operator!r} " 

351 f"in expression {node!s}." 

352 ) 

353 

354 def visitIdentifier(self, name: str, node: Node) -> VisitorResult: 

355 # Docstring inherited. 

356 if name in self.bind: 

357 value = self.bind[name] 

358 if isinstance(value, list | tuple | Set): 

359 elements = [] 

360 all_dtypes = set() 

361 for item in value: 

362 dtype = type(item) 

363 all_dtypes.add(dtype) 

364 elements.append(ColumnExpression.literal(item, dtype=dtype)) 

365 if len(all_dtypes) > 1: 

366 raise ExpressionTypeError( 

367 f"Mismatched types in bind iterable: {value} has a mix of {all_dtypes}." 

368 ) 

369 elif not elements: 

370 # Empty container 

371 return ColumnContainer.sequence([]) 

372 else: 

373 (dtype,) = all_dtypes 

374 return ColumnContainer.sequence(elements, dtype=dtype) 

375 return ColumnExpression.literal(value, dtype=type(value)) 

376 tag: ColumnTag 

377 match categorizeConstant(name): 

378 case ExpressionConstant.INGEST_DATE: 

379 assert self.dataset_type_name is not None 

380 tag = DatasetColumnTag(self.dataset_type_name, "ingest_date") 

381 return ColumnExpression.reference(tag, self.column_types.ingest_date_pytype) 

382 case ExpressionConstant.NULL: 

383 return ColumnExpression.literal(None, type(None)) 

384 case None: 

385 pass 

386 case _: 

387 raise AssertionError("Check for enum values should be exhaustive.") 

388 element, column = categorizeElementId(self.universe, name) 

389 if column is not None: 

390 tag = DimensionRecordColumnTag(element.name, column) 

391 dtype = ( 

392 timespan.Timespan 

393 if column == timespan.TimespanDatabaseRepresentation.NAME 

394 else element.RecordClass.fields.standard[column].getPythonType() 

395 ) 

396 return ColumnExpression.reference(tag, dtype) 

397 else: 

398 tag = DimensionKeyColumnTag(element.name) 

399 assert isinstance(element, Dimension) 

400 return ColumnExpression.reference(tag, element.primaryKey.getPythonType()) 

401 

402 def visitIsIn( 

403 self, lhs: VisitorResult, values: list[VisitorResult], not_in: bool, node: Node 

404 ) -> VisitorResult: 

405 # Docstring inherited. 

406 clauses: list[Predicate] = [] 

407 items: list[ColumnExpression] = [] 

408 assert isinstance(lhs, ColumnExpression), "LHS of IN guaranteed to be scalar by parser." 

409 for rhs_item in values: 

410 match rhs_item: 

411 case ColumnExpressionSequence( 

412 items=rhs_items, dtype=rhs_dtype 

413 ) if rhs_dtype is None or rhs_dtype == lhs.dtype: 

414 items.extend(rhs_items) 

415 case ColumnContainer(dtype=lhs.dtype): 

416 clauses.append(rhs_item.contains(lhs)) 

417 case ColumnExpression(dtype=lhs.dtype): 

418 items.append(rhs_item) 

419 case _: 

420 raise ExpressionTypeError( 

421 f"Invalid type {rhs_item.dtype} for element in {lhs.dtype} IN expression '{node}'." 

422 ) 

423 if items: 

424 clauses.append(ColumnContainer.sequence(items, dtype=lhs.dtype).contains(lhs)) 

425 result = Predicate.logical_or(*clauses) 

426 if not_in: 

427 result = result.logical_not() 

428 return result 

429 

430 def visitNumericLiteral(self, value: str, node: Node) -> VisitorResult: 

431 # Docstring inherited. 

432 try: 

433 return ColumnExpression.literal(int(value), dtype=int) 

434 except ValueError: 

435 return ColumnExpression.literal(float(value), dtype=float) 

436 

437 def visitParens(self, expression: VisitorResult, node: Node) -> VisitorResult: 

438 # Docstring inherited. 

439 return expression 

440 

441 def visitPointNode(self, ra: VisitorResult, dec: VisitorResult, node: Node) -> VisitorResult: 

442 # Docstring inherited. 

443 

444 # this is a placeholder for future extension, we enabled syntax but 

445 # do not support actual use just yet. 

446 raise NotImplementedError("POINT() function is not supported yet") 

447 

448 def visitRangeLiteral(self, start: int, stop: int, stride: int | None, node: Node) -> VisitorResult: 

449 # Docstring inherited. 

450 return ColumnContainer.range_literal(range(start, stop + 1, stride or 1)) 

451 

452 def visitStringLiteral(self, value: str, node: Node) -> VisitorResult: 

453 # Docstring inherited. 

454 return ColumnExpression.literal(value, dtype=str) 

455 

456 def visitTimeLiteral(self, value: astropy.time.Time, node: Node) -> VisitorResult: 

457 # Docstring inherited. 

458 return ColumnExpression.literal(value, dtype=astropy.time.Time) 

459 

460 def visitTupleNode(self, items: tuple[VisitorResult, ...], node: Node) -> VisitorResult: 

461 # Docstring inherited. 

462 match items: 

463 case [ 

464 ColumnLiteral(value=begin, dtype=astropy.time.Time | types.NoneType), 

465 ColumnLiteral(value=end, dtype=astropy.time.Time | types.NoneType), 

466 ]: 

467 return ColumnExpression.literal(timespan.Timespan(begin, end), dtype=timespan.Timespan) 

468 raise ExpressionTypeError( 

469 f'Invalid type(s) ({items[0].dtype}, {items[1].dtype}) in timespan tuple "{node}" ' 

470 '(Note that date/time strings must be preceded by "T" to be recognized).' 

471 ) 

472 

473 def visitUnaryOp(self, operator: str, operand: VisitorResult, node: Node) -> VisitorResult: 

474 # Docstring inherited. 

475 match (operator, operand): 

476 case ["NOT", Predicate() as operand]: 

477 return operand.logical_not() 

478 case ["+", ColumnExpression(dtype=builtins.int | builtins.float) as operand]: 

479 return operand.method("__pos__") 

480 case ["-", ColumnExpression(dtype=builtins.int | builtins.float) as operand]: 

481 return operand.method("__neg__") 

482 raise ExpressionTypeError( 

483 f"Unary operator {operator!r} is not valid for operand of type {operand.dtype!s} in {node!s}." 

484 )