Coverage for tests/test_pipeTools.py: 22%
98 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-28 11:05 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-28 11:05 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Simple unit test for Pipeline.
29"""
31import unittest
33import lsst.pipe.base.connectionTypes as cT
34import lsst.utils.tests
35from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, pipeTools
38class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=["Visit", "Detector"]):
39 """Connections class for the example."""
41 input1 = cT.Input(
42 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
43 )
44 input2 = cT.Input(
45 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Input for this task"
46 )
47 output1 = cT.Output(
48 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
49 )
50 output2 = cT.Output(
51 name="", dimensions=["Visit", "Detector"], storageClass="example", doc="Output for this task"
52 )
54 def __init__(self, *, config=None):
55 super().__init__(config=config)
56 if not config.connections.input2:
57 self.inputs.remove("input2")
58 if not config.connections.output2:
59 self.outputs.remove("output2")
62class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections):
63 """Config for the example."""
66def _makeConfig(inputName, outputName, pipeline, label):
67 """Apply config overrides as needed.
69 Factory method for config instances.
71 inputName and outputName can be either string or tuple of strings
72 with two items max.
73 """
74 if isinstance(inputName, tuple):
75 pipeline.addConfigOverride(label, "connections.input1", inputName[0])
76 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "")
77 else:
78 pipeline.addConfigOverride(label, "connections.input1", inputName)
80 if isinstance(outputName, tuple):
81 pipeline.addConfigOverride(label, "connections.output1", outputName[0])
82 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "")
83 else:
84 pipeline.addConfigOverride(label, "connections.output1", outputName)
87class ExamplePipelineTask(PipelineTask):
88 """Example pipeline task used for testing."""
90 ConfigClass = ExamplePipelineTaskConfig
91 _DefaultName = "examplePipelineTask"
94def _makePipeline(tasks):
95 """Generate Pipeline instance.
97 Parameters
98 ----------
99 tasks : list of tuples
100 Each tuple in the list has 3 or 4 items:
102 - input DatasetType name(s), string or tuple of strings
103 - output DatasetType name(s), string or tuple of strings
104 - task label, string
105 - optional task class object, can be None.
107 Returns
108 -------
109 Pipeline instance.
110 """
111 pipe = Pipeline("test pipeline")
112 for task in tasks:
113 inputs = task[0]
114 outputs = task[1]
115 label = task[2]
116 klass = task[3] if len(task) > 3 else ExamplePipelineTask
117 pipe.addTask(klass, label)
118 _makeConfig(inputs, outputs, pipe, label)
119 return list(pipe.toExpandedPipeline())
122class PipelineToolsTestCase(unittest.TestCase):
123 """A test case for pipelineTools."""
125 def setUp(self):
126 pass
128 def tearDown(self):
129 pass
131 def testIsOrdered(self):
132 """Tests for pipeTools.isPipelineOrdered method."""
133 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
134 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
136 pipeline = _makePipeline(
137 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
138 )
139 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
141 pipeline = _makePipeline(
142 [("A", ("B", "C"), "task1"), ("C", "E", "task2"), ("B", "D", "task3"), (("D", "E"), "F", "task4")]
143 )
144 self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
146 def testOrderPipeline(self):
147 """Tests for pipeTools.orderPipeline method."""
148 pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
149 pipeline = pipeTools.orderPipeline(pipeline)
150 self.assertEqual(len(pipeline), 2)
151 self.assertEqual(pipeline[0].label, "task1")
152 self.assertEqual(pipeline[1].label, "task2")
154 pipeline = _makePipeline([("B", "C", "task2"), ("A", "B", "task1")])
155 pipeline = pipeTools.orderPipeline(pipeline)
156 self.assertEqual(len(pipeline), 2)
157 self.assertEqual(pipeline[0].label, "task1")
158 self.assertEqual(pipeline[1].label, "task2")
160 pipeline = _makePipeline(
161 [("A", ("B", "C"), "task1"), ("B", "D", "task2"), ("C", "E", "task3"), (("D", "E"), "F", "task4")]
162 )
163 pipeline = pipeTools.orderPipeline(pipeline)
164 self.assertEqual(len(pipeline), 4)
165 self.assertEqual(pipeline[0].label, "task1")
166 self.assertEqual(pipeline[1].label, "task2")
167 self.assertEqual(pipeline[2].label, "task3")
168 self.assertEqual(pipeline[3].label, "task4")
170 pipeline = _makePipeline(
171 [("A", ("B", "C"), "task1"), ("C", "E", "task3"), ("B", "D", "task2"), (("D", "E"), "F", "task4")]
172 )
173 pipeline = pipeTools.orderPipeline(pipeline)
174 self.assertEqual(len(pipeline), 4)
175 self.assertEqual(pipeline[0].label, "task1")
176 self.assertEqual(pipeline[1].label, "task2")
177 self.assertEqual(pipeline[2].label, "task3")
178 self.assertEqual(pipeline[3].label, "task4")
180 pipeline = _makePipeline(
181 [(("D", "E"), "F", "task4"), ("B", "D", "task2"), ("C", "E", "task3"), ("A", ("B", "C"), "task1")]
182 )
183 pipeline = pipeTools.orderPipeline(pipeline)
184 self.assertEqual(len(pipeline), 4)
185 self.assertEqual(pipeline[0].label, "task1")
186 self.assertEqual(pipeline[1].label, "task2")
187 self.assertEqual(pipeline[2].label, "task3")
188 self.assertEqual(pipeline[3].label, "task4")
190 pipeline = _makePipeline(
191 [(("D", "E"), "F", "task4"), ("C", "E", "task3"), ("B", "D", "task2"), ("A", ("B", "C"), "task1")]
192 )
193 pipeline = pipeTools.orderPipeline(pipeline)
194 self.assertEqual(len(pipeline), 4)
195 self.assertEqual(pipeline[0].label, "task1")
196 self.assertEqual(pipeline[1].label, "task2")
197 self.assertEqual(pipeline[2].label, "task3")
198 self.assertEqual(pipeline[3].label, "task4")
200 def testOrderPipelineExceptions(self):
201 """Tests for pipeTools.orderPipeline method exceptions."""
202 # cycle in a graph should throw ValueError
203 with self.assertRaises(pipeTools.PipelineDataCycleError):
204 _makePipeline([("A", ("A", "B"), "task1")])
206 # another kind of cycle in a graph
207 with self.assertRaises(pipeTools.PipelineDataCycleError):
208 _makePipeline(
209 [("A", "B", "task1"), ("B", "C", "task2"), ("C", "D", "task3"), ("D", "A", "task4")]
210 )
213class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
214 """Run file leak tests."""
217def setup_module(module):
218 """Configure pytest."""
219 lsst.utils.tests.init()
222if __name__ == "__main__":
223 lsst.utils.tests.init()
224 unittest.main()