Coverage for tests/test_pipeTools.py: 28%
107 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-19 12:18 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-19 12:18 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for Pipeline.
23"""
25import unittest
27import lsst.pipe.base.connectionTypes as cT
28import lsst.utils.tests
29from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, pipeTools
32class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=["Visit", "Detector"]):
33 input1 = cT.Input(
34 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
35 )
36 input2 = cT.Input(
37 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
38 )
39 output1 = cT.Output(
40 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
41 )
42 output2 = cT.Output(
43 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
44 )
46 def __init__(self, *, config=None):
47 super().__init__(config=config)
48 if not config.connections.input2:
49 self.inputs.remove("input2")
50 if not config.connections.output2:
51 self.outputs.remove("output2")
54class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections):
55 pass
58def _makeConfig(inputName, outputName, pipeline, label):
59 """Factory method for config instances
61 inputName and outputName can be either string or tuple of strings
62 with two items max.
63 """
64 if isinstance(inputName, tuple):
65 pipeline.addConfigOverride(label, "connections.input1", inputName[0])
66 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "")
67 else:
68 pipeline.addConfigOverride(label, "connections.input1", inputName)
70 if isinstance(outputName, tuple):
71 pipeline.addConfigOverride(label, "connections.output1", outputName[0])
72 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "")
73 else:
74 pipeline.addConfigOverride(label, "connections.output1", outputName)
77class ExamplePipelineTask(PipelineTask):
78 ConfigClass = ExamplePipelineTaskConfig
81def _makePipeline(tasks):
82 """Generate Pipeline instance.
84 Parameters
85 ----------
86 tasks : list of tuples
87 Each tuple in the list has 3 or 4 items:
88 - input DatasetType name(s), string or tuple of strings
89 - output DatasetType name(s), string or tuple of strings
90 - task label, string
91 - optional task class object, can be None
93 Returns
94 -------
95 Pipeline instance
96 """
97 pipe = Pipeline("test pipeline")
98 for task in tasks:
99 inputs = task[0]
100 outputs = task[1]
101 label = task[2]
102 klass = task[3] if len(task) > 3 else ExamplePipelineTask
103 pipe.addTask(klass, label)
104 _makeConfig(inputs, outputs, pipe, label)
105 return list(pipe.toExpandedPipeline())
108class PipelineToolsTestCase(unittest.TestCase):
109 """A test case for pipelineTools"""
111 def setUp(self):
112 pass
114 def tearDown(self):
115 pass
117 def testIsOrdered(self):
118 """Tests for pipeTools.isPipelineOrdered method"""
119 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
120 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
122 pipeline = _makePipeline(
123 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
124 )
125 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
127 pipeline = _makePipeline(
128 [("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")]
129 )
130 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
132 def testIsOrderedExceptions(self):
133 """Tests for pipeTools.isPipelineOrdered method exceptions"""
134 # two producers should throw ValueError
135 with self.assertRaises(pipeTools.DuplicateOutputError):
136 _makePipeline(
137 [
138 ("A", "B", "task1"),
139 ("B", "C", "task2"),
140 ("A", "C", "task3"),
141 ]
142 )
144 def testOrderPipeline(self):
145 """Tests for pipeTools.orderPipeline method"""
146 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
147 pipeline = pipeTools.orderPipeline(pipeline)
148 self.assertEqual(len(pipeline), 2)
149 self.assertEqual(pipeline[0].label, "task1")
150 self.assertEqual(pipeline[1].label, "task2")
152 pipeline = _makePipeline([("B", "C", "task2"), ("A", "B", "task1")])
153 pipeline = pipeTools.orderPipeline(pipeline)
154 self.assertEqual(len(pipeline), 2)
155 self.assertEqual(pipeline[0].label, "task1")
156 self.assertEqual(pipeline[1].label, "task2")
158 pipeline = _makePipeline(
159 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
160 )
161 pipeline = pipeTools.orderPipeline(pipeline)
162 self.assertEqual(len(pipeline), 4)
163 self.assertEqual(pipeline[0].label, "task1")
164 self.assertEqual(pipeline[1].label, "task2")
165 self.assertEqual(pipeline[2].label, "task3")
166 self.assertEqual(pipeline[3].label, "task4")
168 pipeline = _makePipeline(
169 [("A", ("B", "C"), "task1"), ("C", "E", "task3"), ("B", "D", "task2"), (("D", "E"), "F", "task4")]
170 )
171 pipeline = pipeTools.orderPipeline(pipeline)
172 self.assertEqual(len(pipeline), 4)
173 self.assertEqual(pipeline[0].label, "task1")
174 self.assertEqual(pipeline[1].label, "task2")
175 self.assertEqual(pipeline[2].label, "task3")
176 self.assertEqual(pipeline[3].label, "task4")
178 pipeline = _makePipeline(
179 [(("D", "E"), "F", "task4"), ("B", "D", "task2"), ("C", "E", "task3"), ("A", ("B", "C"), "task1")]
180 )
181 pipeline = pipeTools.orderPipeline(pipeline)
182 self.assertEqual(len(pipeline), 4)
183 self.assertEqual(pipeline[0].label, "task1")
184 self.assertEqual(pipeline[1].label, "task2")
185 self.assertEqual(pipeline[2].label, "task3")
186 self.assertEqual(pipeline[3].label, "task4")
188 pipeline = _makePipeline(
189 [(("D", "E"), "F", "task4"), ("C", "E", "task3"), ("B", "D", "task2"), ("A", ("B", "C"), "task1")]
190 )
191 pipeline = pipeTools.orderPipeline(pipeline)
192 self.assertEqual(len(pipeline), 4)
193 self.assertEqual(pipeline[0].label, "task1")
194 self.assertEqual(pipeline[1].label, "task2")
195 self.assertEqual(pipeline[2].label, "task3")
196 self.assertEqual(pipeline[3].label, "task4")
198 def testOrderPipelineExceptions(self):
199 """Tests for pipeTools.orderPipeline method exceptions"""
200 with self.assertRaises(pipeTools.DuplicateOutputError):
201 _makePipeline(
202 [
203 ("A", "B", "task1"),
204 ("B", "C", "task2"),
205 ("A", "C", "task3"),
206 ]
207 )
209 # cycle in a graph should throw ValueError
210 with self.assertRaises(pipeTools.PipelineDataCycleError):
211 _makePipeline([("A", ("A", "B"), "task1")])
213 # another kind of cycle in a graph
214 with self.assertRaises(pipeTools.PipelineDataCycleError):
215 _makePipeline(
216 [("A", "B", "task1"), ("B", "C", "task2"), ("C", "D", "task3"), ("D", "A", "task4")]
217 )
220class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
221 pass
224def setup_module(module):
225 lsst.utils.tests.init()
228if __name__ == "__main__": 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true
229 lsst.utils.tests.init()
230 unittest.main()