Coverage for tests/test_pipeline.py: 45%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for Pipeline.
23"""
25import tempfile
26import textwrap
27import unittest
29import lsst.pex.config as pexConfig
30import lsst.utils.tests
31from lsst.pipe.base import (
32 Pipeline,
33 PipelineDatasetTypes,
34 PipelineTask,
35 PipelineTaskConfig,
36 PipelineTaskConnections,
37 Struct,
38 TaskDef,
39 connectionTypes,
40)
41from lsst.pipe.base.tests.simpleQGraph import makeSimplePipeline
44class DummyAddConnections(PipelineTaskConnections, dimensions=()):
45 dummyOutput = connectionTypes.Output("addOutput", "dummyStorage", "add output")
48class DummyMultiplyConnections(PipelineTaskConnections, dimensions=()):
49 dummyInput = connectionTypes.Input("addOutput", "dummyStorage", "add output")
52class AddConfig(PipelineTaskConfig, pipelineConnections=DummyAddConnections):
53 addend = pexConfig.Field(doc="amount to add", dtype=float, default=3.1)
56class AddTask(PipelineTask):
57 ConfigClass = AddConfig
58 _DefaultName = "add"
60 def run(self, val):
61 self.metadata.add("add", self.config.addend)
62 return Struct(
63 val=val + self.config.addend,
64 )
67class MultConfig(PipelineTaskConfig, pipelineConnections=DummyMultiplyConnections):
68 multiplicand = pexConfig.Field(doc="amount by which to multiply", dtype=float, default=2.5)
71class MultTask(PipelineTask):
72 ConfigClass = MultConfig
73 _DefaultName = "mult"
75 def run(self, val):
76 self.metadata.add("mult", self.config.multiplicand)
77 return Struct(
78 val=val * self.config.multiplicand,
79 )
82class TaskTestCase(unittest.TestCase):
83 """A test case for Task"""
85 def setUp(self):
86 pass
88 def tearDown(self):
89 pass
91 def testTaskDef(self):
92 """Tests for TaskDef structure"""
93 task1 = TaskDef(taskClass=AddTask, config=AddConfig())
94 self.assertIn("Add", task1.taskName)
95 self.assertIsInstance(task1.config, AddConfig)
96 self.assertIsNotNone(task1.taskClass)
97 self.assertEqual(task1.label, "add")
99 task2 = TaskDef(
100 "lsst.pipe.base.tests.Mult", config=MultConfig(), taskClass=MultTask, label="mult_task"
101 )
102 self.assertEqual(task2.taskName, "lsst.pipe.base.tests.Mult")
103 self.assertIsInstance(task2.config, MultConfig)
104 self.assertIs(task2.taskClass, MultTask)
105 self.assertEqual(task2.label, "mult_task")
106 self.assertEqual(task2.metadataDatasetName, "mult_task_metadata")
108 config = MultConfig()
109 config.saveMetadata = False
110 task3 = TaskDef("lsst.pipe.base.tests.Mult", config, MultTask, "mult_task")
111 self.assertIsNone(task3.metadataDatasetName)
113 def testEmpty(self):
114 """Creating empty pipeline"""
115 pipeline = Pipeline("test")
116 self.assertEqual(len(pipeline), 0)
118 def testInitial(self):
119 """Testing constructor with initial data"""
120 pipeline = Pipeline("test")
121 pipeline.addTask(AddTask, "add")
122 pipeline.addTask(MultTask, "mult")
123 self.assertEqual(len(pipeline), 2)
124 expandedPipeline = list(pipeline.toExpandedPipeline())
125 self.assertEqual(expandedPipeline[0].taskName, "AddTask")
126 self.assertEqual(expandedPipeline[1].taskName, "MultTask")
128 def testParameters(self):
129 """Test that parameters can be set and used to format"""
130 pipeline_str = textwrap.dedent(
131 """
132 description: Test Pipeline
133 parameters:
134 testValue: 5.7
135 tasks:
136 add:
137 class: test_pipeline.AddTask
138 config:
139 addend: parameters.testValue
140 """
141 )
142 # verify that parameters are used in expanding a pipeline
143 pipeline = Pipeline.fromString(pipeline_str)
144 expandedPipeline = list(pipeline.toExpandedPipeline())
145 self.assertEqual(expandedPipeline[0].config.addend, 5.7)
147 # verify that a parameter can be overridden on the "command line"
148 pipeline.addConfigOverride("parameters", "testValue", 14.9)
149 expandedPipeline = list(pipeline.toExpandedPipeline())
150 self.assertEqual(expandedPipeline[0].config.addend, 14.9)
152 # verify that a non existing parameter cant be overridden
153 with self.assertRaises(ValueError):
154 pipeline.addConfigOverride("parameters", "missingValue", 17)
156 # verify that parameters does not support files or python overrides
157 with self.assertRaises(ValueError):
158 pipeline.addConfigFile("parameters", "fakeFile")
159 with self.assertRaises(ValueError):
160 pipeline.addConfigPython("parameters", "fakePythonString")
162 def testSerialization(self):
163 pipeline = Pipeline("test")
164 pipeline.addTask(MultTask, "mult")
165 pipeline.addTask(AddTask, "add")
167 dump = str(pipeline)
168 load = Pipeline.fromString(dump)
169 self.assertEqual(pipeline, load)
171 # verify the keys keys were sorted after a call to str
172 self.assertEqual([t.label for t in pipeline.toExpandedPipeline()], ["add", "mult"])
174 pipeline = Pipeline("test")
175 pipeline.addTask(MultTask, "mult")
176 pipeline.addTask(AddTask, "add")
178 # verify that writing out the file sorts it
179 with tempfile.NamedTemporaryFile() as tf:
180 pipeline.write_to_uri(tf.name)
181 loadedPipeline = Pipeline.from_uri(tf.name)
183 self.assertEqual([t.label for t in loadedPipeline.toExpandedPipeline()], ["add", "mult"])
186class PipelineTestCase(unittest.TestCase):
187 """Test case for Pipeline and related classes"""
189 def test_initOutputNames(self):
190 """Test for PipelineDatasetTypes.initOutputNames method."""
191 pipeline = makeSimplePipeline(3)
192 dsType = set(PipelineDatasetTypes.initOutputNames(pipeline))
193 expected = {
194 "packages",
195 "add_init_output1",
196 "add_init_output2",
197 "add_init_output3",
198 "task0_config",
199 "task1_config",
200 "task2_config",
201 }
202 self.assertEqual(dsType, expected)
205class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
206 pass
209def setup_module(module):
210 lsst.utils.tests.init()
213if __name__ == "__main__": 213 ↛ 214line 213 didn't jump to line 214, because the condition on line 213 was never true
214 lsst.utils.tests.init()
215 unittest.main()