Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTaskConnections. 

23""" 

24 

25import unittest 

26import pytest 

27 

28import lsst.utils.tests 

29import lsst.pipe.base as pipeBase 

30 

31 

32class TestConnectionsClass(unittest.TestCase): 

33 def __init__(self, *args, **kwargs): 

34 super().__init__(*args, **kwargs) 

35 

36 # Test dimensions 

37 self.test_dims = ("a", "b") 

38 

39 def testConnectionsDeclaration(self): 

40 """Tests the declaration of a Connections Class 

41 """ 

42 with pytest.raises(TypeError): 

43 # This should raise because this Connections class is created with no dimensions 

44 class TestConnections(pipeBase.PipelineTaskConnections): 

45 pass 

46 

47 with pytest.raises(TypeError): 

48 # This should raise because this Connections class is created with out template defaults 

49 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

50 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test", 

51 dimensions=self.test_dims, 

52 storageClass='Dummy') 

53 

54 # This declaration should raise no exceptions 

55 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

56 pass 

57 

58 # This should not raise 

59 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims, 

60 defaultTemplates={"template": "working"}): 

61 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test", 

62 dimensions=self.test_dims, 

63 storageClass='Dummy') 

64 

65 def testConnectionsOnConnectionsClass(self): 

66 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

67 initInput1 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input", 

68 storageClass='Dummy') 

69 initInput2 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input2", 

70 storageClass='Dummy') 

71 

72 initOutput1 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output1", 

73 storageClass='Dummy') 

74 initOutput2 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output2", 

75 storageClass='Dummy') 

76 

77 input1 = pipeBase.connectionTypes.Input(doc="test input", name="input2", 

78 dimensions=self.test_dims, 

79 storageClass='Dummy') 

80 input2 = pipeBase.connectionTypes.Input(doc="test input", name="input2", 

81 dimensions=self.test_dims, 

82 storageClass='Dummy') 

83 

84 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input1", 

85 dimensions=self.test_dims, 

86 storageClass='Dummy') 

87 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input2", 

88 dimensions=self.test_dims, 

89 storageClass='Dummy') 

90 

91 output1 = pipeBase.connectionTypes.Output(doc="test output", name="output", 

92 dimensions=self.test_dims, 

93 storageClass='Dummy') 

94 output2 = pipeBase.connectionTypes.Output(doc="test output", name="output", 

95 dimensions=self.test_dims, 

96 storageClass='Dummy') 

97 

98 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"])) 

99 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"])) 

100 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"])) 

101 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"])) 

102 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"])) 

103 

104 def buildTestConnections(self): 

105 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims, 

106 defaultTemplates={"template": "working"}): 

107 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test", 

108 dimensions=self.test_dims, 

109 storageClass='Dummy') 

110 field2 = pipeBase.connectionTypes.Output(doc="Test", name="field2Type", 

111 dimensions=self.test_dims, 

112 storageClass='Dummy', 

113 multiple=True) 

114 

115 def adjustQuantum(self, datasetRefMap): 

116 if len(datasetRefMap.field) < 2: 

117 raise ValueError("This connection should have more than one entry") 

118 

119 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate): 

120 pass 

121 

122 config = TestConfig() 

123 config.connections.template = "fromConfig" 

124 config.connections.field2 = "field2FromConfig" 

125 

126 connections = TestConnectionsWithTemplate(config=config) 

127 return connections 

128 

129 def testConnectionsInstantiation(self): 

130 connections = self.buildTestConnections() 

131 self.assertEqual(connections.field.name, "fromConfigtest") 

132 self.assertEqual(connections.field2.name, "field2FromConfig") 

133 

134 def testBuildDatasetRefs(self): 

135 connections = self.buildTestConnections() 

136 

137 mockQuantum = pipeBase.Struct(predictedInputs={"fromConfigtest": ["a"]}, 

138 outputs={"field2FromConfig": ["b", "c"]}) 

139 

140 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

141 self.assertEqual(inputRefs.field, "a") 

142 self.assertEqual(outputRefs.field2, ["b", "c"]) 

143 

144 def testAdjustQuantum(self): 

145 connections = self.buildTestConnections() 

146 mockQuantum = pipeBase.Struct(predictedInputs={"fromConfigtest": ["a"]}, 

147 outputs={"field2FromConfig": ["b", "c"]}) 

148 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

149 with self.assertRaises(ValueError): 

150 connections.adjustQuantum(inputRefs) 

151 

152 

153class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

154 pass 

155 

156 

157def setup_module(module): 

158 lsst.utils.tests.init() 

159 

160 

161if __name__ == "__main__": 161 ↛ 162line 161 didn't jump to line 162, because the condition on line 161 was never true

162 lsst.utils.tests.init() 

163 unittest.main()