Coverage for tests/test_column_expressions.py: 6%
203 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-07 10:05 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-07 10:05 +0000
1# This file is part of daf_relation.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24import unittest
25from collections.abc import Mapping
27import sqlalchemy
28from lsst.daf.relation import (
29 ColumnContainer,
30 ColumnExpression,
31 ColumnFunction,
32 ColumnTag,
33 LogicalAnd,
34 LogicalOr,
35 Predicate,
36 flatten_logical_and,
37 iteration,
38 sql,
39 tests,
40)
43class ColumnExpressionTestCase(tests.RelationTestCase):
44 def test_operator_function(self) -> None:
45 """Test ColumnFunction via its constructor and a name found in the
46 `operator` module.
47 """
48 tag = tests.ColumnTag("tag")
49 expression = ColumnFunction(
50 "__sub__",
51 (ColumnExpression.reference(tag), ColumnExpression.literal(5)),
52 dtype=int,
53 supporting_engine_types=None,
54 )
55 self.assertEqual(expression.dtype, int)
56 self.assertEqual(expression.supporting_engine_types, None)
57 self.assertEqual(expression.args, (ColumnExpression.reference(tag), ColumnExpression.literal(5)))
58 self.assertEqual(expression.name, "__sub__")
59 self.assertEqual(expression.columns_required, {tag})
60 self.assertEqual(
61 str(expression),
62 "__sub__(tag, 5)",
63 )
64 iteration_engine = iteration.Engine()
65 self.assertTrue(expression.is_supported_by(iteration_engine))
66 callable = iteration_engine.convert_column_expression(expression)
67 self.assertEqual(callable({tag: 3}), -2)
68 self.assertEqual(callable({tag: 6}), 1)
69 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
70 self.assertTrue(expression.is_supported_by(sql_engine))
71 sql_expression = sql_engine.convert_column_expression(
72 expression, {tag: sqlalchemy.schema.Column("tag")}
73 )
74 self.check_sql_str("tag - 5", sql_expression)
76 def test_method(self) -> None:
77 """Test ColumnFunction via ColumnExpression.method and a name found
78 on the object itself.
79 """
80 tag = tests.ColumnTag("tag")
81 expression = ColumnExpression.reference(tag).method("lower")
82 self.assertEqual(expression.dtype, None)
83 self.assertEqual(expression.supporting_engine_types, None)
84 self.assertEqual(expression.args, (ColumnExpression.reference(tag),))
85 self.assertEqual(expression.name, "lower")
86 self.assertEqual(expression.columns_required, {tag})
87 engine = iteration.Engine()
88 self.assertTrue(expression.is_supported_by(engine))
89 callable = engine.convert_column_expression(expression)
90 self.assertEqual(callable({tag: "MiXeDcAsE"}), "mixedcase")
92 def test_engine_function(self) -> None:
93 """Test ColumnFunction via ColumnExpression.function and a name
94 that references a callable held by the engine.
95 """
96 tag = tests.ColumnTag("tag")
97 engine = iteration.Engine()
98 engine.functions["test_function"] = lambda x: x**2
99 expression = ColumnExpression.function(
100 "test_function",
101 ColumnExpression.reference(tag),
102 supporting_engine_types={iteration.Engine},
103 dtype=int,
104 )
105 self.assertEqual(expression.dtype, int)
106 self.assertEqual(expression.supporting_engine_types, (iteration.Engine,))
107 self.assertEqual(expression.args, (ColumnExpression.reference(tag),))
108 self.assertEqual(expression.name, "test_function")
109 self.assertEqual(expression.columns_required, {tag})
110 self.assertTrue(expression.is_supported_by(engine))
111 self.assertFalse(expression.is_supported_by(sql.Engine[sqlalchemy.sql.ColumnElement]()))
112 callable = engine.convert_column_expression(expression)
113 self.assertEqual(callable({tag: 3}), 9)
114 self.assertEqual(callable({tag: 4}), 16)
116 def test_operator_predicate_function(self) -> None:
117 """Test PredicateFunction via ColumnExpression factory methods and
118 names found in the `operator` module.
119 """
120 tag = tests.ColumnTag("tag")
121 ref = ColumnExpression.reference(tag, dtype=int)
122 zero = ColumnExpression.literal(0, dtype=int)
123 expressions = [ref.eq(zero), ref.ne(zero), ref.lt(zero), ref.le(zero), ref.gt(zero), ref.ge(zero)]
124 self.assertEqual([x.dtype for x in expressions], [bool] * 6)
125 self.assertEqual([x.supporting_engine_types for x in expressions], [None] * 6)
126 self.assertEqual([x.args for x in expressions], [(ref, zero)] * 6)
127 self.assertEqual(
128 [x.name for x in expressions], ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]
129 )
130 self.assertEqual([x.columns_required for x in expressions], [{tag}] * 6)
131 self.assertEqual(
132 [str(x) for x in expressions],
133 ["tag=0", "tag≠0", "tag<0", "tag≤0", "tag>0", "tag≥0"],
134 )
135 iteration_engine = iteration.Engine()
136 self.assertEqual([x.is_supported_by(iteration_engine) for x in expressions], [True] * 6)
137 callables = [iteration_engine.convert_predicate(x) for x in expressions]
138 self.assertEqual([c({tag: 0}) for c in callables], [True, False, False, True, False, True])
139 self.assertEqual([c({tag: 1}) for c in callables], [False, True, False, False, True, True])
140 self.assertEqual([c({tag: -1}) for c in callables], [False, True, True, True, False, False])
141 sql_bind = sqlalchemy.schema.Column("tag")
142 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
143 self.assertEqual([x.is_supported_by(sql_engine) for x in expressions], [True] * 6)
144 sql_expressions: list[sqlalchemy.sql.ColumnElement] = [
145 sql_engine.convert_predicate(x, {tag: sql_bind}) for x in expressions
146 ]
147 self.assertEqual(
148 [tests.to_sql_str(s) for s in sql_expressions],
149 [
150 "tag = 0",
151 "tag != 0",
152 "tag < 0",
153 "tag <= 0",
154 "tag > 0",
155 "tag >= 0",
156 ],
157 )
159 def test_column_expression_sequence(self) -> None:
160 """Test ColumnExpressionSequence and ColumnInContainer."""
161 a = tests.ColumnTag("a")
162 seq = ColumnContainer.sequence(
163 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)], dtype=int
164 )
165 self.assertEqual(seq.dtype, int)
166 self.assertEqual(
167 list(seq.items),
168 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)],
169 )
170 self.assertEqual(seq.columns_required, {a})
171 self.assertEqual(str(seq), "[a, 4]")
172 iteration_engine = iteration.Engine()
173 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
174 self.assertTrue(seq.is_supported_by(iteration_engine))
175 self.assertTrue(seq.is_supported_by(sql_engine))
176 b = tests.ColumnTag("b")
177 contains = seq.contains(ColumnExpression.reference(b))
178 self.assertEqual(contains.dtype, bool)
179 self.assertEqual(contains.columns_required, {a, b})
180 self.assertEqual(str(contains), "b∈[a, 4]")
181 self.assertTrue(contains.is_supported_by(iteration_engine))
182 self.assertTrue(contains.is_supported_by(sql_engine))
183 callable = iteration_engine.convert_predicate(contains)
184 self.assertEqual(callable({a: 0, b: 0}), True)
185 self.assertEqual(callable({a: 0, b: 3}), False)
186 self.assertEqual(callable({a: 0, b: 4}), True)
187 bind_a = sqlalchemy.schema.Column("a")
188 bind_b = sqlalchemy.schema.Column("b")
189 self.check_sql_str("b IN (a, 4)", sql_engine.convert_predicate(contains, {a: bind_a, b: bind_b}))
191 def test_range_literal(self) -> None:
192 """Test ColumnRangeLiteral and ColumnInContainer."""
193 ranges = [
194 ColumnContainer.range_literal(range(3, 4)),
195 ColumnContainer.range_literal(range(3, 6)),
196 ColumnContainer.range_literal(range(2, 11, 3)),
197 ]
198 self.assertEqual([r.dtype for r in ranges], [int] * 3)
199 self.assertEqual([r.columns_required for r in ranges], [frozenset()] * 3)
200 self.assertEqual([str(r) for r in ranges], ["[3:4:1]", "[3:6:1]", "[2:11:3]"])
201 a = tests.ColumnTag("a")
202 contains = [r.contains(ColumnExpression.reference(a)) for r in ranges]
203 self.assertEqual([c.dtype for c in contains], [bool] * 3)
204 self.assertEqual([c.columns_required for c in contains], [{a}] * 3)
205 self.assertEqual([str(c) for c in contains], ["a∈[3:4:1]", "a∈[3:6:1]", "a∈[2:11:3]"])
206 iteration_engine = iteration.Engine()
207 self.assertEqual([c.is_supported_by(iteration_engine) for c in contains], [True] * 3)
208 callables = [iteration_engine.convert_predicate(c) for c in contains]
209 self.assertEqual([c({a: 3}) for c in callables], [True, True, False])
210 self.assertEqual([c({a: 5}) for c in callables], [False, True, True])
211 self.assertEqual([c({a: 8}) for c in callables], [False, False, True])
212 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
213 self.assertEqual([c.is_supported_by(sql_engine) for c in contains], [True] * 3)
214 bind_a = sqlalchemy.schema.Column("a")
215 sql_expressions = [sql_engine.convert_predicate(c, {a: bind_a}) for c in contains]
216 self.assertEqual(
217 [tests.to_sql_str(s) for s in sql_expressions],
218 ["a = 3", "a BETWEEN 3 AND 5", "a BETWEEN 2 AND 10 AND a % 3 = 2"],
219 )
221 def test_logical_operators(self) -> None:
222 """Test predicate logical operator expressions, Predicate.as_literal,
223 and flatten_logical_and.
224 """
225 a = tests.ColumnTag("a")
226 b = tests.ColumnTag("b")
227 t = Predicate.literal(True)
228 f = Predicate.literal(False)
229 x = ColumnExpression.reference(a).gt(ColumnExpression.literal(0))
230 y = Predicate.reference(b)
231 iteration_engine = iteration.Engine()
232 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
233 # Check attributes and simple accessors for predicate literals and
234 # references. Some as_trivial overloads are not checked because MyPy
235 # can tell they always return None and complains if we try to use that
236 # return value (which also means MyPy takes care of that "test" for us,
237 # even though coverage can't tell).
238 self.assertEqual(t.columns_required, frozenset())
239 self.assertEqual(f.columns_required, frozenset())
240 self.assertEqual(y.columns_required, {b})
241 self.assertEqual(str(t), "True")
242 self.assertEqual(str(f), "False")
243 self.assertEqual(str(y), "b")
244 self.assertIs(t.as_trivial(), True)
245 self.assertIs(f.as_trivial(), False)
246 self.assertTrue(t.is_supported_by(iteration_engine))
247 self.assertTrue(f.is_supported_by(iteration_engine))
248 self.assertTrue(y.is_supported_by(iteration_engine))
249 self.assertTrue(t.is_supported_by(sql_engine))
250 self.assertTrue(f.is_supported_by(sql_engine))
251 self.assertTrue(y.is_supported_by(sql_engine))
252 # Test factory methods for logical operators, including simplification.
253 self.assertIs(t.logical_not().as_trivial(), False)
254 self.assertIs(f.logical_not().as_trivial(), True)
255 self.assertIs(t.logical_and(x).as_trivial(), None)
256 self.assertIs(f.logical_and(x).as_trivial(), False)
257 self.assertIs(t.logical_or(x).as_trivial(), True)
258 self.assertIs(f.logical_or(x).as_trivial(), None)
259 self.assertEqual(Predicate.logical_and(), t)
260 self.assertEqual(Predicate.logical_or(), f)
261 self.assertEqual(Predicate.logical_and(x), x)
262 self.assertEqual(Predicate.logical_or(x), x)
263 # Test attributes and simple accessors for logical operators.
264 not_x = x.logical_not()
265 self.assertEqual(not_x.columns_required, {a})
266 self.assertEqual(str(not_x), "not (a>0)")
267 self.assertIs(not_x.as_trivial(), None)
268 self.assertTrue(not_x.is_supported_by(iteration_engine))
269 self.assertTrue(not_x.is_supported_by(sql_engine))
270 x_and_y = x.logical_and(y)
271 self.assertEqual(x_and_y.columns_required, {a, b})
272 self.assertEqual(str(x_and_y), "a>0 and b")
273 self.assertIs(x_and_y.as_trivial(), None)
274 self.assertTrue(x_and_y.is_supported_by(iteration_engine))
275 self.assertTrue(x_and_y.is_supported_by(sql_engine))
276 x_or_y = x.logical_or(y)
277 self.assertEqual(x_or_y.columns_required, {a, b})
278 self.assertEqual(str(x_or_y), "a>0 or b")
279 self.assertIs(x_or_y.as_trivial(), None)
280 self.assertTrue(x_or_y.is_supported_by(iteration_engine))
281 self.assertTrue(x_or_y.is_supported_by(sql_engine))
282 # Test iteration engine conversions.
283 self.assertEqual(iteration_engine.convert_predicate(t)({}), True)
284 self.assertEqual(iteration_engine.convert_predicate(f)({}), False)
285 self.assertEqual(iteration_engine.convert_predicate(y)({b: True}), True)
286 self.assertEqual(iteration_engine.convert_predicate(y)({b: False}), False)
287 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 1}), False)
288 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 0}), True)
289 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: True}), True)
290 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: False}), False)
291 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: True}), False)
292 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: False}), False)
293 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: True}), True)
294 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: False}), True)
295 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: True}), True)
296 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: False}), False)
297 # Test SQL engine conversions.
298 columns: Mapping[ColumnTag, sqlalchemy.sql.ColumnElement] = {
299 a: sqlalchemy.schema.Column("a"),
300 b: sqlalchemy.schema.Column("b"),
301 }
302 self.check_sql_str(
303 "1",
304 sql_engine.convert_predicate(t, columns),
305 sql_engine.convert_predicate(LogicalAnd(()), columns),
306 )
307 self.check_sql_str(
308 "0",
309 sql_engine.convert_predicate(f, columns),
310 sql_engine.convert_predicate(LogicalOr(()), columns),
311 )
312 self.check_sql_str(
313 "b",
314 sql_engine.convert_predicate(y, columns),
315 sql_engine.convert_predicate(LogicalAnd((y,)), columns),
316 sql_engine.convert_predicate(LogicalOr((y,)), columns),
317 )
318 # Apparently SQLAlchemy does some simplifications of its own on NOT
319 # operations.
320 self.check_sql_str("a <= 0", sql_engine.convert_predicate(not_x, columns))
321 self.check_sql_str("a > 0 AND b", sql_engine.convert_predicate(x_and_y, columns))
322 self.check_sql_str("a > 0 OR b", sql_engine.convert_predicate(x_or_y, columns))
323 # Test flatten_logical_and
324 self.assertEqual(flatten_logical_and(t), [])
325 self.assertIs(flatten_logical_and(f), False)
326 self.assertEqual(flatten_logical_and(not_x), [not_x])
327 self.assertEqual(flatten_logical_and(x_and_y), [x, y])
328 self.assertEqual(flatten_logical_and(x_or_y), [x_or_y])
329 self.assertEqual(flatten_logical_and(x_and_y.logical_and(t)), [x, y])
330 self.assertEqual(flatten_logical_and(x_and_y.logical_and(f)), False)
331 c = tests.ColumnTag("c")
332 z = Predicate.reference(c)
333 self.assertEqual(flatten_logical_and(x_and_y.logical_and(z)), [x, y, z])
336if __name__ == "__main__": 336 ↛ 337line 336 didn't jump to line 337, because the condition on line 336 was never true
337 unittest.main()