Coverage for tests/test_pipeTools.py: 21%
103 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 11:14 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 11:14 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for Pipeline.
23"""
25import unittest
27import lsst.pipe.base.connectionTypes as cT
28import lsst.utils.tests
29from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, pipeTools
32class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=["Visit", "Detector"]):
33 """Connections class for the example."""
35 input1 = cT.Input(
36 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
37 )
38 input2 = cT.Input(
39 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
40 )
41 output1 = cT.Output(
42 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
43 )
44 output2 = cT.Output(
45 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
46 )
48 def __init__(self, *, config=None):
49 super().__init__(config=config)
50 if not config.connections.input2:
51 self.inputs.remove("input2")
52 if not config.connections.output2:
53 self.outputs.remove("output2")
56class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections):
57 """Config for the example."""
60def _makeConfig(inputName, outputName, pipeline, label):
61 """Apply config overrides as needed.
63 Factory method for config instances.
65 inputName and outputName can be either string or tuple of strings
66 with two items max.
67 """
68 if isinstance(inputName, tuple):
69 pipeline.addConfigOverride(label, "connections.input1", inputName[0])
70 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "")
71 else:
72 pipeline.addConfigOverride(label, "connections.input1", inputName)
74 if isinstance(outputName, tuple):
75 pipeline.addConfigOverride(label, "connections.output1", outputName[0])
76 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "")
77 else:
78 pipeline.addConfigOverride(label, "connections.output1", outputName)
81class ExamplePipelineTask(PipelineTask):
82 """Example pipeline task used for testing."""
84 ConfigClass = ExamplePipelineTaskConfig
85 _DefaultName = "examplePipelineTask"
88def _makePipeline(tasks):
89 """Generate Pipeline instance.
91 Parameters
92 ----------
93 tasks : list of tuples
94 Each tuple in the list has 3 or 4 items:
95 - input DatasetType name(s), string or tuple of strings
96 - output DatasetType name(s), string or tuple of strings
97 - task label, string
98 - optional task class object, can be None
100 Returns
101 -------
102 Pipeline instance
103 """
104 pipe = Pipeline("test pipeline")
105 for task in tasks:
106 inputs = task[0]
107 outputs = task[1]
108 label = task[2]
109 klass = task[3] if len(task) > 3 else ExamplePipelineTask
110 pipe.addTask(klass, label)
111 _makeConfig(inputs, outputs, pipe, label)
112 return list(pipe.toExpandedPipeline())
115class PipelineToolsTestCase(unittest.TestCase):
116 """A test case for pipelineTools"""
118 def setUp(self):
119 pass
121 def tearDown(self):
122 pass
124 def testIsOrdered(self):
125 """Tests for pipeTools.isPipelineOrdered method"""
126 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
127 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
129 pipeline = _makePipeline(
130 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
131 )
132 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
134 pipeline = _makePipeline(
135 [("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")]
136 )
137 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
139 def testIsOrderedExceptions(self):
140 """Tests for pipeTools.isPipelineOrdered method exceptions"""
141 # two producers should throw ValueError
142 with self.assertRaises(pipeTools.DuplicateOutputError):
143 _makePipeline(
144 [
145 ("A", "B", "task1"),
146 ("B", "C", "task2"),
147 ("A", "C", "task3"),
148 ]
149 )
151 def testOrderPipeline(self):
152 """Tests for pipeTools.orderPipeline method"""
153 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
154 pipeline = pipeTools.orderPipeline(pipeline)
155 self.assertEqual(len(pipeline), 2)
156 self.assertEqual(pipeline[0].label, "task1")
157 self.assertEqual(pipeline[1].label, "task2")
159 pipeline = _makePipeline([("B", "C", "task2"), ("A", "B", "task1")])
160 pipeline = pipeTools.orderPipeline(pipeline)
161 self.assertEqual(len(pipeline), 2)
162 self.assertEqual(pipeline[0].label, "task1")
163 self.assertEqual(pipeline[1].label, "task2")
165 pipeline = _makePipeline(
166 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
167 )
168 pipeline = pipeTools.orderPipeline(pipeline)
169 self.assertEqual(len(pipeline), 4)
170 self.assertEqual(pipeline[0].label, "task1")
171 self.assertEqual(pipeline[1].label, "task2")
172 self.assertEqual(pipeline[2].label, "task3")
173 self.assertEqual(pipeline[3].label, "task4")
175 pipeline = _makePipeline(
176 [("A", ("B", "C"), "task1"), ("C", "E", "task3"), ("B", "D", "task2"), (("D", "E"), "F", "task4")]
177 )
178 pipeline = pipeTools.orderPipeline(pipeline)
179 self.assertEqual(len(pipeline), 4)
180 self.assertEqual(pipeline[0].label, "task1")
181 self.assertEqual(pipeline[1].label, "task2")
182 self.assertEqual(pipeline[2].label, "task3")
183 self.assertEqual(pipeline[3].label, "task4")
185 pipeline = _makePipeline(
186 [(("D", "E"), "F", "task4"), ("B", "D", "task2"), ("C", "E", "task3"), ("A", ("B", "C"), "task1")]
187 )
188 pipeline = pipeTools.orderPipeline(pipeline)
189 self.assertEqual(len(pipeline), 4)
190 self.assertEqual(pipeline[0].label, "task1")
191 self.assertEqual(pipeline[1].label, "task2")
192 self.assertEqual(pipeline[2].label, "task3")
193 self.assertEqual(pipeline[3].label, "task4")
195 pipeline = _makePipeline(
196 [(("D", "E"), "F", "task4"), ("C", "E", "task3"), ("B", "D", "task2"), ("A", ("B", "C"), "task1")]
197 )
198 pipeline = pipeTools.orderPipeline(pipeline)
199 self.assertEqual(len(pipeline), 4)
200 self.assertEqual(pipeline[0].label, "task1")
201 self.assertEqual(pipeline[1].label, "task2")
202 self.assertEqual(pipeline[2].label, "task3")
203 self.assertEqual(pipeline[3].label, "task4")
205 def testOrderPipelineExceptions(self):
206 """Tests for pipeTools.orderPipeline method exceptions"""
207 with self.assertRaises(pipeTools.DuplicateOutputError):
208 _makePipeline(
209 [
210 ("A", "B", "task1"),
211 ("B", "C", "task2"),
212 ("A", "C", "task3"),
213 ]
214 )
216 # cycle in a graph should throw ValueError
217 with self.assertRaises(pipeTools.PipelineDataCycleError):
218 _makePipeline([("A", ("A", "B"), "task1")])
220 # another kind of cycle in a graph
221 with self.assertRaises(pipeTools.PipelineDataCycleError):
222 _makePipeline(
223 [("A", "B", "task1"), ("B", "C", "task2"), ("C", "D", "task3"), ("D", "A", "task4")]
224 )
227class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
228 """Run file leak tests."""
231def setup_module(module):
232 """Configure pytest."""
233 lsst.utils.tests.init()
236if __name__ == "__main__":
237 lsst.utils.tests.init()
238 unittest.main()