Coverage for tests/test_dotTools.py: 25%
73 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-06 02:30 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-06 02:30 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Simple unit test for Pipeline.
23"""
25import io
26import re
27import unittest
29import lsst.pipe.base.connectionTypes as cT
30import lsst.utils.tests
31from lsst.ctrl.mpexec.dotTools import pipeline2dot
32from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections
35class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=()):
36 """Connections class used for testing."""
38 input1 = cT.Input(
39 name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task"
40 )
41 input2 = cT.Input(
42 name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task"
43 )
44 output1 = cT.Output(
45 name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task"
46 )
47 output2 = cT.Output(
48 name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task"
49 )
51 def __init__(self, *, config=None):
52 super().__init__(config=config)
53 if not config.connections.input2:
54 self.inputs.remove("input2")
55 if not config.connections.output2:
56 self.outputs.remove("output2")
59class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections):
60 """Example config used for testing."""
63def _makeConfig(inputName, outputName, pipeline, label):
64 """Add config overrides.
66 Factory method for config instances.
68 inputName and outputName can be either string or tuple of strings
69 with two items max.
70 """
71 if isinstance(inputName, tuple):
72 pipeline.addConfigOverride(label, "connections.input1", inputName[0])
73 pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "")
74 else:
75 pipeline.addConfigOverride(label, "connections.input1", inputName)
77 if isinstance(outputName, tuple):
78 pipeline.addConfigOverride(label, "connections.output1", outputName[0])
79 pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "")
80 else:
81 pipeline.addConfigOverride(label, "connections.output1", outputName)
84class ExamplePipelineTask(PipelineTask):
85 """Example pipeline task used for testing."""
87 ConfigClass = ExamplePipelineTaskConfig
90def _makePipeline(tasks):
91 """Generate Pipeline instance.
93 Parameters
94 ----------
95 tasks : list of tuples
96 Each tuple in the list has 3 or 4 items:
97 - input DatasetType name(s), string or tuple of strings
98 - output DatasetType name(s), string or tuple of strings
99 - task label, string
100 - optional task class object, can be None
102 Returns
103 -------
104 Pipeline instance
105 """
106 pipe = Pipeline("test pipeline")
107 for task in tasks:
108 inputs = task[0]
109 outputs = task[1]
110 label = task[2]
111 klass = task[3] if len(task) > 3 else ExamplePipelineTask
112 pipe.addTask(klass, label)
113 _makeConfig(inputs, outputs, pipe, label)
114 return list(pipe.toExpandedPipeline())
117class DotToolsTestCase(unittest.TestCase):
118 """A test case for dotTools"""
120 def testPipeline2dot(self):
121 """Tests for dotTools.pipeline2dot method"""
122 pipeline = _makePipeline(
123 [
124 ("A", ("B", "C"), "task0"),
125 ("C", "E", "task1"),
126 ("B", "D", "task2"),
127 (("D", "E"), "F", "task3"),
128 ("D.C", "G", "task4"),
129 ("task3_metadata", "H", "task5"),
130 ]
131 )
132 file = io.StringIO()
133 pipeline2dot(pipeline, file)
135 # It's hard to validate complete output, just checking few basic
136 # things, even that is not terribly stable.
137 lines = file.getvalue().strip().split("\n")
138 nglobals = 3
139 ndatasets = 10
140 ntasks = 6
141 nedges = 16
142 nextra = 2 # graph header and closing
143 self.assertEqual(len(lines), nglobals + ndatasets + ntasks + nedges + nextra)
145 # make sure that all node names are quoted
146 nodeRe = re.compile(r"^([^ ]+) \[.+\];$")
147 edgeRe = re.compile(r"^([^ ]+) *-> *([^ ]+);$")
148 for line in lines:
149 match = nodeRe.match(line)
150 if match:
151 node = match.group(1)
152 if node not in ["graph", "node", "edge"]:
153 self.assertEqual(node[0] + node[-1], '""')
154 continue
155 match = edgeRe.match(line)
156 if match:
157 for group in (1, 2):
158 node = match.group(group)
159 self.assertEqual(node[0] + node[-1], '""')
160 continue
162 # make sure components are connected appropriately
163 self.assertIn('"D" -> "D.C"', file.getvalue())
165 # make sure there is a connection created for metadata if someone
166 # tries to read it in
167 self.assertIn('"task3" -> "task3_metadata"', file.getvalue())
170class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
171 """Generic file handle leak check."""
174def setup_module(module):
175 """Set up the module for pytest."""
176 lsst.utils.tests.init()
179if __name__ == "__main__":
180 lsst.utils.tests.init()
181 unittest.main()