Coverage for tests/test_column_expressions.py: 6%

203 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-26 13:07 +0000

1# This file is part of daf_relation. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import unittest 

25from collections.abc import Mapping 

26 

27import sqlalchemy 

28from lsst.daf.relation import ( 

29 ColumnContainer, 

30 ColumnExpression, 

31 ColumnFunction, 

32 ColumnTag, 

33 LogicalAnd, 

34 LogicalOr, 

35 Predicate, 

36 flatten_logical_and, 

37 iteration, 

38 sql, 

39 tests, 

40) 

41 

42 

43class ColumnExpressionTestCase(tests.RelationTestCase): 

44 def test_operator_function(self) -> None: 

45 """Test ColumnFunction via its constructor and a name found in the 

46 `operator` module. 

47 """ 

48 tag = tests.ColumnTag("tag") 

49 expression = ColumnFunction( 

50 "__sub__", 

51 (ColumnExpression.reference(tag), ColumnExpression.literal(5)), 

52 dtype=int, 

53 supporting_engine_types=None, 

54 ) 

55 self.assertEqual(expression.dtype, int) 

56 self.assertEqual(expression.supporting_engine_types, None) 

57 self.assertEqual(expression.args, (ColumnExpression.reference(tag), ColumnExpression.literal(5))) 

58 self.assertEqual(expression.name, "__sub__") 

59 self.assertEqual(expression.columns_required, {tag}) 

60 self.assertEqual( 

61 str(expression), 

62 "__sub__(tag, 5)", 

63 ) 

64 iteration_engine = iteration.Engine() 

65 self.assertTrue(expression.is_supported_by(iteration_engine)) 

66 callable = iteration_engine.convert_column_expression(expression) 

67 self.assertEqual(callable({tag: 3}), -2) 

68 self.assertEqual(callable({tag: 6}), 1) 

69 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

70 self.assertTrue(expression.is_supported_by(sql_engine)) 

71 sql_expression = sql_engine.convert_column_expression( 

72 expression, {tag: sqlalchemy.schema.Column("tag")} 

73 ) 

74 self.check_sql_str("tag - 5", sql_expression) 

75 

76 def test_method(self) -> None: 

77 """Test ColumnFunction via ColumnExpression.method and a name found 

78 on the object itself. 

79 """ 

80 tag = tests.ColumnTag("tag") 

81 expression = ColumnExpression.reference(tag).method("lower") 

82 self.assertEqual(expression.dtype, None) 

83 self.assertEqual(expression.supporting_engine_types, None) 

84 self.assertEqual(expression.args, (ColumnExpression.reference(tag),)) 

85 self.assertEqual(expression.name, "lower") 

86 self.assertEqual(expression.columns_required, {tag}) 

87 engine = iteration.Engine() 

88 self.assertTrue(expression.is_supported_by(engine)) 

89 callable = engine.convert_column_expression(expression) 

90 self.assertEqual(callable({tag: "MiXeDcAsE"}), "mixedcase") 

91 

92 def test_engine_function(self) -> None: 

93 """Test ColumnFunction via ColumnExpression.function and a name 

94 that references a callable held by the engine. 

95 """ 

96 tag = tests.ColumnTag("tag") 

97 engine = iteration.Engine() 

98 engine.functions["test_function"] = lambda x: x**2 

99 expression = ColumnExpression.function( 

100 "test_function", 

101 ColumnExpression.reference(tag), 

102 supporting_engine_types={iteration.Engine}, 

103 dtype=int, 

104 ) 

105 self.assertEqual(expression.dtype, int) 

106 self.assertEqual(expression.supporting_engine_types, (iteration.Engine,)) 

107 self.assertEqual(expression.args, (ColumnExpression.reference(tag),)) 

108 self.assertEqual(expression.name, "test_function") 

109 self.assertEqual(expression.columns_required, {tag}) 

110 self.assertTrue(expression.is_supported_by(engine)) 

111 self.assertFalse(expression.is_supported_by(sql.Engine[sqlalchemy.sql.ColumnElement]())) 

112 callable = engine.convert_column_expression(expression) 

113 self.assertEqual(callable({tag: 3}), 9) 

114 self.assertEqual(callable({tag: 4}), 16) 

115 

116 def test_operator_predicate_function(self) -> None: 

