Coverage for tests/test_pipeline.py: 45%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for Pipeline.
23"""
25import tempfile
26import textwrap
27import unittest
29import lsst.pex.config as pexConfig
30from lsst.pipe.base import (Struct, PipelineTask, PipelineTaskConfig, Pipeline, TaskDef,
31 PipelineTaskConnections, PipelineDatasetTypes, connectionTypes)
32from lsst.pipe.base.tests.simpleQGraph import makeSimplePipeline
33import lsst.utils.tests
36class DummyAddConnections(PipelineTaskConnections, dimensions=()):
37 dummyOutput = connectionTypes.Output("addOutput", "dummyStorage", "add output")
40class DummyMultiplyConnections(PipelineTaskConnections, dimensions=()):
41 dummyInput = connectionTypes.Input("addOutput", "dummyStorage", "add output")
44class AddConfig(PipelineTaskConfig, pipelineConnections=DummyAddConnections):
45 addend = pexConfig.Field(doc="amount to add", dtype=float, default=3.1)
48class AddTask(PipelineTask):
49 ConfigClass = AddConfig
50 _DefaultName = "add"
52 def run(self, val):
53 self.metadata.add("add", self.config.addend)
54 return Struct(
55 val=val + self.config.addend,
56 )
59class MultConfig(PipelineTaskConfig, pipelineConnections=DummyMultiplyConnections):
60 multiplicand = pexConfig.Field(doc="amount by which to multiply", dtype=float, default=2.5)
63class MultTask(PipelineTask):
64 ConfigClass = MultConfig
65 _DefaultName = "mult"
67 def run(self, val):
68 self.metadata.add("mult", self.config.multiplicand)
69 return Struct(
70 val=val * self.config.multiplicand,
71 )
74class TaskTestCase(unittest.TestCase):
75 """A test case for Task
76 """
78 def setUp(self):
79 pass
81 def tearDown(self):
82 pass
84 def testTaskDef(self):
85 """Tests for TaskDef structure
86 """
87 task1 = TaskDef(taskClass=AddTask, config=AddConfig())
88 self.assertIn("Add", task1.taskName)
89 self.assertIsInstance(task1.config, AddConfig)
90 self.assertIsNotNone(task1.taskClass)
91 self.assertEqual(task1.label, "add")
93 task2 = TaskDef("lsst.pipe.base.tests.Mult", config=MultConfig(), taskClass=MultTask,
94 label="mult_task")
95 self.assertEqual(task2.taskName, "lsst.pipe.base.tests.Mult")
96 self.assertIsInstance(task2.config, MultConfig)
97 self.assertIs(task2.taskClass, MultTask)
98 self.assertEqual(task2.label, "mult_task")
99 self.assertEqual(task2.metadataDatasetName, "mult_task_metadata")
101 config = MultConfig()
102 config.saveMetadata = False
103 task3 = TaskDef("lsst.pipe.base.tests.Mult", config, MultTask, "mult_task")
104 self.assertIsNone(task3.metadataDatasetName)
106 def testEmpty(self):
107 """Creating empty pipeline
108 """
109 pipeline = Pipeline("test")
110 self.assertEqual(len(pipeline), 0)
112 def testInitial(self):
113 """Testing constructor with initial data
114 """
115 pipeline = Pipeline("test")
116 pipeline.addTask(AddTask, "add")
117 pipeline.addTask(MultTask, "mult")
118 self.assertEqual(len(pipeline), 2)
119 expandedPipeline = list(pipeline.toExpandedPipeline())
120 self.assertEqual(expandedPipeline[0].taskName, "AddTask")
121 self.assertEqual(expandedPipeline[1].taskName, "MultTask")
123 def testParameters(self):
124 """Test that parameters can be set and used to format
125 """
126 pipeline_str = textwrap.dedent("""
127 description: Test Pipeline
128 parameters:
129 testValue: 5.7
130 tasks:
131 add:
132 class: test_pipeline.AddTask
133 config:
134 addend: parameters.testValue
135 """)
136 # verify that parameters are used in expanding a pipeline
137 pipeline = Pipeline.fromString(pipeline_str)
138 expandedPipeline = list(pipeline.toExpandedPipeline())
139 self.assertEqual(expandedPipeline[0].config.addend, 5.7)
141 # verify that a parameter can be overridden on the "command line"
142 pipeline.addConfigOverride("parameters", "testValue", 14.9)
143 expandedPipeline = list(pipeline.toExpandedPipeline())
144 self.assertEqual(expandedPipeline[0].config.addend, 14.9)
146 # verify that a non existing parameter cant be overridden
147 with self.assertRaises(ValueError):
148 pipeline.addConfigOverride("parameters", "missingValue", 17)
150 # verify that parameters does not support files or python overrides
151 with self.assertRaises(ValueError):
152 pipeline.addConfigFile("parameters", "fakeFile")
153 with self.assertRaises(ValueError):
154 pipeline.addConfigPython("parameters", "fakePythonString")
156 def testSerialization(self):
157 pipeline = Pipeline("test")
158 pipeline.addTask(MultTask, "mult")
159 pipeline.addTask(AddTask, "add")
161 dump = str(pipeline)
162 load = Pipeline.fromString(dump)
163 self.assertEqual(pipeline, load)
165 # verify the keys keys were sorted after a call to str
166 self.assertEqual([t.label for t in pipeline.toExpandedPipeline()], ['add', 'mult'])
168 pipeline = Pipeline("test")
169 pipeline.addTask(MultTask, "mult")
170 pipeline.addTask(AddTask, "add")
172 # verify that writing out the file sorts it
173 with tempfile.NamedTemporaryFile() as tf:
174 pipeline.write_to_uri(tf.name)
175 loadedPipeline = Pipeline.from_uri(tf.name)
177 self.assertEqual([t.label for t in loadedPipeline.toExpandedPipeline()], ['add', 'mult'])
180class PipelineTestCase(unittest.TestCase):
181 """Test case for Pipeline and related classes
182 """
184 def test_initOutputNames(self):
185 """Test for PipelineDatasetTypes.initOutputNames method.
186 """
187 pipeline = makeSimplePipeline(3)
188 dsType = set(PipelineDatasetTypes.initOutputNames(pipeline))
189 expected = {
190 "packages",
191 "add_init_output1",
192 "add_init_output2",
193 "add_init_output3",
194 "task0_config",
195 "task1_config",
196 "task2_config",
197 }
198 self.assertEqual(dsType, expected)
201class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
202 pass
205def setup_module(module):
206 lsst.utils.tests.init()
209if __name__ == "__main__": 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true
210 lsst.utils.tests.init()
211 unittest.main()