Coverage for tests/test_column_expressions.py: 5%

201 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-21 09:39 +0000

1# This file is part of daf_relation. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import unittest 

25from collections.abc import Mapping 

26 

27import sqlalchemy 

28from lsst.daf.relation import ( 

29 ColumnContainer, 

30 ColumnExpression, 

31 ColumnFunction, 

32 ColumnTag, 

33 LogicalAnd, 

34 LogicalOr, 

35 Predicate, 

36 flatten_logical_and, 

37 iteration, 

38 sql, 

39 tests, 

40) 

41 

42 

43class ColumnExpressionTestCase(tests.RelationTestCase): 

44 """Test column expressions.""" 

45 

46 def test_operator_function(self) -> None: 

47 """Test ColumnFunction via its constructor and a name found in the 

48 `operator` module. 

49 """ 

50 tag = tests.ColumnTag("tag") 

51 expression = ColumnFunction( 

52 "__sub__", 

53 (ColumnExpression.reference(tag), ColumnExpression.literal(5)), 

54 dtype=int, 

55 supporting_engine_types=None, 

56 ) 

57 self.assertEqual(expression.dtype, int) 

58 self.assertEqual(expression.supporting_engine_types, None) 

59 self.assertEqual(expression.args, (ColumnExpression.reference(tag), ColumnExpression.literal(5))) 

60 self.assertEqual(expression.name, "__sub__") 

61 self.assertEqual(expression.columns_required, {tag}) 

62 self.assertEqual( 

63 str(expression), 

64 "__sub__(tag, 5)", 

65 ) 

66 iteration_engine = iteration.Engine() 

67 self.assertTrue(expression.is_supported_by(iteration_engine)) 

68 callable = iteration_engine.convert_column_expression(expression) 

69 self.assertEqual(callable({tag: 3}), -2) 

70 self.assertEqual(callable({tag: 6}), 1) 

71 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

72 self.assertTrue(expression.is_supported_by(sql_engine)) 

73 sql_expression = sql_engine.convert_column_expression( 

74 expression, {tag: sqlalchemy.schema.Column("tag")} 

75 ) 

76 self.check_sql_str("tag - 5", sql_expression) 

77 

78 def test_method(self) -> None: 

79 """Test ColumnFunction via ColumnExpression.method and a name found 

80 on the object itself. 

81 """ 

82 tag = tests.ColumnTag("tag") 

83 expression = ColumnExpression.reference(tag).method("lower") 

84 self.assertEqual(expression.dtype, None) 

85 self.assertEqual(expression.supporting_engine_types, None) 

86 self.assertEqual(expression.args, (ColumnExpression.reference(tag),)) 

87 self.assertEqual(expression.name, "lower") 

88 self.assertEqual(expression.columns_required, {tag}) 

89 engine = iteration.Engine() 

90 self.assertTrue(expression.is_supported_by(engine)) 

91 callable = engine.convert_column_expression(expression) 

92 self.assertEqual(callable({tag: "MiXeDcAsE"}), "mixedcase") 

93 

94 def test_engine_function(self) -> None: 

95 """Test ColumnFunction via ColumnExpression.function and a name 

96 that references a callable held by the engine. 

97 """ 

98 tag = tests.ColumnTag("tag") 

99 engine = iteration.Engine() 

100 engine.functions["test_function"] = lambda x: x**2 

101 expression = ColumnExpression.function( 

102 "test_function", 

103 ColumnExpression.reference(tag), 

104 supporting_engine_types={iteration.Engine}, 

105 dtype=int, 

106 ) 

107 self.assertEqual(expression.dtype, int) 

108 self.assertEqual(expression.supporting_engine_types, (iteration.Engine,)) 

109 self.assertEqual(expression.args, (ColumnExpression.reference(tag),)) 

110 self.assertEqual(expression.name, "test_function") 

111 self.assertEqual(expression.columns_required, {tag}) 

112 self.assertTrue(expression.is_supported_by(engine)) 

113 self.assertFalse(expression.is_supported_by(sql.Engine[sqlalchemy.sql.ColumnElement]())) 

114 callable = engine.convert_column_expression(expression) 

115 self.assertEqual(callable({tag: 3}), 9) 

116 self.assertEqual(callable({tag: 4}), 16) 

117 