117 """Test PredicateFunction via ColumnExpression factory methods and 

118 names found in the `operator` module. 

119 """ 

120 tag = tests.ColumnTag("tag") 

121 ref = ColumnExpression.reference(tag, dtype=int) 

122 zero = ColumnExpression.literal(0, dtype=int) 

123 expressions = [ref.eq(zero), ref.ne(zero), ref.lt(zero), ref.le(zero), ref.gt(zero), ref.ge(zero)] 

124 self.assertEqual([x.dtype for x in expressions], [bool] * 6) 

125 self.assertEqual([x.supporting_engine_types for x in expressions], [None] * 6) 

126 self.assertEqual([x.args for x in expressions], [(ref, zero)] * 6) 

127 self.assertEqual( 

128 [x.name for x in expressions], ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"] 

129 ) 

130 self.assertEqual([x.columns_required for x in expressions], [{tag}] * 6) 

131 self.assertEqual( 

132 [str(x) for x in expressions], 

133 ["tag=0", "tag≠0", "tag<0", "tag≤0", "tag>0", "tag≥0"], 

134 ) 

135 iteration_engine = iteration.Engine() 

136 self.assertEqual([x.is_supported_by(iteration_engine) for x in expressions], [True] * 6) 

137 callables = [iteration_engine.convert_predicate(x) for x in expressions] 

138 self.assertEqual([c({tag: 0}) for c in callables], [True, False, False, True, False, True]) 

139 self.assertEqual([c({tag: 1}) for c in callables], [False, True, False, False, True, True]) 

140 self.assertEqual([c({tag: -1}) for c in callables], [False, True, True, True, False, False]) 

141 sql_bind = sqlalchemy.schema.Column("tag") 

142 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

143 self.assertEqual([x.is_supported_by(sql_engine) for x in expressions], [True] * 6) 

144 sql_expressions: list[sqlalchemy.sql.ColumnElement] = [ 

145 sql_engine.convert_predicate(x, {tag: sql_bind}) for x in expressions 

146 ] 

147 self.assertEqual( 

148 [tests.to_sql_str(s) for s in sql_expressions], 

149 [ 

150 "tag = 0", 

151 "tag != 0", 

152 "tag < 0", 

153 "tag <= 0", 

154 "tag > 0", 

155 "tag >= 0", 

156 ], 

157 ) 

158 

159 def test_column_expression_sequence(self) -> None: 

160 """Test ColumnExpressionSequence and ColumnInContainer.""" 

161 a = tests.ColumnTag("a") 

162 seq = ColumnContainer.sequence( 

163 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)], dtype=int 

164 ) 

165 self.assertEqual(seq.dtype, int) 

166 self.assertEqual( 

167 list(seq.items), 

168 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)], 

169 ) 

170 self.assertEqual(seq.columns_required, {a}) 

171 self.assertEqual(str(seq), "[a, 4]") 

172 iteration_engine = iteration.Engine() 

173 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

174 self.assertTrue(seq.is_supported_by(iteration_engine)) 

175 self.assertTrue(seq.is_supported_by(sql_engine)) 

176 b = tests.ColumnTag("b") 

177 contains = seq.contains(ColumnExpression.reference(b)) 

178 self.assertEqual(contains.dtype, bool) 

179 self.assertEqual(contains.columns_required, {a, b}) 

180 self.assertEqual(str(contains), "b∈[a, 4]") 

181 self.assertTrue(contains.is_supported_by(iteration_engine)) 

182 self.assertTrue(contains.is_supported_by(sql_engine)) 

183 callable = iteration_engine.convert_predicate(contains) 

184 self.assertEqual(callable({a: 0, b: 0}), True) 

185 self.assertEqual(callable({a: 0, b: 3}), False) 

186 self.assertEqual(callable({a: 0, b: 4}), True) 

187 bind_a = sqlalchemy.schema.Column("a") 

188 bind_b = sqlalchemy.schema.Column("b") 

189 self.check_sql_str("b IN (a, 4)", sql_engine.convert_predicate(contains, {a: bind_a, b: bind_b})) 

190 

191 def test_range_literal(self) -> None: 

192 """Test ColumnRangeLiteral and ColumnInContainer.""" 

