Coverage for tests/test_pipeTools.py: 22%

98 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for Pipeline. 

23""" 

24 

25import unittest 

26 

27import lsst.pipe.base.connectionTypes as cT 

28import lsst.utils.tests 

29from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, pipeTools 

30 

31 

32class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=["Visit", "Detector"]): 

33 """Connections class for the example.""" 

34 

35 input1 = cT.Input( 

36 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task" 

37 ) 

38 input2 = cT.Input( 

39 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task" 

40 ) 

41 output1 = cT.Output( 

42 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task" 

43 ) 

44 output2 = cT.Output( 

45 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task" 

46 ) 

47 

48 def __init__(self, *, config=None): 

49 super().__init__(config=config) 

50 if not config.connections.input2: 

51 self.inputs.remove("input2") 

52 if not config.connections.output2: 

53 self.outputs.remove("output2") 

54 

55 

56class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections): 

57 """Config for the example.""" 

58 

59 

60def _makeConfig(inputName, outputName, pipeline, label): 

61 """Apply config overrides as needed. 

62 

63 Factory method for config instances. 

64 

65 inputName and outputName can be either string or tuple of strings 

66 with two items max. 

67 """ 

68 if isinstance(inputName, tuple): 

69 pipeline.addConfigOverride(label, "connections.input1", inputName[0]) 

70 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "") 

71 else: 

72 pipeline.addConfigOverride(label, "connections.input1", inputName) 

73 

74 if isinstance(outputName, tuple): 

75 pipeline.addConfigOverride(label, "connections.output1", outputName[0]) 

76 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "") 

77 else: 

78 pipeline.addConfigOverride(label, "connections.output1", outputName) 

79 

80 

81class ExamplePipelineTask(PipelineTask): 

82 """Example pipeline task used for testing.""" 

83 

84 ConfigClass = ExamplePipelineTaskConfig 

85 _DefaultName = "examplePipelineTask" 

86 

87 

88def _makePipeline(tasks): 

89 """Generate Pipeline instance. 

90 

91 Parameters 

92 ---------- 

93 tasks : list of tuples 

94 Each tuple in the list has 3 or 4 items: 

95 - input DatasetType name(s), string or tuple of strings 

96 - output DatasetType name(s), string or tuple of strings 

97 - task label, string 

98 - optional task class object, can be None 

99 

100 Returns 

101 ------- 

102 Pipeline instance 

103 """ 

104 pipe = Pipeline("test pipeline") 

105 for task in tasks: 

106 inputs = task[0] 

107 outputs = task[1] 

108 label = task[2] 

109 klass = task[3] if len(task) > 3 else ExamplePipelineTask 

110 pipe.addTask(klass, label) 

111 _makeConfig(inputs, outputs, pipe, label) 

112 return list(pipe.toExpandedPipeline()) 

113 

114 

115class PipelineToolsTestCase(unittest.TestCase): 

116 """A test case for pipelineTools""" 

117 

118 def setUp(self): 

119 pass 

120 

121 def tearDown(self): 

122 pass 

123 

124 def testIsOrdered(self): 

125 """Tests for pipeTools.isPipelineOrdered method""" 

126 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")]) 

127 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

128 

129 pipeline = _makePipeline( 

130 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")] 

131 ) 

132 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

133 

134 pipeline = _makePipeline( 

135 [("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")] 

136 ) 

137 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

138 

139 def testOrderPipeline(self): 

140 """Tests for pipeTools.orderPipeline method""" 

141 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")]) 

142 pipeline = pipeTools.orderPipeline(pipeline) 

143 self.assertEqual(len(pipeline), 2) 

144 self.assertEqual(pipeline[0].label, "task1") 

145 self.assertEqual(pipeline[1].label, "task2") 

146 

147 pipeline = _makePipeline([("B", "C", "task2"), ("A", "B", "task1")]) 

148 pipeline = pipeTools.orderPipeline(pipeline) 

149 self.assertEqual(len(pipeline), 2) 

150 self.assertEqual(pipeline[0].label, "task1") 

151 self.assertEqual(pipeline[1].label, "task2") 

152 

153 pipeline = _makePipeline( 

154 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")] 

155 ) 

156 pipeline = pipeTools.orderPipeline(pipeline) 

157 self.assertEqual(len(pipeline), 4) 

158 self.assertEqual(pipeline[0].label, "task1") 

159 self.assertEqual(pipeline[1].label, "task2") 

160 self.assertEqual(pipeline[2].label, "task3") 

161 self.assertEqual(pipeline[3].label, "task4") 

162 

163 pipeline = _makePipeline( 

164 [("A", ("B", "C"), "task1"), ("C", "E", "task3"), ("B", "D", "task2"), (("D", "E"), "F", "task4")] 

165 ) 

166 pipeline = pipeTools.orderPipeline(pipeline) 

167 self.assertEqual(len(pipeline), 4) 

168 self.assertEqual(pipeline[0].label, "task1") 

169 self.assertEqual(pipeline[1].label, "task2") 

170 self.assertEqual(pipeline[2].label, "task3") 

171 self.assertEqual(pipeline[3].label, "task4") 

172 

173 pipeline = _makePipeline( 

174 [(("D", "E"), "F", "task4"), ("B", "D", "task2"), ("C", "E", "task3"), ("A", ("B", "C"), "task1")] 

175 ) 

176 pipeline = pipeTools.orderPipeline(pipeline) 

177 self.assertEqual(len(pipeline), 4) 

178 self.assertEqual(pipeline[0].label, "task1") 

179 self.assertEqual(pipeline[1].label, "task2") 

180 self.assertEqual(pipeline[2].label, "task3") 

181 self.assertEqual(pipeline[3].label, "task4") 

182 

183 pipeline = _makePipeline( 

184 [(("D", "E"), "F", "task4"), ("C", "E", "task3"), ("B", "D", "task2"), ("A", ("B", "C"), "task1")] 

185 ) 

186 pipeline = pipeTools.orderPipeline(pipeline) 

187 self.assertEqual(len(pipeline), 4) 

188 self.assertEqual(pipeline[0].label, "task1") 

189 self.assertEqual(pipeline[1].label, "task2") 

190 self.assertEqual(pipeline[2].label, "task3") 

191 self.assertEqual(pipeline[3].label, "task4") 

192 

193 def testOrderPipelineExceptions(self): 

194 """Tests for pipeTools.orderPipeline method exceptions.""" 

195 # cycle in a graph should throw ValueError 

196 with self.assertRaises(pipeTools.PipelineDataCycleError): 

197 _makePipeline([("A", ("A", "B"), "task1")]) 

198 

199 # another kind of cycle in a graph 

200 with self.assertRaises(pipeTools.PipelineDataCycleError): 

201 _makePipeline( 

202 [("A", "B", "task1"), ("B", "C", "task2"), ("C", "D", "task3"), ("D", "A", "task4")] 

203 ) 

204 

205 

206class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

207 """Run file leak tests.""" 

208 

209 

210def setup_module(module): 

211 """Configure pytest.""" 

212 lsst.utils.tests.init() 

213 

214 

215if __name__ == "__main__": 

216 lsst.utils.tests.init() 

217 unittest.main()