Coverage for tests/test_column_expressions.py: 5%
201 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-21 09:39 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-21 09:39 +0000
1# This file is part of daf_relation.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24import unittest
25from collections.abc import Mapping
27import sqlalchemy
28from lsst.daf.relation import (
29 ColumnContainer,
30 ColumnExpression,
31 ColumnFunction,
32 ColumnTag,
33 LogicalAnd,
34 LogicalOr,
35 Predicate,
36 flatten_logical_and,
37 iteration,
38 sql,
39 tests,
40)
43class ColumnExpressionTestCase(tests.RelationTestCase):
44 """Test column expressions."""
46 def test_operator_function(self) -> None:
47 """Test ColumnFunction via its constructor and a name found in the
48 `operator` module.
49 """
50 tag = tests.ColumnTag("tag")
51 expression = ColumnFunction(
52 "__sub__",
53 (ColumnExpression.reference(tag), ColumnExpression.literal(5)),
54 dtype=int,
55 supporting_engine_types=None,
56 )
57 self.assertEqual(expression.dtype, int)
58 self.assertEqual(expression.supporting_engine_types, None)
59 self.assertEqual(expression.args, (ColumnExpression.reference(tag), ColumnExpression.literal(5)))
60 self.assertEqual(expression.name, "__sub__")
61 self.assertEqual(expression.columns_required, {tag})
62 self.assertEqual(
63 str(expression),
64 "__sub__(tag, 5)",
65 )
66 iteration_engine = iteration.Engine()
67 self.assertTrue(expression.is_supported_by(iteration_engine))
68 callable = iteration_engine.convert_column_expression(expression)
69 self.assertEqual(callable({tag: 3}), -2)
70 self.assertEqual(callable({tag: 6}), 1)
71 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
72 self.assertTrue(expression.is_supported_by(sql_engine))
73 sql_expression = sql_engine.convert_column_expression(
74 expression, {tag: sqlalchemy.schema.Column("tag")}
75 )
76 self.check_sql_str("tag - 5", sql_expression)
78 def test_method(self) -> None:
79 """Test ColumnFunction via ColumnExpression.method and a name found
80 on the object itself.
81 """
82 tag = tests.ColumnTag("tag")
83 expression = ColumnExpression.reference(tag).method("lower")
84 self.assertEqual(expression.dtype, None)
85 self.assertEqual(expression.supporting_engine_types, None)
86 self.assertEqual(expression.args, (ColumnExpression.reference(tag),))
87 self.assertEqual(expression.name, "lower")
88 self.assertEqual(expression.columns_required, {tag})
89 engine = iteration.Engine()
90 self.assertTrue(expression.is_supported_by(engine))
91 callable = engine.convert_column_expression(expression)
92 self.assertEqual(callable({tag: "MiXeDcAsE"}), "mixedcase")
94 def test_engine_function(self) -> None:
95 """Test ColumnFunction via ColumnExpression.function and a name
96 that references a callable held by the engine.
97 """
98 tag = tests.ColumnTag("tag")
99 engine = iteration.Engine()
100 engine.functions["test_function"] = lambda x: x**2
101 expression = ColumnExpression.function(
102 "test_function",
103 ColumnExpression.reference(tag),
104 supporting_engine_types={iteration.Engine},
105 dtype=int,
106 )
107 self.assertEqual(expression.dtype, int)
108 self.assertEqual(expression.supporting_engine_types, (iteration.Engine,))
109 self.assertEqual(expression.args, (ColumnExpression.reference(tag),))
110 self.assertEqual(expression.name, "test_function")
111 self.assertEqual(expression.columns_required, {tag})
112 self.assertTrue(expression.is_supported_by(engine))
113 self.assertFalse(expression.is_supported_by(sql.Engine[sqlalchemy.sql.ColumnElement]()))
114 callable = engine.convert_column_expression(expression)
115 self.assertEqual(callable({tag: 3}), 9)
116 self.assertEqual(callable({tag: 4}), 16)
118 def test_operator_predicate_function(self) -> None:
119 """Test PredicateFunction via ColumnExpression factory methods and
120 names found in the `operator` module.
121 """
122 tag = tests.ColumnTag("tag")
123 ref = ColumnExpression.reference(tag, dtype=int)
124 zero = ColumnExpression.literal(0, dtype=int)
125 expressions = [ref.eq(zero), ref.ne(zero), ref.lt(zero), ref.le(zero), ref.gt(zero), ref.ge(zero)]
126 self.assertEqual([x.dtype for x in expressions], [bool] * 6)
127 self.assertEqual([x.supporting_engine_types for x in expressions], [None] * 6)
128 self.assertEqual([x.args for x in expressions], [(ref, zero)] * 6)
129 self.assertEqual(
130 [x.name for x in expressions], ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]
131 )
132 self.assertEqual([x.columns_required for x in expressions], [{tag}] * 6)
133 self.assertEqual(
134 [str(x) for x in expressions],
135 ["tag=0", "tag≠0", "tag<0", "tag≤0", "tag>0", "tag≥0"],
136 )
137 iteration_engine = iteration.Engine()
138 self.assertEqual([x.is_supported_by(iteration_engine) for x in expressions], [True] * 6)
139 callables = [iteration_engine.convert_predicate(x) for x in expressions]
140 self.assertEqual([c({tag: 0}) for c in callables], [True, False, False, True, False, True])
141 self.assertEqual([c({tag: 1}) for c in callables], [False, True, False, False, True, True])
142 self.assertEqual([c({tag: -1}) for c in callables], [False, True, True, True, False, False])
143 sql_bind = sqlalchemy.schema.Column("tag")
144 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
145 self.assertEqual([x.is_supported_by(sql_engine) for x in expressions], [True] * 6)
146 sql_expressions: list[sqlalchemy.sql.ColumnElement] = [
147 sql_engine.convert_predicate(x, {tag: sql_bind}) for x in expressions
148 ]
149 self.assertEqual(
150 [tests.to_sql_str(s) for s in sql_expressions],
151 [
152 "tag = 0",
153 "tag != 0",
154 "tag < 0",
155 "tag <= 0",
156 "tag > 0",
157 "tag >= 0",
158 ],
159 )
161 def test_column_expression_sequence(self) -> None:
162 """Test ColumnExpressionSequence and ColumnInContainer."""
163 a = tests.ColumnTag("a")
164 seq = ColumnContainer.sequence(
165 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)], dtype=int
166 )
167 self.assertEqual(seq.dtype, int)
168 self.assertEqual(
169 list(seq.items),
170 [ColumnExpression.reference(a, dtype=int), ColumnExpression.literal(4, dtype=int)],
171 )
172 self.assertEqual(seq.columns_required, {a})
173 self.assertEqual(str(seq), "[a, 4]")
174 iteration_engine = iteration.Engine()
175 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
176 self.assertTrue(seq.is_supported_by(iteration_engine))
177 self.assertTrue(seq.is_supported_by(sql_engine))
178 b = tests.ColumnTag("b")
179 contains = seq.contains(ColumnExpression.reference(b))
180 self.assertEqual(contains.dtype, bool)
181 self.assertEqual(contains.columns_required, {a, b})
182 self.assertEqual(str(contains), "b∈[a, 4]")
183 self.assertTrue(contains.is_supported_by(iteration_engine))
184 self.assertTrue(contains.is_supported_by(sql_engine))
185 callable = iteration_engine.convert_predicate(contains)
186 self.assertEqual(callable({a: 0, b: 0}), True)
187 self.assertEqual(callable({a: 0, b: 3}), False)
188 self.assertEqual(callable({a: 0, b: 4}), True)
189 bind_a = sqlalchemy.schema.Column("a")
190 bind_b = sqlalchemy.schema.Column("b")
191 self.check_sql_str("b IN (a, 4)", sql_engine.convert_predicate(contains, {a: bind_a, b: bind_b}))
193 def test_range_literal(self) -> None:
194 """Test ColumnRangeLiteral and ColumnInContainer."""
195 ranges = [
196 ColumnContainer.range_literal(range(3, 4)),
197 ColumnContainer.range_literal(range(3, 6)),
198 ColumnContainer.range_literal(range(2, 11, 3)),
199 ]
200 self.assertEqual([r.dtype for r in ranges], [int] * 3)
201 self.assertEqual([r.columns_required for r in ranges], [frozenset()] * 3)
202 self.assertEqual([str(r) for r in ranges], ["[3:4:1]", "[3:6:1]", "[2:11:3]"])
203 a = tests.ColumnTag("a")
204 contains = [r.contains(ColumnExpression.reference(a)) for r in ranges]
205 self.assertEqual([c.dtype for c in contains], [bool] * 3)
206 self.assertEqual([c.columns_required for c in contains], [{a}] * 3)
207 self.assertEqual([str(c) for c in contains], ["a∈[3:4:1]", "a∈[3:6:1]", "a∈[2:11:3]"])
208 iteration_engine = iteration.Engine()
209 self.assertEqual([c.is_supported_by(iteration_engine) for c in contains], [True] * 3)
210 callables = [iteration_engine.convert_predicate(c) for c in contains]
211 self.assertEqual([c({a: 3}) for c in callables], [True, True, False])
212 self.assertEqual([c({a: 5}) for c in callables], [False, True, True])
213 self.assertEqual([c({a: 8}) for c in callables], [False, False, True])
214 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
215 self.assertEqual([c.is_supported_by(sql_engine) for c in contains], [True] * 3)
216 bind_a = sqlalchemy.schema.Column("a")
217 sql_expressions = [sql_engine.convert_predicate(c, {a: bind_a}) for c in contains]
218 self.assertEqual(
219 [tests.to_sql_str(s) for s in sql_expressions],
220 ["a = 3", "a BETWEEN 3 AND 5", "a BETWEEN 2 AND 10 AND a % 3 = 2"],
221 )
223 def test_logical_operators(self) -> None:
224 """Test predicate logical operator expressions, Predicate.as_literal,
225 and flatten_logical_and.
226 """
227 a = tests.ColumnTag("a")
228 b = tests.ColumnTag("b")
229 t = Predicate.literal(True)
230 f = Predicate.literal(False)
231 x = ColumnExpression.reference(a).gt(ColumnExpression.literal(0))
232 y = Predicate.reference(b)
233 iteration_engine = iteration.Engine()
234 sql_engine = sql.Engine[sqlalchemy.sql.ColumnElement]()
235 # Check attributes and simple accessors for predicate literals and
236 # references. Some as_trivial overloads are not checked because MyPy
237 # can tell they always return None and complains if we try to use that
238 # return value (which also means MyPy takes care of that "test" for us,
239 # even though coverage can't tell).
240 self.assertEqual(t.columns_required, frozenset())
241 self.assertEqual(f.columns_required, frozenset())
242 self.assertEqual(y.columns_required, {b})
243 self.assertEqual(str(t), "True")
244 self.assertEqual(str(f), "False")
245 self.assertEqual(str(y), "b")
246 self.assertIs(t.as_trivial(), True)
247 self.assertIs(f.as_trivial(), False)
248 self.assertTrue(t.is_supported_by(iteration_engine))
249 self.assertTrue(f.is_supported_by(iteration_engine))
250 self.assertTrue(y.is_supported_by(iteration_engine))
251 self.assertTrue(t.is_supported_by(sql_engine))
252 self.assertTrue(f.is_supported_by(sql_engine))
253 self.assertTrue(y.is_supported_by(sql_engine))
254 # Test factory methods for logical operators, including simplification.
255 self.assertIs(t.logical_not().as_trivial(), False)
256 self.assertIs(f.logical_not().as_trivial(), True)
257 self.assertIs(t.logical_and(x).as_trivial(), None)
258 self.assertIs(f.logical_and(x).as_trivial(), False)
259 self.assertIs(t.logical_or(x).as_trivial(), True)
260 self.assertIs(f.logical_or(x).as_trivial(), None)
261 self.assertEqual(Predicate.logical_and(), t)
262 self.assertEqual(Predicate.logical_or(), f)
263 self.assertEqual(Predicate.logical_and(x), x)
264 self.assertEqual(Predicate.logical_or(x), x)
265 # Test attributes and simple accessors for logical operators.
266 not_x = x.logical_not()
267 self.assertEqual(not_x.columns_required, {a})
268 self.assertEqual(str(not_x), "not (a>0)")
269 self.assertIs(not_x.as_trivial(), None)
270 self.assertTrue(not_x.is_supported_by(iteration_engine))
271 self.assertTrue(not_x.is_supported_by(sql_engine))
272 x_and_y = x.logical_and(y)
273 self.assertEqual(x_and_y.columns_required, {a, b})
274 self.assertEqual(str(x_and_y), "a>0 and b")
275 self.assertIs(x_and_y.as_trivial(), None)
276 self.assertTrue(x_and_y.is_supported_by(iteration_engine))
277 self.assertTrue(x_and_y.is_supported_by(sql_engine))
278 x_or_y = x.logical_or(y)
279 self.assertEqual(x_or_y.columns_required, {a, b})
280 self.assertEqual(str(x_or_y), "a>0 or b")
281 self.assertIs(x_or_y.as_trivial(), None)
282 self.assertTrue(x_or_y.is_supported_by(iteration_engine))
283 self.assertTrue(x_or_y.is_supported_by(sql_engine))
284 # Test iteration engine conversions.
285 self.assertEqual(iteration_engine.convert_predicate(t)({}), True)
286 self.assertEqual(iteration_engine.convert_predicate(f)({}), False)
287 self.assertEqual(iteration_engine.convert_predicate(y)({b: True}), True)
288 self.assertEqual(iteration_engine.convert_predicate(y)({b: False}), False)
289 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 1}), False)
290 self.assertEqual(iteration_engine.convert_predicate(not_x)({a: 0}), True)
291 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: True}), True)
292 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 1, b: False}), False)
293 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: True}), False)
294 self.assertEqual(iteration_engine.convert_predicate(x_and_y)({a: 0, b: False}), False)
295 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: True}), True)
296 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 1, b: False}), True)
297 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: True}), True)
298 self.assertEqual(iteration_engine.convert_predicate(x_or_y)({a: 0, b: False}), False)
299 # Test SQL engine conversions.
300 columns: Mapping[ColumnTag, sqlalchemy.sql.ColumnElement] = {
301 a: sqlalchemy.schema.Column("a"),
302 b: sqlalchemy.schema.Column("b"),
303 }
304 self.check_sql_str(
305 "1",
306 sql_engine.convert_predicate(t, columns),
307 sql_engine.convert_predicate(LogicalAnd(()), columns),
308 )
309 self.check_sql_str(
310 "0",
311 sql_engine.convert_predicate(f, columns),
312 sql_engine.convert_predicate(LogicalOr(()), columns),
313 )
314 self.check_sql_str(
315 "b",
316 sql_engine.convert_predicate(y, columns),
317 sql_engine.convert_predicate(LogicalAnd((y,)), columns),
318 sql_engine.convert_predicate(LogicalOr((y,)), columns),
319 )
320 # Apparently SQLAlchemy does some simplifications of its own on NOT
321 # operations.
322 self.check_sql_str("a <= 0", sql_engine.convert_predicate(not_x, columns))
323 self.check_sql_str("a > 0 AND b", sql_engine.convert_predicate(x_and_y, columns))
324 self.check_sql_str("a > 0 OR b", sql_engine.convert_predicate(x_or_y, columns))
325 # Test flatten_logical_and
326 self.assertEqual(flatten_logical_and(t), [])
327 self.assertIs(flatten_logical_and(f), False)
328 self.assertEqual(flatten_logical_and(not_x), [not_x])
329 self.assertEqual(flatten_logical_and(x_and_y), [x, y])
330 self.assertEqual(flatten_logical_and(x_or_y), [x_or_y])
331 self.assertEqual(flatten_logical_and(x_and_y.logical_and(t)), [x, y])
332 self.assertEqual(flatten_logical_and(x_and_y.logical_and(f)), False)
333 c = tests.ColumnTag("c")
334 z = Predicate.reference(c)
335 self.assertEqual(flatten_logical_and(x_and_y.logical_and(z)), [x, y, z])
338if __name__ == "__main__":
339 unittest.main()