Coverage for tests/test_join.py: 10%

104 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-22 02:59 -0800

1# This file is part of daf_relation. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import unittest 

25 

26from lsst.daf.relation import ( 

27 BinaryOperationRelation, 

28 ColumnError, 

29 ColumnExpression, 

30 EngineError, 

31 Join, 

32 Predicate, 

33 SortTerm, 

34 iteration, 

35 tests, 

36) 

37 

38 

39class JoinTestCase(tests.RelationTestCase): 

40 """Tests for the Join operation and relations based on it.""" 

41 

42 def setUp(self) -> None: 

43 self.a = tests.ColumnTag("a") 

44 self.b = tests.ColumnTag("b") 

45 self.c = tests.ColumnTag("c") 

46 self.engine = iteration.Engine(name="preferred") 

47 self.leaf_1 = self.engine.make_leaf( 

48 {self.a, self.b}, 

49 payload=iteration.RowSequence( 

50 [{self.a: 0, self.b: 5}, {self.a: 1, self.b: 10}, {self.a: 2, self.b: 25}] 

51 ), 

52 name="leaf_1", 

53 ) 

54 self.leaf_2 = self.engine.make_leaf( 

55 {self.a, self.c}, 

56 payload=iteration.RowSequence( 

57 [{self.a: 0, self.c: 15}, {self.a: 2, self.c: 20}, {self.a: 3, self.b: 0}] 

58 ), 

59 name="leaf_2", 

60 ) 

61 

62 def test_attributes(self) -> None: 

63 """Check that all Relation and PartialJoin attributes have the expected 

64 values.""" 

65 relation = self.leaf_1.join(self.leaf_2) 

66 assert isinstance(relation, BinaryOperationRelation) 

67 self.assertEqual(relation.columns, {self.a, self.b, self.c}) 

68 self.assertEqual(relation.engine, self.engine) 

69 self.assertEqual(relation.min_rows, 0) 

70 self.assertEqual(relation.max_rows, 9) 

71 self.assertFalse(relation.is_locked) 

72 operation = relation.operation 

73 assert isinstance(operation, Join) 

74 self.assertEqual(operation.min_columns, {self.a}) 

75 self.assertEqual(operation.max_columns, {self.a}) 

76 self.assertEqual(operation.common_columns, {self.a}) 

77 self.assertEqual(operation.predicate, Predicate.literal(True)) 

78 partial = Join().partial(self.leaf_1) 

79 self.assertEqual(partial.columns_required, frozenset()) 

80 self.assert_relations_equal(partial.fixed, self.leaf_1) 

81 self.assertFalse(partial.is_count_dependent) 

82 self.assertFalse(partial.is_order_dependent) 

83 self.assertFalse(partial.is_count_invariant) 

84 self.assertFalse(partial.is_empty_invariant) 

85 self.assertEqual(partial.applied_columns(self.leaf_2), {self.a, self.b, self.c}) 

86 self.assertEqual(partial.applied_min_rows(self.leaf_2), 0) 

87 self.assertEqual(partial.applied_max_rows(self.leaf_2), 9) 

88 

89 def test_apply_failures(self) -> None: 

90 """Test failure modes of constructing and applying Join.""" 

91 # Mismatched engines. 

92 new_engine = iteration.Engine(name="downstream") 

93 with self.assertRaises(EngineError): 

94 Join().apply(self.leaf_1.transferred_to(new_engine), self.leaf_2) 

95 # Predicate requires nonexistent columns. 

96 predicate = ColumnExpression.reference(tests.ColumnTag("d")).lt(ColumnExpression.literal(0)) 

97 with self.assertRaises(ColumnError): 

98 Join(predicate=predicate).apply(self.leaf_1, self.leaf_2) 

99 with self.assertRaises(ColumnError): 

100 Join(predicate=predicate).partial(self.leaf_1).apply(self.leaf_2) 

101 with self.assertRaises(ColumnError): 

102 Join(predicate=predicate).partial(self.leaf_2).apply(self.leaf_1) 

103 # Bounds on columns internally inconsistent. 

104 with self.assertRaises(ColumnError): 

105 Join(min_columns=frozenset({self.a, self.b}), max_columns=frozenset({self.a})) 

106 # Minimum columns not satisfied. 

107 join = Join(min_columns=frozenset({self.a, self.b})) 

108 with self.assertRaises(ColumnError): 

109 join.apply(self.leaf_1, self.leaf_2) 

110 with self.assertRaises(ColumnError): 

111 join.apply(self.leaf_2, self.leaf_1) 

112 with self.assertRaises(ColumnError): 

113 join.partial(self.leaf_2) 

114 with self.assertRaises(ColumnError): 

115 join.partial(self.leaf_1).apply(self.leaf_2) 

116 # Common columns not satisfied. 

117 join = Join(min_columns=frozenset({self.a, self.b}), max_columns=frozenset({self.a, self.b})) 

118 with self.assertRaises(ColumnError): 

119 join.apply(self.leaf_1, self.leaf_2) 

120 with self.assertRaises(ColumnError): 

121 join.apply(self.leaf_2, self.leaf_1) 

122 

123 def test_apply_simplify(self) -> None: 

124 """Test Join.apply simplifications.""" 

125 join_identity = self.engine.make_join_identity_relation() 

126 self.assertIs(self.leaf_1.join(join_identity), self.leaf_1) 

127 self.assertIs(join_identity.join(self.leaf_1), self.leaf_1) 

128 