118 def test_operator_predicate_function(self) -> None: 

119 """Test PredicateFunction via ColumnExpression factory methods and 

120 names found in the `operator` module. 

121 """ 

122 tag = tests.ColumnTag("tag") 

123 ref = ColumnExpression.reference(tag, dtype=int) 

124 zero = ColumnExpression.literal(0, dtype=int) 

125 expressions = [ref.eq(zero), ref.ne(zero), ref.lt(zero), ref.le(zero), ref.gt(zero), ref.ge(zero)] 

126 self.assertEqual([x.dtype for x in expressions], [bool] * 6) 

127 self.assertEqual([x.supporting_engine_types for x in expressions], [None] * 6) 

128 self.assertEqual([x.args for x in expressions], [(ref, zero)] * 6) 

129 self.assertEqual( 

130 [x.name for x in expressions], ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"] 

131 ) 

132 self.assertEqual([x.columns_required for x in expressions], [{tag}] * 6) 

133 self.assertEqual( 

134 [str(x) for x in expressions], 

135 ["tag=0", "tag≠0", "tag<0", "tag≤0", "tag>0", "tag≥0"], 

136 ) 

137 iteration_engine = iteration.Engine() 

138 self.assertEqual([x.is_supported_by(iteration_engine) for x in expressions], [True] * 6) 

139 callables = [iteration_engine.convert_predicate(x) for x in expressions] 

140 self.assertEqual([c({tag: 0}) for c in callables], [True, False, False, True, False, True]) 

141 self.assertEqual([c({tag: 1}) for c in callables], [False, True, False, False, True, True]) 

142 self.assertEqual([c({tag: -1}) for c in callables], [False, True, True, True, False, False]) 

143 sql_bind = sqlalchemy.schema.Column("tag") 

144 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

145 self.assertEqual([x.is_supported_by(sql_engine) for x in expressions], [True] * 6) 

146 sql_expressions: list[sqlalchemy.sql.ColumnElement] = [ 

147 sql_engine.convert_predicate(x, {tag: sql_bind}) for x in expressions 

148 ] 

149 self.assertEqual( 

150 [tests.to_sql_str(s) for s in sql_expressions], 

151 [ 

152 "tag = 0", 

153 "tag != 0", 

154 "tag < 0", 

155 "tag <= 0", 

156 "tag > 0", 

157 "tag >= 0", 

158 ], 

159 ) 

160 

161 def test_column_expression_sequence(self) -> None: 

162 """Test ColumnExpressionSequence and ColumnInContainer.""" 

163 a = tests.ColumnTag("a") 

164 seq = ColumnContainer.sequence( 

165 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)], dtype=int 

166 ) 

167 self.assertEqual(seq.dtype, int) 

168 self.assertEqual( 

169 list(seq.items), 

170 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)], 

171 ) 

172 self.assertEqual(seq.columns_required, {a}) 

173 self.assertEqual(str(seq), "[a, 4]") 

174 iteration_engine = iteration.Engine() 

175 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

176 self.assertTrue(seq.is_supported_by(iteration_engine)) 

177 self.assertTrue(seq.is_supported_by(sql_engine)) 

178 b = tests.ColumnTag("b") 

179 contains = seq.contains(ColumnExpression.reference(b)) 

180 self.assertEqual(contains.dtype, bool) 

181 self.assertEqual(contains.columns_required, {a, b}) 

182 self.assertEqual(str(contains), "b∈[a, 4]") 

183 self.assertTrue(contains.is_supported_by(iteration_engine)) 

184 self.assertTrue(contains.is_supported_by(sql_engine)) 

185 callable = iteration_engine.convert_predicate(contains) 

186 self.assertEqual(callable({a: 0, b: 0}), True) 

187 self.assertEqual(callable({a: 0, b: 3}), False) 

188 self.assertEqual(callable({a: 0, b: 4}), True) 

189 bind_a = sqlalchemy.schema.Column("a") 

190 bind_b = sqlalchemy.schema.Column("b") 

191 self.check_sql_str("b IN (a, 4)", sql_engine.convert_predicate(contains, {a: bind_a, b: bind_b})) 

192 

193 def test_range_literal(self) -> None: 

194 """Test ColumnRangeLiteral and ColumnInContainer.""" 

