Coverage for tests/test_pipeline.py: 45%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

102 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for Pipeline. 

23""" 

24 

25import tempfile 

26import textwrap 

27import unittest 

28 

29import lsst.pex.config as pexConfig 

30from lsst.pipe.base import (Struct, PipelineTask, PipelineTaskConfig, Pipeline, TaskDef, 

31 PipelineTaskConnections, PipelineDatasetTypes, connectionTypes) 

32from lsst.pipe.base.tests.simpleQGraph import makeSimplePipeline 

33import lsst.utils.tests 

34 

35 

36class DummyAddConnections(PipelineTaskConnections, dimensions=()): 

37 dummyOutput = connectionTypes.Output("addOutput", "dummyStorage", "add output") 

38 

39 

40class DummyMultiplyConnections(PipelineTaskConnections, dimensions=()): 

41 dummyInput = connectionTypes.Input("addOutput", "dummyStorage", "add output") 

42 

43 

44class AddConfig(PipelineTaskConfig, pipelineConnections=DummyAddConnections): 

45 addend = pexConfig.Field(doc="amount to add", dtype=float, default=3.1) 

46 

47 

48class AddTask(PipelineTask): 

49 ConfigClass = AddConfig 

50 _DefaultName = "add" 

51 

52 def run(self, val): 

53 self.metadata.add("add", self.config.addend) 

54 return Struct( 

55 val=val + self.config.addend, 

56 ) 

57 

58 

59class MultConfig(PipelineTaskConfig, pipelineConnections=DummyMultiplyConnections): 

60 multiplicand = pexConfig.Field(doc="amount by which to multiply", dtype=float, default=2.5) 

61 

62 

63class MultTask(PipelineTask): 

64 ConfigClass = MultConfig 

65 _DefaultName = "mult" 

66 

67 def run(self, val): 

68 self.metadata.add("mult", self.config.multiplicand) 

69 return Struct( 

70 val=val * self.config.multiplicand, 

71 ) 

72 

73 

74class TaskTestCase(unittest.TestCase): 

75 """A test case for Task 

76 """ 

77 

78 def setUp(self): 

79 pass 

80 

81 def tearDown(self): 

82 pass 

83 

84 def testTaskDef(self): 

85 """Tests for TaskDef structure 

86 """ 

87 task1 = TaskDef(taskClass=AddTask, config=AddConfig()) 

88 self.assertIn("Add", task1.taskName) 

89 self.assertIsInstance(task1.config, AddConfig) 

90 self.assertIsNotNone(task1.taskClass) 

91 self.assertEqual(task1.label, "add") 

92 

93 task2 = TaskDef("lsst.pipe.base.tests.Mult", config=MultConfig(), taskClass=MultTask, 

94 label="mult_task") 

95 self.assertEqual(task2.taskName, "lsst.pipe.base.tests.Mult") 

96 self.assertIsInstance(task2.config, MultConfig) 

97 self.assertIs(task2.taskClass, MultTask) 

98 self.assertEqual(task2.label, "mult_task") 

99 self.assertEqual(task2.metadataDatasetName, "mult_task_metadata") 

100 

101 config = MultConfig() 

102 config.saveMetadata = False 

103 task3 = TaskDef("lsst.pipe.base.tests.Mult", config, MultTask, "mult_task") 

104 self.assertIsNone(task3.metadataDatasetName) 

105 

106 def testEmpty(self): 

107 """Creating empty pipeline 

108 """ 

109 pipeline = Pipeline("test") 

110 self.assertEqual(len(pipeline), 0) 

111 

112 def testInitial(self): 

113 """Testing constructor with initial data 

114 """ 

115 pipeline = Pipeline("test") 

116 pipeline.addTask(AddTask, "add") 

117 pipeline.addTask(MultTask, "mult") 

118 self.assertEqual(len(pipeline), 2) 

119 expandedPipeline = list(pipeline.toExpandedPipeline()) 

120 self.assertEqual(expandedPipeline[0].taskName, "AddTask") 

121 self.assertEqual(expandedPipeline[1].taskName, "MultTask") 

122 

123 def testParameters(self): 

124 """Test that parameters can be set and used to format 

125 """ 

126 pipeline_str = textwrap.dedent(""" 

127 description: Test Pipeline 

128 parameters: 

129 testValue: 5.7 

130 tasks: 

131 add: 

132 class: test_pipeline.AddTask 

133 config: 

134 addend: parameters.testValue 

135 """) 

136 # verify that parameters are used in expanding a pipeline 

137 pipeline = Pipeline.fromString(pipeline_str) 

138 expandedPipeline = list(pipeline.toExpandedPipeline()) 

139 self.assertEqual(expandedPipeline[0].config.addend, 5.7) 

140 

141 # verify that a parameter can be overridden on the "command line" 

142 pipeline.addConfigOverride("parameters", "testValue", 14.9) 

143 expandedPipeline = list(pipeline.toExpandedPipeline()) 

144 self.assertEqual(expandedPipeline[0].config.addend, 14.9) 

145 

146 # verify that a non existing parameter cant be overridden 

147 with self.assertRaises(ValueError): 

148 pipeline.addConfigOverride("parameters", "missingValue", 17) 

149 

150 # verify that parameters does not support files or python overrides 

151 with self.assertRaises(ValueError): 

152 pipeline.addConfigFile("parameters", "fakeFile") 

153 with self.assertRaises(ValueError): 

154 pipeline.addConfigPython("parameters", "fakePythonString") 

155 

156 def testSerialization(self): 

157 pipeline = Pipeline("test") 

158 pipeline.addTask(MultTask, "mult") 

159 pipeline.addTask(AddTask, "add") 

160 

161 dump = str(pipeline) 

162 load = Pipeline.fromString(dump) 

163 self.assertEqual(pipeline, load) 

164 

165 # verify the keys keys were sorted after a call to str 

166 self.assertEqual([t.label for t in pipeline.toExpandedPipeline()], ['add', 'mult']) 

167 

168 pipeline = Pipeline("test") 

169 pipeline.addTask(MultTask, "mult") 

170 pipeline.addTask(AddTask, "add") 

171 

172 # verify that writing out the file sorts it 

173 with tempfile.NamedTemporaryFile() as tf: 

174 pipeline.write_to_uri(tf.name) 

175 loadedPipeline = Pipeline.from_uri(tf.name) 

176 

177 self.assertEqual([t.label for t in loadedPipeline.toExpandedPipeline()], ['add', 'mult']) 

178 

179 

180class PipelineTestCase(unittest.TestCase): 

181 """Test case for Pipeline and related classes 

182 """ 

183 

184 def test_initOutputNames(self): 

185 """Test for PipelineDatasetTypes.initOutputNames method. 

186 """ 

187 pipeline = makeSimplePipeline(3) 

188 dsType = set(PipelineDatasetTypes.initOutputNames(pipeline)) 

189 expected = { 

190 "packages", 

191 "add_init_output1", 

192 "add_init_output2", 

193 "add_init_output3", 

194 "task0_config", 

195 "task1_config", 

196 "task2_config", 

197 } 

198 self.assertEqual(dsType, expected) 

199 

200 

201class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

202 pass 

203 

204 

205def setup_module(module): 

206 lsst.utils.tests.init() 

207 

208 

209if __name__ == "__main__": 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true

210 lsst.utils.tests.init() 

211 unittest.main()