129 def test_backtracking_apply(self) -> None: 

130 """Test `PartialJoin.apply` logic that involves reordering operations 

131 in the existing tree to perform the new operation in a preferred 

132 engine. 

133 """ 

134 new_engine = iteration.Engine(name="downstream") 

135 d = tests.ColumnTag("d") 

136 expression = ColumnExpression.function( 

137 "__add__", ColumnExpression.reference(self.a), ColumnExpression.literal(5) 

138 ) 

139 sort_terms = [SortTerm(ColumnExpression.reference(self.a))] 

140 predicate = ColumnExpression.reference(self.b).gt(ColumnExpression.literal(0)) 

141 # Apply a bunch of operations in a new engine that a PartialJoin should 

142 # commute with. 

143 target = ( 

144 self.leaf_1.transferred_to(new_engine) 

145 .with_calculated_column(d, expression) 

146 .with_rows_satisfying(predicate) 

147 .with_only_columns({self.a, d}) 

148 .sorted(sort_terms) 

149 ) 

150 # Apply a new PartialJoin with backtracking and see that it appears 

151 # before the transfer to the new engine, with adjustments as needed. 

152 relation = target.join(self.leaf_2) 

153 self.assert_relations_equal( 

154 relation, 

155 ( 

156 self.leaf_1.join(self.leaf_2) 

157 .transferred_to(new_engine) 

158 .with_calculated_column(d, expression) 

159 .with_rows_satisfying(predicate) 

160 .with_only_columns({self.a, self.c, d}) 

161 .sorted(sort_terms) 

162 ), 

163 ) 

164 

165 def test_no_backtracking(self) -> None: 

166 """Test `PartialJoin.apply` logic that handles differing engines 

167 without reordering operations in the existing tree, as well as failures 

168 in that reordering. 

169 """ 

170 new_engine = iteration.Engine(name="downstream") 

171 # Construct a relation tree we can't reorder when inserting a Join, 

172 # because there is a locked Materialization in the way. 

173 target = self.leaf_1.transferred_to(new_engine).materialized("lock") 

174 # We can automatically transfer (back) to the new relation's engine. 

175 self.assert_relations_equal( 

176 target.join(self.leaf_2, transfer=True), 

177 target.transferred_to(self.engine).join(self.leaf_2), 

178 ) 

179 # Can't backtrack through a Deduplication. 

180 target = self.leaf_1.transferred_to(new_engine).without_duplicates() 

181 with self.assertRaises(EngineError): 

182 target.join(self.leaf_2) 

183 # Can't backtrack through a Slice, because it's order/count dependent. 

184 target = self.leaf_1.transferred_to(new_engine)[:2] 

185 with self.assertRaises(EngineError): 

186 target.join(self.leaf_2) 

187 

188 def test_common_columns(self) -> None: 

189 """Test Join.applied_common_columns logic.""" 

190 leaf_3 = self.engine.make_leaf( 

191 {self.a, self.b, self.c}, 

192 payload=iteration.RowSequence( 

193 [{self.a: 0, self.b: 2, self.c: 15}, {self.a: 2, self.b: 4, self.c: 20}] 

194 ), 

195 name="leaf_2", 

196 ) 

197 # With no min or max columns, common_columns is just the intersection 

198 # of the columns of the operands. 

199 self.assertEqual(Join().applied_common_columns(self.leaf_1, leaf_3), {self.a, self.b}) 

200 # Check that max_columns is enforced. 

201 self.assertEqual( 

202 Join(max_columns=frozenset({self.a})).applied_common_columns(self.leaf_1, leaf_3), {self.a} 

203 ) 

204 # Check that min_columns is enforced. 

205 with self.assertRaises(ColumnError): 

206 Join(min_columns=frozenset({self.c})).applied_common_columns(self.leaf_1, leaf_3) 

207 # Repeat last two checks with min_columns == max_columns. 

208 self.assertEqual( 

209 Join(min_columns=frozenset({self.a}), max_columns=frozenset({self.a})).applied_common_columns( 

210 self.leaf_1, leaf_3 

211 ), 

212 {self.a}, 

213 ) 

214 with self.assertRaises(ColumnError): 

215 Join(min_columns=frozenset({self.c}), max_columns=frozenset({self.c})).apply(self.leaf_1, leaf_3) 

216 

217 def test_str(self) -> None: 

218 """Test str(Join), str(PartialJoin), and 

219 str(BinaryOperationRelation[Join]). 

220 """ 

221 relation = self.leaf_1.join(self.leaf_2) 

222 self.assertEqual(str(relation), "leaf_1 ⋈ leaf_2") 

223 partial = Join().partial(self.leaf_1) 

224 self.assertEqual(str(partial), "⋈[leaf_1]") 

225 # Nested operations get parentheses, unless they're joins or leaves. 

226 leaf_3 = self.engine.make_leaf( 

227 {self.a, self.b}, 

228 payload=iteration.RowSequence([{self.a: 3, self.b: 4}]), 

229 name="leaf_3", 

230 ) 

231 self.assertEqual(str(relation.join(leaf_3)), "leaf_1 ⋈ leaf_2 ⋈ leaf_3") 

232 self.assertEqual(str(self.leaf_1.chain(leaf_3).join(self.leaf_2)), "(leaf_1 ∪ leaf_3) ⋈ leaf_2") 

233 

234 

235if __name__ == "__main__": 235 ↛ 236line 235 didn't jump to line 236, because the condition on line 235 was never true

236 unittest.main()