195 ranges = [ 

196 ColumnContainer.range_literal(range(3, 4)), 

197 ColumnContainer.range_literal(range(3, 6)), 

198 ColumnContainer.range_literal(range(2, 11, 3)), 

199 ] 

200 self.assertEqual([r.dtype for r in ranges], [int] * 3) 

201 self.assertEqual([r.columns_required for r in ranges], [frozenset()] * 3) 

202 self.assertEqual([str(r) for r in ranges], ["[3:4:1]", "[3:6:1]", "[2:11:3]"]) 

203 a = tests.ColumnTag("a") 

204 contains = [r.contains(ColumnExpression.reference(a)) for r in ranges] 

205 self.assertEqual([c.dtype for c in contains], [bool] * 3) 

206 self.assertEqual([c.columns_required for c in contains], [{a}] * 3) 

207 self.assertEqual([str(c) for c in contains], ["a∈[3:4:1]", "a∈[3:6:1]", "a∈[2:11:3]"]) 

208 iteration_engine = iteration.Engine() 

209 self.assertEqual([c.is_supported_by(iteration_engine) for c in contains], [True] * 3) 

210 callables = [iteration_engine.convert_predicate(c) for c in contains] 

211 self.assertEqual([c({a: 3}) for c in callables], [True, True, False]) 

212 self.assertEqual([c({a: 5}) for c in callables], [False, True, True]) 

213 self.assertEqual([c({a: 8}) for c in callables], [False, False, True]) 

214 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

215 self.assertEqual([c.is_supported_by(sql_engine) for c in contains], [True] * 3) 

216 bind_a = sqlalchemy.schema.Column("a") 

217 sql_expressions = [sql_engine.convert_predicate(c, {a: bind_a}) for c in contains] 

218 self.assertEqual( 

219 [tests.to_sql_str(s) for s in sql_expressions], 

220 ["a = 3", "a BETWEEN 3 AND 5", "a BETWEEN 2 AND 10 AND a % 3 = 2"], 

221 ) 

222 

223 def test_logical_operators(self) -> None: 

224 """Test predicate logical operator expressions, Predicate.as_literal, 

225 and flatten_logical_and. 

226 """ 

227 a = tests.ColumnTag("a") 

228 b = tests.ColumnTag("b") 

229 t = Predicate.literal(True) 

230 f = Predicate.literal(False) 

231 x = ColumnExpression.reference(a).gt(ColumnExpression.literal(0)) 

232 y = Predicate.reference(b) 

233 iteration_engine = iteration.Engine() 

234 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

235 # Check attributes and simple accessors for predicate literals and 

236 # references. Some as_trivial overloads are not checked because MyPy 

237 # can tell they always return None and complains if we try to use that 

238 # return value (which also means MyPy takes care of that "test" for us, 

239 # even though coverage can't tell). 

240 self.assertEqual(t.columns_required, frozenset()) 

241 self.assertEqual(f.columns_required, frozenset()) 

242 self.assertEqual(y.columns_required, {b}) 

243 self.assertEqual(str(t), "True") 

244 self.assertEqual(str(f), "False") 

245 self.assertEqual(str(y), "b") 

246 self.assertIs(t.as_trivial(), True) 

247 self.assertIs(f.as_trivial(), False) 

248 self.assertTrue(t.is_supported_by(iteration_engine)) 

249 self.assertTrue(f.is_supported_by(iteration_engine)) 

250 self.assertTrue(y.is_supported_by(iteration_engine)) 

251 self.assertTrue(t.is_supported_by(sql_engine)) 

252 self.assertTrue(f.is_supported_by(sql_engine)) 

253 self.assertTrue(y.is_supported_by(sql_engine)) 

254 # Test factory methods for logical operators, including simplification. 

255 self.assertIs(t.logical_not().as_trivial(), False) 

256 self.assertIs(f.logical_not().as_trivial(), True) 

257 self.assertIs(t.logical_and(x).as_trivial(), None) 

258 self.assertIs(f.logical_and(x).as_trivial(), False) 

259 self.assertIs(t.logical_or(x).as_trivial(), True) 

260 self.assertIs(f.logical_or(x).as_trivial(), None) 

261 self.assertEqual(Predicate.logical_and(), t) 

262 self.assertEqual(Predicate.logical_or(), f) 

