Coverage for python/lsst/daf/butler/direct_query_driver/_sql_column_visitor.py: 16%

123 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 09:55 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("SqlColumnVisitor",) 

31 

32from typing import TYPE_CHECKING, Any 

33 

34import sqlalchemy 

35 

36from .. import ddl 

37from ..queries import tree as qt 

38from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, PredicateVisitor 

39from ..timespan_database_representation import TimespanDatabaseRepresentation 

40 

41if TYPE_CHECKING: 

42 from ._driver import DirectQueryDriver 

43 from ._query_builder import QueryJoiner 

44 

45 

46class SqlColumnVisitor( 

47 ColumnExpressionVisitor[sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation], 

48 PredicateVisitor[ 

49 sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool], sqlalchemy.ColumnElement[bool] 

50 ], 

51): 

52 """A column expression visitor that constructs a `sqlalchemy.ColumnElement` 

53 expression tree. 

54 

55 Parameters 

56 ---------- 

57 joiner : `QueryJoiner` 

58 `QueryJoiner` that provides SQL columns for column-reference 

59 expressions. 

60 driver : `QueryDriver` 

61 Driver used to construct nested queries for "in query" predicates. 

62 """ 

63 

64 def __init__(self, joiner: QueryJoiner, driver: DirectQueryDriver): 

65 self._driver = driver 

66 self._joiner = joiner 

67 

