Coverage for python / lsst / daf / butler / direct_query_driver / _sql_column_visitor.py: 15%

158 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:49 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("SqlColumnVisitor",) 

31 

32import warnings 

33from typing import TYPE_CHECKING, Any 

34 

35import erfa 

36import sqlalchemy 

37 

38from .. import ddl 

39from .._exceptions import InvalidQueryError 

40from ..queries import tree as qt 

41from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor 

42from ..timespan_database_representation import TimespanDatabaseRepresentation 

43 

44if TYPE_CHECKING: 

45 from ._driver import DirectQueryDriver 

46 from ._sql_builders import SqlColumns 

47 

48 

49class SqlColumnVisitor( 

50 ColumnExpressionVisitor[sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation], 

51 PredicateVisitor[ 

52 sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool] 

53 ], 

54): 

55 """A column expression visitor that constructs a `sqlalchemy.ColumnElement` 

56 expression tree. 

57 

58 Parameters 

59 ---------- 

60 columns : `QueryColumns` 

61 `QueryColumns` that provides SQL columns for column-reference 

62 expressions. 

63 driver : `QueryDriver` 

64 Driver used to construct nested queries for "in query" predicates. 

65 """ 

66 

67 def __init__(self, columns: SqlColumns, driver: DirectQueryDriver): 

68 self._driver = driver 

69 self._columns = columns 

70 