193 ranges = [ 

194 ColumnContainer.range_literal(range(3, 4)), 

195 ColumnContainer.range_literal(range(3, 6)), 

196 ColumnContainer.range_literal(range(2, 11, 3)), 

197 ] 

198 self.assertEqual([r.dtype for r in ranges], [int] * 3) 

199 self.assertEqual([r.columns_required for r in ranges], [frozenset()] * 3) 

200 self.assertEqual([str(r) for r in ranges], ["[3:4:1]", "[3:6:1]", "[2:11:3]"]) 

201 a = tests.ColumnTag("a") 

202 contains = [r.contains(ColumnExpression.reference(a)) for r in ranges] 

203 self.assertEqual([c.dtype for c in contains], [bool] * 3) 

204 self.assertEqual([c.columns_required for c in contains], [{a}] * 3) 

205 self.assertEqual([str(c) for c in contains], ["a∈[3:4:1]", "a∈[3:6:1]", "a∈[2:11:3]"]) 

206 iteration_engine = iteration.Engine() 

207 self.assertEqual([c.is_supported_by(iteration_engine) for c in contains], [True] * 3) 

208 callables = [iteration_engine.convert_predicate(c) for c in contains] 

209 self.assertEqual([c({a: 3}) for c in callables], [True, True, False]) 

210 self.assertEqual([c({a: 5}) for c in callables], [False, True, True]) 

211 self.assertEqual([c({a: 8}) for c in callables], [False, False, True]) 

212 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

213 self.assertEqual([c.is_supported_by(sql_engine) for c in contains], [True] * 3) 

214 bind_a = sqlalchemy.schema.Column("a") 

215 sql_expressions = [sql_engine.convert_predicate(c, {a: bind_a}) for c in contains] 

216 self.assertEqual( 

217 [tests.to_sql_str(s) for s in sql_expressions], 

218 ["a = 3", "a BETWEEN 3 AND 5", "a BETWEEN 2 AND 10 AND a % 3 = 2"], 

219 ) 

220 

221 def test_logical_operators(self) -> None: 

222 """Test predicate logical operator expressions, Predicate.as_literal, 

223 and flatten_logical_and. 

224 """ 

225 a = tests.ColumnTag("a") 

226 b = tests.ColumnTag("b") 

227 t = Predicate.literal(True) 

228 f = Predicate.literal(False) 

229 x = ColumnExpression.reference(a).gt(ColumnExpression.literal(0)) 

230 y = Predicate.reference(b) 

231 iteration_engine = iteration.Engine() 

232 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]() 

233 # Check attributes and simple accessors for predicate literals and 

234 # references. Some as_trivial overloads are not checked because MyPy 

235 # can tell they always return None and complains if we try to use that 

236 # return value (which also means MyPy takes care of that "test" for us, 

237 # even though coverage can't tell). 

238 self.assertEqual(t.columns_required, frozenset()) 

239 self.assertEqual(f.columns_required, frozenset()) 

240 self.assertEqual(y.columns_required, {b}) 

241 self.assertEqual(str(t), "True") 

242 self.assertEqual(str(f), "False") 

243 self.assertEqual(str(y), "b") 

244 self.assertIs(t.as_trivial(), True) 

245 self.assertIs(f.as_trivial(), False) 

246 self.assertTrue(t.is_supported_by(iteration_engine)) 

247 self.assertTrue(f.is_supported_by(iteration_engine)) 

248 self.assertTrue(y.is_supported_by(iteration_engine)) 

249 self.assertTrue(t.is_supported_by(sql_engine)) 

250 self.assertTrue(f.is_supported_by(sql_engine)) 

251 self.assertTrue(y.is_supported_by(sql_engine)) 

252 # Test factory methods for logical operators, including simplification. 

253 self.assertIs(t.logical_not().as_trivial(), False) 

254 self.assertIs(f.logical_not().as_trivial(), True) 

255 self.assertIs(t.logical_and(x).as_trivial(), None) 

256 self.assertIs(f.logical_and(x).as_trivial(), False) 

257 self.assertIs(t.logical_or(x).as_trivial(), True) 

258 self.assertIs(f.logical_or(x).as_trivial(), None) 

259 self.assertEqual(Predicate.logical_and(), t) 

260 self.assertEqual(Predicate.logical_or(), f) 

261 self.assertEqual(Predicate.logical_and(x), x) 