68 def visit_literal( 

69 self, expression: qt.ColumnLiteral 

70 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: 

71 # Docstring inherited. 

72 if expression.column_type == "timespan": 

73 return self._driver.db.getTimespanRepresentation().fromLiteral(expression.get_literal_value()) 

74 return sqlalchemy.literal( 

75 expression.get_literal_value(), type_=ddl.VALID_CONFIG_COLUMN_TYPES[expression.column_type] 

76 ) 

77 

78 def visit_dimension_key_reference( 

79 self, expression: qt.DimensionKeyReference 

80 ) -> sqlalchemy.ColumnElement[int | str]: 

81 # Docstring inherited. 

82 return self._joiner.dimension_keys[expression.dimension.name][0] 

83 

84 def visit_dimension_field_reference( 

85 self, expression: qt.DimensionFieldReference 

86 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: 

87 # Docstring inherited. 

88 if expression.column_type == "timespan": 

89 return self._joiner.timespans[expression.element.name] 

90 return self._joiner.fields[expression.element.name][expression.field] 

91 

92 def visit_dataset_field_reference( 

93 self, expression: qt.DatasetFieldReference 

94 ) -> sqlalchemy.ColumnElement[Any] | TimespanDatabaseRepresentation: 

95 # Docstring inherited. 

96 if expression.column_type == "timespan": 

97 return self._joiner.timespans[expression.dataset_type] 

98 return self._joiner.fields[expression.dataset_type][expression.field] 

99 

100 def visit_unary_expression(self, expression: qt.UnaryExpression) -> sqlalchemy.ColumnElement[Any]: 

101 # Docstring inherited. 

102 match expression.operator: 

103 case "-": 

104 return -self.expect_scalar(expression.operand) 

105 case "begin_of": 

106 return self.expect_timespan(expression.operand).lower() 

107 case "end_of": 

108 return self.expect_timespan(expression.operand).upper() 

109 raise AssertionError(f"Invalid unary expression operator {expression.operator!r}.") 

110 

111 def visit_binary_expression(self, expression: qt.BinaryExpression) -> sqlalchemy.ColumnElement[Any]: 

112 # Docstring inherited. 

113 a = self.expect_scalar(expression.a) 

114 b = self.expect_scalar(expression.b) 

115 match expression.operator: 

116 case "+": 

117 return a + b 

118 case "-": 

119 return a - b 

120 case "*": 

121 return a * b 

122 case "/": 

123 return a / b 

124 case "%": 

125 return a % b 

126 raise AssertionError(f"Invalid binary expression operator {expression.operator!r}.") 

127 

128 def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[Any]: 

129 # Docstring inherited. 

130 return self.expect_scalar(expression.operand).desc() 

131 

132 def visit_comparison( 

133 self, 

134 a: qt.ColumnExpression, 

135 operator: qt.ComparisonOperator, 

136 b: qt.ColumnExpression, 

137 flags: PredicateVisitFlags, 

138 ) -> sqlalchemy.ColumnElement[bool]: 

139 # Docstring inherited. 

140 if operator == "overlaps": 

141 assert a.column_type == "timespan", "Spatial overlaps should be transformed away by now." 

142 return self.expect_timespan(a).overlaps(self.expect_timespan(b)) 

143 lhs = self.expect_scalar(a) 

144 rhs = self.expect_scalar(b) 

145 match operator: 

146 case "==": 

147 return lhs == rhs 

148 case "!=": 

149 return lhs != rhs 

150 case "<": 

151 return lhs < rhs 

152 case ">": 

153 return lhs > rhs 

154 case "<=": 

155 return lhs <= rhs 

156 case ">=": 

157 return lhs >= rhs 

158 raise AssertionError(f"Invalid comparison operator {operator!r}.") 

159 

160 def visit_is_null( 

161 self, operand: qt.ColumnExpression, flags: PredicateVisitFlags 

162 ) -> sqlalchemy.ColumnElement[bool]: 

163 # Docstring inherited. 

164 if operand.column_type == "timespan": 

165 return self.expect_timespan(operand).isNull() 

166 return self.expect_scalar(operand) == sqlalchemy.null() 

167 

168 def visit_in_container( 

169 self, 

170 member: qt.ColumnExpression, 

171 container: tuple[qt.ColumnExpression, ...], 

172 flags: PredicateVisitFlags, 

173 ) -> sqlalchemy.ColumnElement[bool]: 

174 # Docstring inherited. 

175 return self.expect_scalar(member).in_([self.expect_scalar(item) for item in container]) 

176 

177 def visit_in_range( 

178 self, member: qt.ColumnExpression, start: int, stop: int | None, step: int, flags: PredicateVisitFlags 

179 ) -> sqlalchemy.ColumnElement[bool]: 

180 # Docstring inherited. 

181 sql_member = self.expect_scalar(member) 

182 if stop is None: 

183 target = sql_member >= sqlalchemy.literal(start) 

184 else: 

185 stop_inclusive = stop - 1 

186 if start == stop_inclusive: 

187 return sql_member == sqlalchemy.literal(start) 

188 else: 

189 target = sqlalchemy.sql.between( 

190 sql_member, 

191 sqlalchemy.literal(start), 

192 sqlalchemy.literal(stop_inclusive), 

193 ) 

194 if step != 1: 

195 return sqlalchemy.sql.and_( 

196 *[ 

197 target, 

198 sql_member % sqlalchemy.literal(step) == sqlalchemy.literal(start % step), 

199 ] 

200 ) 

201 else: 

202 return target 

203 

204 def visit_in_query_tree( 

205 self, 

206 member: qt.ColumnExpression, 

207 column: qt.ColumnExpression, 

208 query_tree: qt.QueryTree, 

209 flags: PredicateVisitFlags, 

210 ) -> sqlalchemy.ColumnElement[bool]: 

211 # Docstring inherited. 

212 columns = qt.ColumnSet(self._driver.universe.empty.as_group()) 

213 column.gather_required_columns(columns) 

214 _, builder = self._driver.build_query(query_tree, columns) 

215 if builder.postprocessing: 

216 raise NotImplementedError( 

217 "Right-hand side subquery in IN expression would require postprocessing." 

218 ) 

219 subquery_visitor = SqlColumnVisitor(builder.joiner, self._driver) 

220 builder.joiner.special["_MEMBER"] = subquery_visitor.expect_scalar(column) 

221 builder.columns = qt.ColumnSet(self._driver.universe.empty.as_group()) 

222 subquery_select = builder.select() 

223 sql_member = self.expect_scalar(member) 

224 return sql_member.in_(subquery_select) 

225 

226 def apply_logical_and( 

227 self, originals: qt.PredicateOperands, results: tuple[sqlalchemy.ColumnElement[bool], ...] 

228 ) -> sqlalchemy.ColumnElement[bool]: 

229 # Docstring inherited. 

230 match len(results): 

231 case 0: 

232 return sqlalchemy.true() 

233 case 1: 

234 return results[0] 

235 case _: 

236 return sqlalchemy.and_(*results) 

237 

238 def apply_logical_or( 

239 self, 

240 originals: tuple[qt.PredicateLeaf, ...], 

241 results: tuple[sqlalchemy.ColumnElement[bool], ...], 

242 flags: PredicateVisitFlags, 

243 ) -> sqlalchemy.ColumnElement[bool]: 

244 # Docstring inherited. 

245 match len(results): 

246 case 0: 

247 return sqlalchemy.false() 

248 case 1: 

249 return results[0] 

250 case _: 

251 return sqlalchemy.or_(*results) 

252 

253 def apply_logical_not( 

254 self, original: qt.PredicateLeaf, result: sqlalchemy.ColumnElement[bool], flags: PredicateVisitFlags 

255 ) -> sqlalchemy.ColumnElement[bool]: 

256 # Docstring inherited. 

257 return sqlalchemy.not_(result) 

258 

259 def expect_scalar(self, expression: qt.OrderExpression) -> sqlalchemy.ColumnElement[Any]: 

260 result = expression.visit(self) 

261 assert isinstance(result, sqlalchemy.ColumnElement) 

262 return result 

263 

264 def expect_timespan(self, expression: qt.ColumnExpression) -> TimespanDatabaseRepresentation: 

265 result = expression.visit(self) 

266 assert isinstance(result, TimespanDatabaseRepresentation) 

267 return result