71 def visit_literal( 

72 self, expression: qt.ColumnLiteral 

73 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: 

74 # Docstring inherited. 

75 if expression.column_type == "timespan": 

76 return self._driver.db.getTimespanRepresentation().fromLiteral(expression.get_literal_value()) 

77 return sqlalchemy.literal( 

78 expression.get_literal_value(), type_=ddl.VALID_CONFIG_COLUMN_TYPES[expression.column_type] 

79 ) 

80 

81 def visit_dimension_key_reference( 

82 self, expression: qt.DimensionKeyReference 

83 ) -> sqlalchemy.ColumnElement[int | str]: 

84 # Docstring inherited. 

85 return self._columns.dimension_keys[expression.dimension.name][0] 

86 

87 def visit_dimension_field_reference( 

88 self, expression: qt.DimensionFieldReference 

89 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: 

90 # Docstring inherited. 

91 if expression.column_type == "timespan": 

92 return self._columns.timespans[expression.element.name] 

93 return self._columns.fields[expression.element.name][expression.field] 

94 

95 def visit_dataset_field_reference( 

96 self, expression: qt.DatasetFieldReference 

97 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: 

98 # Docstring inherited. 

99 if expression.column_type == "timespan": 

100 return self._columns.timespans[expression.dataset_type] 

101 return self._columns.fields[expression.dataset_type][expression.field] 

102 

103 def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]: 

104 # Docstring inherited. 

105 match expression.operator: 

106 case "-": 

107 return -self.expect_scalar(expression.operand) 

108 case "begin_of": 

109 return self.expect_timespan(expression.operand).lower() 

110 case "end_of": 

111 return self.expect_timespan(expression.operand).upper() 

112 raise AssertionError(f"Invalid unary expression operator {expression.operator!r}.") 

113 

114 def visit_binary_expression(self, expression: qt.BinaryExpression) -> sqlalchemy.ColumnElement[Any]: 

115 # Docstring inherited. 

116 a = self.expect_scalar(expression.a) 

117 b = self.expect_scalar(expression.b) 

118 match expression.operator: 

119 case "+": 

120 return a + b 

121 case "-": 

122 return a - b 

123 case "*": 

124 return a * b 

125 case "/": 

126 return a / b 

127 case "%": 

128 return a % b 

129 raise AssertionError(f"Invalid binary expression operator {expression.operator!r}.") 

130 

131 def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[Any]: 

132 # Docstring inherited. 

133 return self.expect_scalar(expression.operand).desc() 

134 

135 def visit_boolean_wrapper( 

136 self, value: qt.ColumnExpression, flags: PredicateVisitFlags 

137 ) -> sqlalchemy.ColumnElement[bool]: 

138 return self.expect_scalar(value) 

139 

140 def visit_comparison( 

141 self, 

142 a: qt.ColumnExpression, 

143 operator: qt.ComparisonOperator, 

144 b: qt.ColumnExpression, 

145 flags: PredicateVisitFlags, 

146 ) -> sqlalchemy.ColumnElement[bool]: 

147 # Docstring inherited. 

148 if operator == "overlaps": 

149 if values := (qt.is_one_timespan_and_one_datetime(a, b)): 

150 return self.expect_timespan(values.timespan).contains(self.expect_scalar(values.datetime)) 

151 elif values := (qt.is_one_timespan_and_one_ingest_date(a, b)): 

152 return self.expect_timespan(values.timespan).contains( 

153 self._convert_ingest_date_to_datetime(values.datetime) 

154 ) 

155 elif a.column_type == "timespan" and b.column_type == "timespan": 

156 return self.expect_timespan(a).overlaps(self.expect_timespan(b)) 

157 else: 

158 # Spatial overlaps should be transformed away by now. 

159 raise AssertionError( 

160 f"Unexpected types {a.column_type},{b.column_type} in overlaps operator." 

161 ) 

162 

163 if operator == "glob": 

164 # Second operand must be a string literal. 

165 if isinstance(b, qt._column_literal.StringColumnLiteral): 

166 pattern = b.value 

167 expression = self.expect_scalar(a) 

168 return self._driver.db.glob_expression(expression, pattern) 

169 else: 

170 raise AssertionError(f"Unexpected pattern type ({type(b)}) in glob operator.") 

171 

172 lhs = self.expect_scalar(a) 

173 rhs = self.expect_scalar(b) 

174 # Special case to handle awkward situation where ingest_date is not 

175 # always the same type as other datetime columns. 

176 if qt.is_one_datetime_and_one_ingest_date(a, b): 

177 if a.column_type == "datetime": 

178 lhs = self._convert_datetime_to_ingest_date(a) 

179 elif b.column_type == "datetime": 

180 rhs = self._convert_datetime_to_ingest_date(b) 

181 

182 match operator: 

183 case "==": 

184 return lhs == rhs 

185 case "!=": 

186 return lhs != rhs 

187 case "<": 

188 return lhs < rhs 

189 case ">": 

190 return lhs > rhs 

191 case "<=": 

192 return lhs <= rhs 

193 case ">=": 

194 return lhs >= rhs 

195 raise AssertionError(f"Invalid comparison operator {operator!r}.") 

196 

197 def visit_is_null( 

198 self, operand: qt.ColumnExpression, flags: PredicateVisitFlags 

199 ) -> sqlalchemy.ColumnElement[bool]: 

200 # Docstring inherited. 

201 if operand.column_type == "timespan": 

202 return self.expect_timespan(operand).isNull() 

203 return self.expect_scalar(operand) == sqlalchemy.null() 

204 

205 def visit_in_container( 

206 self, 

207 member: qt.ColumnExpression, 

208 container: tuple[qt.ColumnExpression, ...], 

209 flags: PredicateVisitFlags, 

210 ) -> sqlalchemy.ColumnElement[bool]: 

211 # Docstring inherited. 

212 return self.expect_scalar(member).in_([self.expect_scalar(item) for item in container]) 

213 

214 def visit_in_range( 

215 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags 

216 ) -> sqlalchemy.ColumnElement[bool]: 

217 # Docstring inherited. 

218 sql_member = self.expect_scalar(member) 

219 if stop is None: 

220 target = sql_member >= sqlalchemy.literal(start) 

221 else: 

222 stop_inclusive = stop - 1 

223 if start == stop_inclusive: 

224 return sql_member == sqlalchemy.literal(start) 

225 else: 

226 target = sqlalchemy.sql.between( 

227 sql_member, 

228 sqlalchemy.literal(start), 

229 sqlalchemy.literal(stop_inclusive), 

230 ) 

231 if step != 1: 

232 return sqlalchemy.sql.and_( 

233 *[ 

234 target, 

235 sql_member % sqlalchemy.literal(step) == sqlalchemy.literal(start % step), 

236 ] 

237 ) 

238 else: 

239 return target 

240 

241 def visit_in_query_tree( 

242 self, 

243 member: qt.ColumnExpression, 

244 column: qt.ColumnExpression, 

245 query_tree: qt.QueryTree, 

246 flags: PredicateVisitFlags, 

247 ) -> sqlalchemy.ColumnElement[bool]: 

248 # Docstring inherited. 

249 columns = qt.ColumnSet(self._driver.universe.empty) 

250 column.gather_required_columns(columns) 

251 builder = self._driver.build_query(query_tree, columns) 

252 if builder.postprocessing: 

253 raise NotImplementedError( 

254 "Right-hand side subquery in IN expression would require postprocessing." 

255 ) 

256 select_builder = builder.finish_nested() 

257 subquery_visitor = SqlColumnVisitor(select_builder.joins, self._driver) 

258 select_builder.joins.special["_MEMBER"] = subquery_visitor.expect_scalar(column) 

259 select_builder.columns = qt.ColumnSet(self._driver.universe.empty) 

260 subquery_select = select_builder.select(postprocessing=None) 

261 sql_member = self.expect_scalar(member) 

262 return sql_member.in_(subquery_select) 

263 

264 def apply_logical_and( 

265 self, originals: qt.PredicateOperands, results: tuple[sqlalchemy.ColumnElement[bool], ...] 

266 ) -> sqlalchemy.ColumnElement[bool]: 

267 # Docstring inherited. 

268 match len(results): 

269 case 0: 

270 return sqlalchemy.true() 

271 case 1: 

272 return results[0] 

273 case _: 

274 return sqlalchemy.and_(*results) 

275 

276 def apply_logical_or( 

277 self, 

278 originals: tuple[qt.PredicateLeaf, ...], 

279 results: tuple[sqlalchemy.ColumnElement[bool], ...], 

280 flags: PredicateVisitFlags, 

281 ) -> sqlalchemy.ColumnElement[bool]: 

282 # Docstring inherited. 

283 match len(results): 

284 case 0: 

285 return sqlalchemy.false() 

286 case 1: 

287 return results[0] 

288 case _: 

289 return sqlalchemy.or_(*results) 

290 

291 def apply_logical_not( 

292 self, original: qt.PredicateLeaf, result: sqlalchemy.ColumnElement[bool], flags: PredicateVisitFlags 

293 ) -> sqlalchemy.ColumnElement[bool]: 

294 # Docstring inherited. 

295 return sqlalchemy.not_(result) 

296 

297 def expect_scalar(self, expression: qt.OrderExpression) -> sqlalchemy.ColumnElement[Any]: 

298 result = expression.visit(self) 

299 assert isinstance(result, sqlalchemy.ColumnElement) 

300 return result 

301 

302 def _convert_datetime_to_ingest_date(self, expression: qt.ColumnExpression) -> sqlalchemy.ColumnElement: 

303 assert expression.column_type == "datetime" 

304 if self._driver.managers.datasets.ingest_date_dtype() == sqlalchemy.TIMESTAMP: 

305 # Datasets manager v1 schema has "datasets" table's "ingest_date" 

306 # column as TIMESTAMP, but the rest of the database schema and 

307 # query system uses integer TAI nanoseconds. So we have to convert 

308 # the nanoseconds value to a timestamp. 

309 # Note that this loses precision. 

310 if expression.expression_type != "datetime": 

311 # Conversion between TAI and UTC can't be done in the database, 

312 # so we are only able to handle literal values here. 

313 raise InvalidQueryError("Only literal date-time values can be compared with ingest date.") 

314 

315 # The conversion from TAI to UTC can trigger a warning for dates 

316 # in the future so catch those warnings. 

317 with warnings.catch_warnings(): 

318 warnings.simplefilter("ignore", category=erfa.ErfaWarning) 

319 dt = expression.value.utc.to_datetime() 

320 return sqlalchemy.literal(dt) 

321 else: 

322 # For v2 schema, ingest_date uses TAI nanoseconds like everything 

323 # else, so no conversion is required. 

324 return self.expect_scalar(expression) 

325 

326 def _convert_ingest_date_to_datetime(self, expression: qt.ColumnExpression) -> sqlalchemy.ColumnElement: 

327 assert expression.column_type == "ingest_date" 

328 if self._driver.managers.datasets.ingest_date_dtype() == sqlalchemy.TIMESTAMP: 

329 # Datasets manager v1 schema has "datasets" table's "ingest_date" 

330 # column as TIMESTAMP, but the rest of the database schema and 

331 # query system uses integer TAI nanoseconds. 

332 raise InvalidQueryError( 

333 "Expressions involving ingest date are not supported by this database schema." 

334 ) 

335 else: 

336 # For v2 schema, ingest_date uses TAI nanoseconds like everything 

337 # else, so no conversion is required. 

338 return self.expect_scalar(expression) 

339 

340 def expect_timespan(self, expression: qt.ColumnExpression) -> TimespanDatabaseRepresentation: 

341 result = expression.visit(self) 

342 assert isinstance(result, TimespanDatabaseRepresentation) 

343 return result