Coverage for tests/test_connections.py: 21%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

85 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTaskConnections. 

23""" 

24 

25import unittest 

26import pytest 

27 

28import lsst.utils.tests 

29import lsst.pipe.base as pipeBase 

30 

31 

32class TestConnectionsClass(unittest.TestCase): 

33 def __init__(self, *args, **kwargs): 

34 super().__init__(*args, **kwargs) 

35 

36 # Test dimensions 

37 self.test_dims = ("a", "b") 

38 

39 def testConnectionsDeclaration(self): 

40 """Tests the declaration of a Connections Class 

41 """ 

42 with pytest.raises(TypeError): 

43 # This should raise because this Connections class is created with 

44 # no dimensions 

45 class TestConnections(pipeBase.PipelineTaskConnections): 

46 pass 

47 

48 with pytest.raises(TypeError): 

49 # This should raise because this Connections class is created with 

50 # out template defaults 

51 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

52 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test", 

53 dimensions=self.test_dims, 

54 storageClass='Dummy') 

55 

56 # This declaration should raise no exceptions 

57 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

58 pass 

59 

60 # This should not raise 

61 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims, 

62 defaultTemplates={"template": "working"}): 

63 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test", 

64 dimensions=self.test_dims, 

65 storageClass='Dummy') 

66 

67 def testConnectionsOnConnectionsClass(self): 

68 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

69 initInput1 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input", 

70 storageClass='Dummy') 

71 initInput2 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input2", 

72 storageClass='Dummy') 

73 

74 initOutput1 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output1", 

75 storageClass='Dummy') 

76 initOutput2 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output2", 

77 storageClass='Dummy') 

78 

79 input1 = pipeBase.connectionTypes.Input(doc="test input", name="input2", 

80 dimensions=self.test_dims, 

81 storageClass='Dummy') 

82 input2 = pipeBase.connectionTypes.Input(doc="test input", name="input2", 

83 dimensions=self.test_dims, 

84 storageClass='Dummy') 

85 

86 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input1", 

87 dimensions=self.test_dims, 

88 storageClass='Dummy') 

89 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input2", 

90 dimensions=self.test_dims, 

91 storageClass='Dummy') 

92 

93 output1 = pipeBase.connectionTypes.Output(doc="test output", name="output", 

94 dimensions=self.test_dims, 

95 storageClass='Dummy') 

96 output2 = pipeBase.connectionTypes.Output(doc="test output", name="output", 

97 dimensions=self.test_dims, 

98 storageClass='Dummy') 

99 

100 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"])) 

101 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"])) 

102 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"])) 

103 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"])) 

104 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"])) 

105 

106 def buildTestConnections(self): 

107 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims, 

108 defaultTemplates={"template": "working"}): 

109 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test", 

110 dimensions=self.test_dims, 

111 storageClass='Dummy') 

112 field2 = pipeBase.connectionTypes.Output(doc="Test", name="field2Type", 

113 dimensions=self.test_dims, 

114 storageClass='Dummy', 

115 multiple=True) 

116 

117 def adjustQuantum(self, datasetRefMap): 

118 if len(datasetRefMap.field) < 2: 

119 raise ValueError("This connection should have more than one entry") 

120 

121 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate): 

122 pass 

123 

124 config = TestConfig() 

125 config.connections.template = "fromConfig" 

126 config.connections.field2 = "field2FromConfig" 

127 

128 connections = TestConnectionsWithTemplate(config=config) 

129 return connections 

130 

131 def testConnectionsInstantiation(self): 

132 connections = self.buildTestConnections() 

133 self.assertEqual(connections.field.name, "fromConfigtest") 

134 self.assertEqual(connections.field2.name, "field2FromConfig") 

135 

136 def testBuildDatasetRefs(self): 

137 connections = self.buildTestConnections() 

138 

139 mockQuantum = pipeBase.Struct(inputs={"fromConfigtest": ["a"]}, 

140 outputs={"field2FromConfig": ["b", "c"]}) 

141 

142 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

143 self.assertEqual(inputRefs.field, "a") 

144 self.assertEqual(outputRefs.field2, ["b", "c"]) 

145 

146 def testAdjustQuantum(self): 

147 connections = self.buildTestConnections() 

148 mockQuantum = pipeBase.Struct(inputs={"fromConfigtest": ["a"]}, 

149 outputs={"field2FromConfig": ["b", "c"]}) 

150 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

151 with self.assertRaises(ValueError): 

152 connections.adjustQuantum(inputRefs) 

153 

154 def testDimensionCheck(self): 

155 with self.assertRaises(TypeError): 

156 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions=("a")): 

157 pass 

158 

159 with self.assertRaises(TypeError): 

160 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2): 

161 pass 

162 

163 with self.assertRaises(TypeError): 

164 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=("a"), name="output", 

165 storageClass="mock") 

166 

167 with self.assertRaises(TypeError): 

168 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output", 

169 storageClass="mock") 

170 

171 

172class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

173 pass 

174 

175 

176def setup_module(module): 

177 lsst.utils.tests.init() 

178 

179 

180if __name__ == "__main__": 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true

181 lsst.utils.tests.init() 

182 unittest.main()