Coverage for tests/test_pipeTools.py: 21%

94 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-23 10:43 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for Pipeline. 

29""" 

30 

31import unittest 

32 

33import lsst.pipe.base.connectionTypes as cT 

34import lsst.utils.tests 

35from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, pipeTools 

36 

37 

38class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=["Visit", "Detector"]): 

39 """Connections class for the example.""" 

40 

41 input1 = cT.Input( 

42 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task" 

43 ) 

44 input2 = cT.Input( 

45 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task" 

46 ) 

47 output1 = cT.Output( 

48 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task" 

49 ) 

50 output2 = cT.Output( 

51 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task" 

52 ) 

53 

54 def __init__(self, *, config=None): 

55 super().__init__(config=config) 

56 if not config.connections.input2: 

57 self.inputs.remove("input2") 

58 if not config.connections.output2: 

59 self.outputs.remove("output2") 

60 

61 

62class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections): 

63 """Config for the example.""" 

64 

65 

66def _makeConfig(inputName, outputName, pipeline, label): 

67 """Apply config overrides as needed. 

68 

69 Factory method for config instances. 

70 

71 inputName and outputName can be either string or tuple of strings 

72 with two items max. 

73 """ 

74 if isinstance(inputName, tuple): 

75 pipeline.addConfigOverride(label, "connections.input1", inputName[0]) 

76 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "") 

77 else: 

78 pipeline.addConfigOverride(label, "connections.input1", inputName) 

79 

80 if isinstance(outputName, tuple): 

81 pipeline.addConfigOverride(label, "connections.output1", outputName[0]) 

82 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "") 

83 else: 

84 pipeline.addConfigOverride(label, "connections.output1", outputName) 

85 

86 

87class ExamplePipelineTask(PipelineTask): 

88 """Example pipeline task used for testing.""" 

89 

90 ConfigClass = ExamplePipelineTaskConfig 

91 _DefaultName = "examplePipelineTask" 

92 

93 

94def _makePipeline(tasks): 

95 """Generate Pipeline instance. 

96 

97 Parameters 

98 ---------- 

99 tasks : list of tuples 

100 Each tuple in the list has 3 or 4 items: 

101 

102 - input DatasetType name(s), string or tuple of strings 

103 - output DatasetType name(s), string or tuple of strings 

104 - task label, string 

105 - optional task class object, can be None. 

106 

107 Returns 

108 ------- 

109 Pipeline instance. 

110 """ 

111 pipe = Pipeline("test pipeline") 

112 for task in tasks: 

113 inputs = task[0] 

114 outputs = task[1] 

115 label = task[2] 

116 klass = task[3] if len(task) > 3 else ExamplePipelineTask 

117 pipe.addTask(klass, label) 

118 _makeConfig(inputs, outputs, pipe, label) 

119 return list(pipe.toExpandedPipeline()) 

120 

121 

122class PipelineToolsTestCase(unittest.TestCase): 

123 """A test case for pipelineTools.""" 

124 

125 def testIsOrdered(self): 

126 """Tests for pipeTools.isPipelineOrdered method.""" 

127 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")]) 

128 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

129 

130 pipeline = _makePipeline( 

131 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")] 

132 ) 

133 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

134 

135 pipeline = _makePipeline( 

136 [("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")] 

137 ) 

138 self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) 

139 

140 def testOrderPipeline(self): 

141 """Tests for pipeTools.orderPipeline method.""" 

142 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")]) 

143 pipeline = pipeTools.orderPipeline(pipeline) 

144 self.assertEqual(len(pipeline), 2) 

145 self.assertEqual(pipeline[0].label, "task1") 

146 self.assertEqual(pipeline[1].label, "task2") 

147 

148 pipeline = _makePipeline([("B", "C", "task2"), ("A", "B", "task1")]) 

149 pipeline = pipeTools.orderPipeline(pipeline) 

150 self.assertEqual(len(pipeline), 2) 

151 self.assertEqual(pipeline[0].label, "task1") 

152 self.assertEqual(pipeline[1].label, "task2") 

153 

154 pipeline = _makePipeline( 

155 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")] 

156 ) 

157 pipeline = pipeTools.orderPipeline(pipeline) 

158 self.assertEqual(len(pipeline), 4) 

159 self.assertEqual(pipeline[0].label, "task1") 

160 self.assertEqual(pipeline[1].label, "task2") 

161 self.assertEqual(pipeline[2].label, "task3") 

162 self.assertEqual(pipeline[3].label, "task4") 

163 

164 pipeline = _makePipeline( 

165 [("A", ("B", "C"), "task1"), ("C", "E", "task3"), ("B", "D", "task2"), (("D", "E"), "F", "task4")] 

166 ) 

167 pipeline = pipeTools.orderPipeline(pipeline) 

168 self.assertEqual(len(pipeline), 4) 

169 self.assertEqual(pipeline[0].label, "task1") 

170 self.assertEqual(pipeline[1].label, "task2") 

171 self.assertEqual(pipeline[2].label, "task3") 

172 self.assertEqual(pipeline[3].label, "task4") 

173 

174 pipeline = _makePipeline( 

175 [(("D", "E"), "F", "task4"), ("B", "D", "task2"), ("C", "E", "task3"), ("A", ("B", "C"), "task1")] 

176 ) 

177 pipeline = pipeTools.orderPipeline(pipeline) 

178 self.assertEqual(len(pipeline), 4) 

179 self.assertEqual(pipeline[0].label, "task1") 

180 self.assertEqual(pipeline[1].label, "task2") 

181 self.assertEqual(pipeline[2].label, "task3") 

182 self.assertEqual(pipeline[3].label, "task4") 

183 

184 pipeline = _makePipeline( 

185 [(("D", "E"), "F", "task4"), ("C", "E", "task3"), ("B", "D", "task2"), ("A", ("B", "C"), "task1")] 

186 ) 

187 pipeline = pipeTools.orderPipeline(pipeline) 

188 self.assertEqual(len(pipeline), 4) 

189 self.assertEqual(pipeline[0].label, "task1") 

190 self.assertEqual(pipeline[1].label, "task2") 

191 self.assertEqual(pipeline[2].label, "task3") 

192 self.assertEqual(pipeline[3].label, "task4") 

193 

194 def testOrderPipelineExceptions(self): 

195 """Tests for pipeTools.orderPipeline method exceptions.""" 

196 # cycle in a graph should throw ValueError 

197 with self.assertRaises(pipeTools.PipelineDataCycleError): 

198 _makePipeline([("A", ("A", "B"), "task1")]) 

199 

200 # another kind of cycle in a graph 

201 with self.assertRaises(pipeTools.PipelineDataCycleError): 

202 _makePipeline( 

203 [("A", "B", "task1"), ("B", "C", "task2"), ("C", "D", "task3"), ("D", "A", "task4")] 

204 ) 

205 

206 

207class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

208 """Run file leak tests.""" 

209 

210 

211def setup_module(module): 

212 """Configure pytest.""" 

213 lsst.utils.tests.init() 

214 

215 

216if __name__ == "__main__": 

217 lsst.utils.tests.init() 

218 unittest.main()