Coverage for tests / test_mermaid.py: 23%
111 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:44 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Simple unit test for Pipeline visualization using Mermaid."""
30import io
31import unittest
32from unittest.mock import MagicMock, patch
34import lsst.pipe.base.connectionTypes as cT
35import lsst.pipe.base.pipeline_graph.visualization as vis
36import lsst.utils.tests
37from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections
38from lsst.pipe.base.mermaid_tools import pipeline2mermaid
40MERMAID_AVAILABLE = vis._mermaid.MERMAID_AVAILABLE
42if MERMAID_AVAILABLE: 42 ↛ 46line 42 didn't jump to line 46 because the condition on line 42 was never true
43 # Since we are patching Mermaid’s `__init__`, calling
44 # `vis.Mermaid.__init__` inside the mock would cause infinite recursion.
45 # Using a reference to the original initializer prevents this.
46 _originalMermaidInit = vis._mermaid.Mermaid.__init__
48 # Mocked content for SVG and PNG files.
49 MOCKED_SVG_CONTENT = b"<svg>Mocked SVG Content</svg>"
50 MOCKED_PNG_CONTENT = b"\x89PNG\r\n\x1a\nMocked PNG Content"
52 def _mockMermaidInit(self, *args, **kwargs):
53 # Call the original initializer to set things up.
54 _originalMermaidInit(self, *args, **kwargs)
56 # Override only the `svg_response` and `img_response` attributes.
57 self.svg_response = MagicMock(content=MOCKED_SVG_CONTENT)
58 self.img_response = MagicMock(content=MOCKED_PNG_CONTENT)
61class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=("visit", "detector")):
62 """Connections class used for testing.
64 Parameters
65 ----------
66 config : `~lsst.pipe.base.PipelineTaskConfig`
67 The config to use for this connections class.
68 """
70 input1 = cT.Input(
71 name="",
72 dimensions=["visit", "detector"],
73 storageClass="ExposureF",
74 doc="Input for this task",
75 )
76 input1Prerequisite = cT.PrerequisiteInput(
77 name="",
78 dimensions=["visit", "detector"],
79 storageClass="ExposureF",
80 doc="Prerequisite input for this task",
81 )
82 input2 = cT.Input(
83 name="",
84 dimensions=["visit", "detector"],
85 storageClass="ExposureF",
86 doc="Input for this task",
87 )
88 input2Prerequisite = cT.PrerequisiteInput(
89 name="",
90 dimensions=["visit", "detector"],
91 storageClass="ExposureF",
92 doc="Prerequisite input for this task",
93 )
94 output1 = cT.Output(
95 name="",
96 dimensions=["visit", "detector"],
97 storageClass="ExposureF",
98 doc="Output for this task",
99 )
100 output2 = cT.Output(
101 name="",
102 dimensions=["visit", "detector"],
103 storageClass="ExposureF",
104 doc="Output for this task",
105 )
107 def __init__(self, *, config=None):
108 super().__init__(config=config)
109 for x in config.connections:
110 # Avoid complaints about incorrect storage class for metadata.
111 if getattr(config.connections, x).endswith("_metadata"):
112 getattr(self, x).__dict__["storageClass"] = "TaskMetadata"
114 # Remove inputs and outputs that are not used.
115 if x.startswith("input"):
116 if not getattr(config.connections, x):
117 if x in self.inputs:
118 self.inputs.remove(x)
119 elif x in self.prerequisiteInputs:
120 self.prerequisiteInputs.remove(x)
121 elif x.startswith("output"):
122 if not getattr(config.connections, x):
123 if x in self.outputs:
124 self.outputs.remove(x)
127class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections):
128 """Example config used for testing."""
131def _makeConfig(inputNames, outputNames, pipeline, label, inputTypes=None):
132 """Configure pipeline connections by adding config overrides.
134 Parameters
135 ----------
136 inputNames : `list` or `tuple` of `str`
137 A list or tuple containing up to two input names.
138 outputNames : `list` or `tuple` of `str`
139 A list or tuple containing up to two output names.
140 pipeline : `~lsst.pipe.base.Pipeline`
141 The pipeline object where configuration overrides will be applied.
142 label : `str`
143 The label associated with the configuration.
144 inputTypes : `list` or `tuple` of `str`, optional
145 A list or tuple specifying input types. Elements can be 'Prerequisite'
146 or empty strings (the default, indicating a regular input).
147 """
148 # Ensure at least two elements, default to empty strings.
149 inputTypes = (inputTypes or ["", ""])[:2]
151 for i in (0, 1):
152 inputName = inputNames[i] if i < len(inputNames) else ""
153 isPrerequisite = inputTypes[i] == "Prerequisite" if i < len(inputNames) else False
155 pipeline.addConfigOverride(label, f"connections.input{i + 1}", "" if isPrerequisite else inputName)
156 pipeline.addConfigOverride(
157 label, f"connections.input{i + 1}Prerequisite", inputName if isPrerequisite else ""
158 )
160 outputName = outputNames[i] if i < len(outputNames) else ""
161 pipeline.addConfigOverride(label, f"connections.output{i + 1}", outputName)
164class ExamplePipelineTask(PipelineTask):
165 """Example pipeline task used for testing."""
167 ConfigClass = ExamplePipelineTaskConfig
168 _DefaultName = "examplePipelineTask"
171def _makePipeline(tasks, inputTypeOverrides=None):
172 """Generate a Pipeline instance and return it along with task definitions.
174 Parameters
175 ----------
176 tasks : `list` of `tuple`
177 Each tuple in the list has 3 or 4 items:
178 - input DatasetType name(s), string or tuple of strings
179 - output DatasetType name(s), string or tuple of strings
180 - task label, string
181 - optional task class object, can be None
182 inputTypeOverrides : `dict`, optional
183 Dictionary of input type overrides.
185 Returns
186 -------
187 pipe: `~lsst.pipe.base.Pipeline`
188 The pipeline instance.
189 task_defs: `list` of `~lsst.pipe.base.TaskDef`
190 List of task definitions.
191 """
192 pipe = Pipeline("test pipeline")
193 for task in tasks:
194 inputs = task[0]
195 outputs = task[1]
196 label = task[2]
197 klass = task[3] if len(task) > 3 else ExamplePipelineTask
198 pipe.addTask(klass, label)
199 if not isinstance(inputs, tuple | list):
200 inputs = (inputs,)
201 if not isinstance(outputs, tuple | list):
202 outputs = (outputs,)
203 if inputTypeOverrides:
204 inputTypes = [inputTypeOverrides.get(x) for x in inputs]
205 _makeConfig(inputs, outputs, pipe, label, inputTypes)
206 taskDefs = list(pipe.to_graph()._iter_task_defs())
207 return pipe, taskDefs
210class MermaidTestCase(unittest.TestCase):
211 """A test case for Mermaid pipeline visualization."""
213 def setUp(self):
214 # Create a pipeline with example tasks and dataset types.
215 self.pipeline, self.pipelineTaskDefs = _makePipeline(
216 [
217 ("A", ("B", "C"), "task0"),
218 ("B", "E", "task1"),
219 ("B", ("D", "F"), "task2"),
220 ("D", "G", "task3"),
221 (("C", "F"), "H", "task4"),
222 ("task4_metadata", "I", "task5"),
223 ],
224 inputTypeOverrides={"A": "Prerequisite"},
225 )
227 # Create a pipeline graph for processing.
228 self.graph = self.pipeline.to_graph(visualization_only=True)
230 def validateMermaidSource(self, file):
231 """Validate the Mermaid source output for basic components.
233 Parameters
234 ----------
235 file : `~io.StringIO`
236 The in-memory file-like object containing the Mermaid source.
237 """
238 # It's hard to validate the complete output, just checking some basic
239 # things, but even that is not terribly stable.
240 fileValue = file.getvalue()
241 lines = fileValue.strip().split("\n")
242 nClassDefs = 2
243 nTasks = 6
244 nTaskClasses = 6
245 nDatasets = 10
246 nDatasetClasses = 10
247 nEdges = 16
248 nLinkStyles = 2 # For the default and pre-requisite edges
249 nExtra = 1 # 'flowchart' opening line
251 self.assertEqual(
252 len(lines),
253 nClassDefs + nTasks + nTaskClasses + nDatasets + nDatasetClasses + nEdges + nLinkStyles + nExtra,
254 )
256 # Confirm Mermaid syntax begins with a top-down flowchart declaration.
257 self.assertTrue(fileValue.startswith("flowchart TD"))
259 # Select the reference set based on naming style.
260 namingRef = {
261 "edge1": "B:0 --> task1:2",
262 "edge2": "task2:2 --> F:0",
263 "classDef": "class E:0 dsType;",
264 "metadata": "task4:2 --> task4_metadata:0",
265 }
267 # Make sure components are connected appropriately.
268 self.assertIn(namingRef["edge1"], fileValue)
269 self.assertIn(namingRef["edge2"], fileValue)
271 # Make sure class definitions are created for datasets.
272 self.assertIn(namingRef["classDef"], fileValue)
274 # Make sure there is a connection created for metadata.
275 self.assertIn(namingRef["metadata"], fileValue)
277 def test_pipeline2mermaid(self):
278 """Validate Mermaid syntax generated by `pipeline2mermaid`."""
279 for source in [self.pipeline, self.pipelineTaskDefs]:
280 file = io.StringIO()
281 pipeline2mermaid(source, file, expand_dimensions=False)
282 self.validateMermaidSource(file)
284 def test_show_mermaid_source(self):
285 """Test generating Mermaid source using `show_mermaid`."""
286 # Validate Mermaid source output (mmd format) in a text stream.
287 file = io.StringIO()
288 vis.show_mermaid(self.graph, file, dataset_types=True, task_classes="full")
289 self.validateMermaidSource(file)
291 @unittest.skipIf(not MERMAID_AVAILABLE, "Skipping image rendering tests since `mermaid-py` is missing.")
292 def test_show_mermaid_image(self):
293 """Test the output of the `show_mermaid` method in png and svg formats.
295 For generating image outputs, the `show_mermaid` method invokes a
296 remote rendering service (`mermaid.ink`), which can lead to timeouts at
297 times. To bypass this instability, we patch Mermaid's `__init__` method
298 to override its `svg_response` and `img_response` attributes, allowing
299 us to return predictable, mocked contents instead.
300 """
301 # Test image rendering formats with mocked responses in binary streams.
302 with patch.object(vis._mermaid.Mermaid, "__init__", new=_mockMermaidInit):
303 for fmt, expected in [("svg", MOCKED_SVG_CONTENT), ("png", MOCKED_PNG_CONTENT)]:
304 file = io.BytesIO()
305 vis.show_mermaid(self.graph, file, output_format=fmt, dataset_types=True, task_classes="full")
306 file.seek(0) # Read from the beginning.
307 self.assertEqual(file.read(), expected)
310class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
311 """Generic file handle leak check."""
314def setup_module(module):
315 """Set up the module for pytest.
317 Parameters
318 ----------
319 module : `~types.ModuleType`
320 Module to set up.
321 """
322 lsst.utils.tests.init()
325if __name__ == "__main__":
326 lsst.utils.tests.init()
327 unittest.main()