Coverage for tests/test_pipeTools.py: 22%
98 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for Pipeline.
23"""
25import unittest
27import lsst.pipe.base.connectionTypes as cT
28import lsst.utils.tests
29from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, pipeTools
32class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=["Visit", "Detector"]):
33 """Connections class for the example."""
35 input1 = cT.Input(
36 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
37 )
38 input2 = cT.Input(
39 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
40 )
41 output1 = cT.Output(
42 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
43 )
44 output2 = cT.Output(
45 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
46 )
48 def __init__(self, *, config=None):
49 super().__init__(config=config)
50 if not config.connections.input2:
51 self.inputs.remove("input2")
52 if not config.connections.output2:
53 self.outputs.remove("output2")
56class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections):
57 """Config for the example."""
60def _makeConfig(inputName, outputName, pipeline, label):
61 """Apply config overrides as needed.
63 Factory method for config instances.
65 inputName and outputName can be either string or tuple of strings
66 with two items max.
67 """
68 if isinstance(inputName, tuple):
69 pipeline.addConfigOverride(label, "connections.input1", inputName[0])
70 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "")
71 else:
72 pipeline.addConfigOverride(label, "connections.input1", inputName)
74 if isinstance(outputName, tuple):
75 pipeline.addConfigOverride(label, "connections.output1", outputName[0])
76 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "")
77 else:
78 pipeline.addConfigOverride(label, "connections.output1", outputName)
81class ExamplePipelineTask(PipelineTask):
82 """Example pipeline task used for testing."""
84 ConfigClass = ExamplePipelineTaskConfig
85 _DefaultName = "examplePipelineTask"
88def _makePipeline(tasks):
89 """Generate Pipeline instance.
91 Parameters
92 ----------
93 tasks : list of tuples
94 Each tuple in the list has 3 or 4 items:
95 - input DatasetType name(s), string or tuple of strings
96 - output DatasetType name(s), string or tuple of strings
97 - task label, string
98 - optional task class object, can be None
100 Returns
101 -------
102 Pipeline instance
103 """
104 pipe = Pipeline("test pipeline")
105 for task in tasks:
106 inputs = task[0]
107 outputs = task[1]
108 label = task[2]
109 klass = task[3] if len(task) > 3 else ExamplePipelineTask
110 pipe.addTask(klass, label)
111 _makeConfig(inputs, outputs, pipe, label)
112 return list(pipe.toExpandedPipeline())
115class PipelineToolsTestCase(unittest.TestCase):
116 """A test case for pipelineTools"""
118 def setUp(self):
119 pass
121 def tearDown(self):
122 pass
124 def testIsOrdered(self):
125 """Tests for pipeTools.isPipelineOrdered method"""
126 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
127 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
129 pipeline = _makePipeline(
130 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
131 )
132 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
134 pipeline = _makePipeline(
135 [("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")]
136 )
137 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
139 def testOrderPipeline(self):
140 """Tests for pipeTools.orderPipeline method"""
141 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
142 pipeline = pipeTools.orderPipeline(pipeline)
143 self.assertEqual(len(pipeline), 2)
144 self.assertEqual(pipeline[0].label, "task1")
145 self.assertEqual(pipeline[1].label, "task2")
147 pipeline = _makePipeline([("B", "C", "task2"), ("A", "B", "task1")])
148 pipeline = pipeTools.orderPipeline(pipeline)
149 self.assertEqual(len(pipeline), 2)
150 self.assertEqual(pipeline[0].label, "task1")
151 self.assertEqual(pipeline[1].label, "task2")
153 pipeline = _makePipeline(
154 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
155 )
156 pipeline = pipeTools.orderPipeline(pipeline)
157 self.assertEqual(len(pipeline), 4)
158 self.assertEqual(pipeline[0].label, "task1")
159 self.assertEqual(pipeline[1].label, "task2")
160 self.assertEqual(pipeline[2].label, "task3")
161 self.assertEqual(pipeline[3].label, "task4")
163 pipeline = _makePipeline(
164 [("A", ("B", "C"), "task1"), ("C", "E", "task3"), ("B", "D", "task2"), (("D", "E"), "F", "task4")]
165 )
166 pipeline = pipeTools.orderPipeline(pipeline)
167 self.assertEqual(len(pipeline), 4)
168 self.assertEqual(pipeline[0].label, "task1")
169 self.assertEqual(pipeline[1].label, "task2")
170 self.assertEqual(pipeline[2].label, "task3")
171 self.assertEqual(pipeline[3].label, "task4")
173 pipeline = _makePipeline(
174 [(("D", "E"), "F", "task4"), ("B", "D", "task2"), ("C", "E", "task3"), ("A", ("B", "C"), "task1")]
175 )
176 pipeline = pipeTools.orderPipeline(pipeline)
177 self.assertEqual(len(pipeline), 4)
178 self.assertEqual(pipeline[0].label, "task1")
179 self.assertEqual(pipeline[1].label, "task2")
180 self.assertEqual(pipeline[2].label, "task3")
181 self.assertEqual(pipeline[3].label, "task4")
183 pipeline = _makePipeline(
184 [(("D", "E"), "F", "task4"), ("C", "E", "task3"), ("B", "D", "task2"), ("A", ("B", "C"), "task1")]
185 )
186 pipeline = pipeTools.orderPipeline(pipeline)
187 self.assertEqual(len(pipeline), 4)
188 self.assertEqual(pipeline[0].label, "task1")
189 self.assertEqual(pipeline[1].label, "task2")
190 self.assertEqual(pipeline[2].label, "task3")
191 self.assertEqual(pipeline[3].label, "task4")
193 def testOrderPipelineExceptions(self):
194 """Tests for pipeTools.orderPipeline method exceptions."""
195 # cycle in a graph should throw ValueError
196 with self.assertRaises(pipeTools.PipelineDataCycleError):
197 _makePipeline([("A", ("A", "B"), "task1")])
199 # another kind of cycle in a graph
200 with self.assertRaises(pipeTools.PipelineDataCycleError):
201 _makePipeline(
202 [("A", "B", "task1"), ("B", "C", "task2"), ("C", "D", "task3"), ("D", "A", "task4")]
203 )
206class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
207 """Run file leak tests."""
210def setup_module(module):
211 """Configure pytest."""
212 lsst.utils.tests.init()
215if __name__ == "__main__":
216 lsst.utils.tests.init()
217 unittest.main()