263 self.assertEqual(Predicate.logical_and(x), x) 

264 self.assertEqual(Predicate.logical_or(x), x) 

265 # Test attributes and simple accessors for logical operators. 

266 not_x = x.logical_not() 

267 self.assertEqual(not_x.columns_required, {a}) 

268 self.assertEqual(str(not_x), "not (a>0)") 

269 self.assertIs(not_x.as_trivial(), None) 

270 self.assertTrue(not_x.is_supported_by(iteration_engine)) 

271 self.assertTrue(not_x.is_supported_by(sql_engine)) 

272 x_and_y = x.logical_and(y) 

273 self.assertEqual(x_and_y.columns_required, {a, b}) 

274 self.assertEqual(str(x_and_y), "a>0 and b") 

275 self.assertIs(x_and_y.as_trivial(), None) 

276 self.assertTrue(x_and_y.is_supported_by(iteration_engine)) 

277 self.assertTrue(x_and_y.is_supported_by(sql_engine)) 

278 x_or_y = x.logical_or(y) 

279 self.assertEqual(x_or_y.columns_required, {a, b}) 

280 self.assertEqual(str(x_or_y), "a>0 or b") 

281 self.assertIs(x_or_y.as_trivial(), None) 

282 self.assertTrue(x_or_y.is_supported_by(iteration_engine)) 

283 self.assertTrue(x_or_y.is_supported_by(sql_engine)) 

284 # Test iteration engine conversions. 

285 self.assertEqual(iteration_engine.convert_predicate(t)({}), True) 

286 self.assertEqual(iteration_engine.convert_predicate(f)({}), False) 

287 self.assertEqual(iteration_engine.convert_predicate(y)({b: True}), True) 

288 self.assertEqual(iteration_engine.convert_predicate(y)({b: False}), False) 

289 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 1}), False) 

290 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 0}), True) 

291 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: True}), True) 

292 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: False}), False) 

293 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: True}), False) 

294 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: False}), False) 

295 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: True}), True) 

296 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: False}), True) 

297 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: True}), True) 

298 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: False}), False) 

299 # Test SQL engine conversions. 

300 columns: Mapping[ColumnTag, sqlalchemy.sql.ColumnElement] = { 

301 a: sqlalchemy.schema.Column("a"), 

302 b: sqlalchemy.schema.Column("b"), 

303 } 

304 self.check_sql_str( 

305 "1", 

306 sql_engine.convert_predicate(t, columns), 

307 sql_engine.convert_predicate(LogicalAnd(()), columns), 

308 ) 

309 self.check_sql_str( 

310 "0", 

311 sql_engine.convert_predicate(f, columns), 

312 sql_engine.convert_predicate(LogicalOr(()), columns), 

313 ) 

314 self.check_sql_str( 

315 "b", 

316 sql_engine.convert_predicate(y, columns), 

317 sql_engine.convert_predicate(LogicalAnd((y,)), columns), 

318 sql_engine.convert_predicate(LogicalOr((y,)), columns), 

319 ) 

320 # Apparently SQLAlchemy does some simplifications of its own on NOT 

321 # operations. 

322 self.check_sql_str("a <= 0", sql_engine.convert_predicate(not_x, columns)) 

323 self.check_sql_str("a > 0 AND b", sql_engine.convert_predicate(x_and_y, columns)) 

324 self.check_sql_str("a > 0 OR b", sql_engine.convert_predicate(x_or_y, columns)) 

325 # Test flatten_logical_and 

326 self.assertEqual(flatten_logical_and(t), []) 

327 self.assertIs(flatten_logical_and(f), False) 

328 self.assertEqual(flatten_logical_and(not_x), [not_x]) 

329 self.assertEqual(flatten_logical_and(x_and_y), [x, y]) 

330 self.assertEqual(flatten_logical_and(x_or_y), [x_or_y]) 

331 self.assertEqual(flatten_logical_and(x_and_y.logical_and(t)), [x, y]) 

332 self.assertEqual(flatten_logical_and(x_and_y.logical_and(f)), False) 

333 c = tests.ColumnTag("c") 

334 z = Predicate.reference(c) 

335 self.assertEqual(flatten_logical_and(x_and_y.logical_and(z)), [x, y, z]) 

336 

337 

338if __name__ == "__main__": 

339 unittest.main()