Coverage for tests/test_connections.py: 18%

85 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-19 04:01 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTaskConnections. 

23""" 

24 

25import unittest 

26 

27import lsst.pipe.base as pipeBase 

28import lsst.utils.tests 

29import pytest 

30 

31 

32class TestConnectionsClass(unittest.TestCase): 

33 def __init__(self, *args, **kwargs): 

34 super().__init__(*args, **kwargs) 

35 

36 # Test dimensions 

37 self.test_dims = ("a", "b") 

38 

39 def testConnectionsDeclaration(self): 

40 """Tests the declaration of a Connections Class""" 

41 with pytest.raises(TypeError): 

42 # This should raise because this Connections class is created with 

43 # no dimensions 

44 class TestConnections(pipeBase.PipelineTaskConnections): 

45 pass 

46 

47 with pytest.raises(TypeError): 

48 # This should raise because this Connections class is created with 

49 # out template defaults 

50 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

51 field = pipeBase.connectionTypes.Input( 

52 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

53 ) 

54 

55 # This declaration should raise no exceptions 

56 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

57 pass 

58 

59 # This should not raise 

60 class TestConnectionsWithTemplate( 

61 pipeBase.PipelineTaskConnections, 

62 dimensions=self.test_dims, 

63 defaultTemplates={"template": "working"}, 

64 ): 

65 field = pipeBase.connectionTypes.Input( 

66 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

67 ) 

68 

69 def testConnectionsOnConnectionsClass(self): 

70 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

71 initInput1 = pipeBase.connectionTypes.InitInput( 

72 doc="Test Init input", name="init_input", storageClass="Dummy" 

73 ) 

74 initInput2 = pipeBase.connectionTypes.InitInput( 

75 doc="Test Init input", name="init_input2", storageClass="Dummy" 

76 ) 

77 

78 initOutput1 = pipeBase.connectionTypes.InitOutput( 

79 doc="Test Init output", name="init_output1", storageClass="Dummy" 

80 ) 

81 initOutput2 = pipeBase.connectionTypes.InitOutput( 

82 doc="Test Init output", name="init_output2", storageClass="Dummy" 

83 ) 

84 

85 input1 = pipeBase.connectionTypes.Input( 

86 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy" 

87 ) 

88 input2 = pipeBase.connectionTypes.Input( 

89 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy" 

90 ) 

91 

92 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput( 

93 doc="test input", name="pre_input1", dimensions=self.test_dims, storageClass="Dummy" 

94 ) 

95 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput( 

96 doc="test input", name="pre_input2", dimensions=self.test_dims, storageClass="Dummy" 

97 ) 

98 

99 output1 = pipeBase.connectionTypes.Output( 

100 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy" 

101 ) 

102 output2 = pipeBase.connectionTypes.Output( 

103 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy" 

104 ) 

105 

106 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"])) 

107 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"])) 

108 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"])) 

109 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"])) 

110 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"])) 

111 

112 def buildTestConnections(self): 

113 class TestConnectionsWithTemplate( 

114 pipeBase.PipelineTaskConnections, 

115 dimensions=self.test_dims, 

116 defaultTemplates={"template": "working"}, 

117 ): 

118 field = pipeBase.connectionTypes.Input( 

119 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

120 ) 

121 field2 = pipeBase.connectionTypes.Output( 

122 doc="Test", name="field2Type", dimensions=self.test_dims, storageClass="Dummy", multiple=True 

123 ) 

124 

125 def adjustQuantum(self, datasetRefMap): 

126 if len(datasetRefMap.field) < 2: 

127 raise ValueError("This connection should have more than one entry") 

128 

129 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate): 

130 pass 

131 

132 config = TestConfig() 

133 config.connections.template = "fromConfig" 

134 config.connections.field2 = "field2FromConfig" 

135 

136 connections = TestConnectionsWithTemplate(config=config) 

137 return connections 

138 

139 def testConnectionsInstantiation(self): 

140 connections = self.buildTestConnections() 

141 self.assertEqual(connections.field.name, "fromConfigtest") 

142 self.assertEqual(connections.field2.name, "field2FromConfig") 

143 

144 def testBuildDatasetRefs(self): 

145 connections = self.buildTestConnections() 

146 

147 mockQuantum = pipeBase.Struct( 

148 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]} 

149 ) 

150 

151 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

152 self.assertEqual(inputRefs.field, "a") 

153 self.assertEqual(outputRefs.field2, ["b", "c"]) 

154 

155 def testAdjustQuantum(self): 

156 connections = self.buildTestConnections() 

157 mockQuantum = pipeBase.Struct( 

158 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]} 

159 ) 

160 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

161 with self.assertRaises(ValueError): 

162 connections.adjustQuantum(inputRefs) 

163 

164 def testDimensionCheck(self): 

165 with self.assertRaises(TypeError): 

166 

167 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions={"a"}): 

168 pass 

169 

170 with self.assertRaises(TypeError): 

171 

172 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2): 

173 pass 

174 

175 with self.assertRaises(TypeError): 

176 pipeBase.connectionTypes.Output( 

177 Doc="mock doc", dimensions={"a"}, name="output", storageClass="mock" 

178 ) 

179 

180 with self.assertRaises(TypeError): 

181 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output", storageClass="mock") 

182 

183 

184class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

185 pass 

186 

187 

188def setup_module(module): 

189 lsst.utils.tests.init() 

190 

191 

192if __name__ == "__main__": 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true

193 lsst.utils.tests.init() 

194 unittest.main()