Coverage for tests/test_pipeTools.py: 22%

98 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-17 10:52 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for Pipeline. 

29""" 

30 

31import unittest 

32 

33import lsst.pipe.base.connectionTypes as cT 

34import lsst.utils.tests 

35from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, pipeTools 

36 

37 

38class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=["Visit", "Detector"]): 

39 """Connections class for the example.""" 

40 

41 input1 = cT.Input( 

42 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task" 

43 ) 

44 input2 = cT.Input( 

45 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task" 

46 ) 

47 output1 = cT.Output( 

48 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task" 

49 ) 

50 output2 = cT.Output( 

51 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task" 

52 ) 

53 

54 def __init__(self, *, config=None): 

55 super().__init__(config=config) 

56 if not config.connections.input2: 

57 self.inputs.remove("input2") 

58 if not config.connections.output2: 

59 self.outputs.remove("output2") 

60 

61 

62class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections): 

63 """Config for the example.""" 

64 

65 

66def _makeConfig(inputName, outputName, pipeline, label): 

67 """Apply config overrides as needed. 

68 

69 Factory method for config instances. 

70 

71 inputName and outputName can be either string or tuple of strings 

72 with two items max. 

73 """ 

74 if isinstance(inputName, tuple): 

75 pipeline.addConfigOverride(label, "connections.input1", inputName[0]) 

76 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "") 

77 else: 

78 pipeline.addConfigOverride(label, "connections.input1", inputName) 

79 

80 if isinstance(outputName, tuple): 

81 pipeline.addConfigOverride(label, "connections.output1", outputName[0]) 

82 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "") 

83 else: 

84 pipeline.addConfigOverride(label, "connections.output1", outputName) 

85 

86 

87class ExamplePipelineTask(PipelineTask): 

88 """Example pipeline task used for testing.""" 

89 

90 ConfigClass = ExamplePipelineTaskConfig 

91 _DefaultName = "examplePipelineTask" 

92 

93 

94def _makePipeline(tasks): 

95 """Generate Pipeline instance. 

96 

97 Parameters 

98 ---------- 

99 tasks : list of tuples 

100 Each tuple in the list has 3 or 4 items: 

101 - input DatasetType name(s), string or tuple of strings 

102 - output DatasetType name(s), string or tuple of strings 

103 - task label, string 

104 - optional task class object, can be None 

105 

106 Returns 

107 ------- 

108 Pipeline instance 

109 """ 

110 pipe = Pipeline("test pipeline") 

111 for task in tasks: 

112 inputs = task[0] 

113 outputs = task[1] 

114 label = task[2] 

115 klass = task[3] if len(task) > 3 else ExamplePipelineTask 

116 pipe.addTask(klass, label) 

117 _makeConfig(inputs, outputs, pipe, label) 

118 return list(pipe.toExpandedPipeline()) 

119 

120 

121class PipelineToolsTestCase(unittest.TestCase): 

122 """A test case for pipelineTools""" 

123 

124 def setUp(self): 

125 pass 

126 

127 def tearDown(self): 

128 pass 

129 

130 def testIsOrdered(self): 

131 """Tests for pipeTools.isPipelineOrdered method""" 

132 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")]) 

133 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

134 

135 pipeline = _makePipeline( 

136 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")] 

137 ) 

138 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

139 

140 pipeline = _makePipeline( 

141 [("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")] 

142 ) 

143 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

144 

145 def testOrderPipeline(self): 

146 """Tests for pipeTools.orderPipeline method""" 

147 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")]) 

148 pipeline = pipeTools.orderPipeline(pipeline) 

149 self.assertEqual(len(pipeline), 2) 

150 self.assertEqual(pipeline[0].label, "task1") 

151 self.assertEqual(pipeline[1].label, "task2") 

152 

153 pipeline = _makePipeline([("B", "C", "task2"), ("A", "B", "task1")]) 

154 pipeline = pipeTools.orderPipeline(pipeline) 

155 self.assertEqual(len(pipeline), 2) 

156 self.assertEqual(pipeline[0].label, "task1") 

157 self.assertEqual(pipeline[1].label, "task2") 

158 

159 pipeline = _makePipeline( 

160 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")] 

161 ) 

162 pipeline = pipeTools.orderPipeline(pipeline) 

163 self.assertEqual(len(pipeline), 4) 

164 self.assertEqual(pipeline[0].label, "task1") 

165 self.assertEqual(pipeline[1].label, "task2") 

166 self.assertEqual(pipeline[2].label, "task3") 

167 self.assertEqual(pipeline[3].label, "task4") 

168 

169 pipeline = _makePipeline( 

170 [("A", ("B", "C"), "task1"), ("C", "E", "task3"), ("B", "D", "task2"), (("D", "E"), "F", "task4")] 

171 ) 

172 pipeline = pipeTools.orderPipeline(pipeline) 

173 self.assertEqual(len(pipeline), 4) 

174 self.assertEqual(pipeline[0].label, "task1") 

175 self.assertEqual(pipeline[1].label, "task2") 

176 self.assertEqual(pipeline[2].label, "task3") 

177 self.assertEqual(pipeline[3].label, "task4") 

178 

179 pipeline = _makePipeline( 

180 [(("D", "E"), "F", "task4"), ("B", "D", "task2"), ("C", "E", "task3"), ("A", ("B", "C"), "task1")] 

181 ) 

182 pipeline = pipeTools.orderPipeline(pipeline) 

183 self.assertEqual(len(pipeline), 4) 

184 self.assertEqual(pipeline[0].label, "task1") 

185 self.assertEqual(pipeline[1].label, "task2") 

186 self.assertEqual(pipeline[2].label, "task3") 

187 self.assertEqual(pipeline[3].label, "task4") 

188 

189 pipeline = _makePipeline( 

190 [(("D", "E"), "F", "task4"), ("C", "E", "task3"), ("B", "D", "task2"), ("A", ("B", "C"), "task1")] 

191 ) 

192 pipeline = pipeTools.orderPipeline(pipeline) 

193 self.assertEqual(len(pipeline), 4) 

194 self.assertEqual(pipeline[0].label, "task1") 

195 self.assertEqual(pipeline[1].label, "task2") 

196 self.assertEqual(pipeline[2].label, "task3") 

197 self.assertEqual(pipeline[3].label, "task4") 

198 

199 def testOrderPipelineExceptions(self): 

200 """Tests for pipeTools.orderPipeline method exceptions.""" 

201 # cycle in a graph should throw ValueError 

202 with self.assertRaises(pipeTools.PipelineDataCycleError): 

203 _makePipeline([("A", ("A", "B"), "task1")]) 

204 

205 # another kind of cycle in a graph 

206 with self.assertRaises(pipeTools.PipelineDataCycleError): 

207 _makePipeline( 

208 [("A", "B", "task1"), ("B", "C", "task2"), ("C", "D", "task3"), ("D", "A", "task4")] 

209 ) 

210 

211 

212class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

213 """Run file leak tests.""" 

214 

215 

216def setup_module(module): 

217 """Configure pytest.""" 

218 lsst.utils.tests.init() 

219 

220 

221if __name__ == "__main__": 

222 lsst.utils.tests.init() 

223 unittest.main()