262 self.assertEqual(Predicate.logical_or(x), x) 

263 # Test attributes and simple accessors for logical operators. 

264 not_x = x.logical_not() 

265 self.assertEqual(not_x.columns_required, {a}) 

266 self.assertEqual(str(not_x), "not (a>0)") 

267 self.assertIs(not_x.as_trivial(), None) 

268 self.assertTrue(not_x.is_supported_by(iteration_engine)) 

269 self.assertTrue(not_x.is_supported_by(sql_engine)) 

270 x_and_y = x.logical_and(y) 

271 self.assertEqual(x_and_y.columns_required, {a, b}) 

272 self.assertEqual(str(x_and_y), "a>0 and b") 

273 self.assertIs(x_and_y.as_trivial(), None) 

274 self.assertTrue(x_and_y.is_supported_by(iteration_engine)) 

275 self.assertTrue(x_and_y.is_supported_by(sql_engine)) 

276 x_or_y = x.logical_or(y) 

277 self.assertEqual(x_or_y.columns_required, {a, b}) 

278 self.assertEqual(str(x_or_y), "a>0 or b") 

279 self.assertIs(x_or_y.as_trivial(), None) 

280 self.assertTrue(x_or_y.is_supported_by(iteration_engine)) 

281 self.assertTrue(x_or_y.is_supported_by(sql_engine)) 

282 # Test iteration engine conversions. 

283 self.assertEqual(iteration_engine.convert_predicate(t)({}), True) 

284 self.assertEqual(iteration_engine.convert_predicate(f)({}), False) 

285 self.assertEqual(iteration_engine.convert_predicate(y)({b: True}), True) 

286 self.assertEqual(iteration_engine.convert_predicate(y)({b: False}), False) 

287 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 1}), False) 

288 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 0}), True) 

289 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: True}), True) 

290 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: False}), False) 

291 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: True}), False) 

292 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: False}), False) 

293 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: True}), True) 

294 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: False}), True) 

295 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: True}), True) 

296 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: False}), False) 

297 # Test SQL engine conversions. 

298 columns: Mapping[ColumnTag, sqlalchemy.sql.ColumnElement] = { 

299 a: sqlalchemy.schema.Column("a"), 

300 b: sqlalchemy.schema.Column("b"), 

301 } 

302 self.check_sql_str( 

303 "1", 

304 sql_engine.convert_predicate(t, columns), 

305 sql_engine.convert_predicate(LogicalAnd(()), columns), 

306 ) 

307 self.check_sql_str( 

308 "0", 

309 sql_engine.convert_predicate(f, columns), 

310 sql_engine.convert_predicate(LogicalOr(()), columns), 

311 ) 

312 self.check_sql_str( 

313 "b", 

314 sql_engine.convert_predicate(y, columns), 

315 sql_engine.convert_predicate(LogicalAnd((y,)), columns), 

316 sql_engine.convert_predicate(LogicalOr((y,)), columns), 

317 ) 

318 # Apparently SQLAlchemy does some simplifications of its own on NOT 

319 # operations. 

320 self.check_sql_str("a <= 0", sql_engine.convert_predicate(not_x, columns)) 

321 self.check_sql_str("a > 0 AND b", sql_engine.convert_predicate(x_and_y, columns)) 

322 self.check_sql_str("a > 0 OR b", sql_engine.convert_predicate(x_or_y, columns)) 

323 # Test flatten_logical_and 

324 self.assertEqual(flatten_logical_and(t), []) 

325 self.assertIs(flatten_logical_and(f), False) 

326 self.assertEqual(flatten_logical_and(not_x), [not_x]) 

327 self.assertEqual(flatten_logical_and(x_and_y), [x, y]) 

328 self.assertEqual(flatten_logical_and(x_or_y), [x_or_y]) 

329 self.assertEqual(flatten_logical_and(x_and_y.logical_and(t)), [x, y]) 

330 self.assertEqual(flatten_logical_and(x_and_y.logical_and(f)), False) 

331 c = tests.ColumnTag("c") 

332 z = Predicate.reference(c) 

333 self.assertEqual(flatten_logical_and(x_and_y.logical_and(z)), [x, y, z]) 

334 

335 

336if __name__ == "__main__": 336 ↛ 337line 336 didn't jump to line 337, because the condition on line 336 was never true

337 unittest.main()