Coverage for tests / test_mermaid.py: 23%

111 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:59 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for Pipeline visualization using Mermaid.""" 

29 

30import io 

31import unittest 

32from unittest.mock import MagicMock, patch 

33 

34import lsst.pipe.base.connectionTypes as cT 

35import lsst.pipe.base.pipeline_graph.visualization as vis 

36import lsst.utils.tests 

37from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections 

38from lsst.pipe.base.mermaid_tools import pipeline2mermaid 

39 

40MERMAID_AVAILABLE = vis._mermaid.MERMAID_AVAILABLE 

41 

42if MERMAID_AVAILABLE: 42 ↛ 46line 42 didn't jump to line 46 because the condition on line 42 was never true

43 # Since we are patching Mermaid’s `__init__`, calling 

44 # `vis.Mermaid.__init__` inside the mock would cause infinite recursion. 

45 # Using a reference to the original initializer prevents this. 

46 _originalMermaidInit = vis._mermaid.Mermaid.__init__ 

47 

48 # Mocked content for SVG and PNG files. 

49 MOCKED_SVG_CONTENT = b"<svg>Mocked SVG Content</svg>" 

50 MOCKED_PNG_CONTENT = b"\x89PNG\r\n\x1a\nMocked PNG Content" 

51 

52 def _mockMermaidInit(self, *args, **kwargs): 

53 # Call the original initializer to set things up. 

54 _originalMermaidInit(self, *args, **kwargs) 

55 

56 # Override only the `svg_response` and `img_response` attributes. 

57 self.svg_response = MagicMock(content=MOCKED_SVG_CONTENT) 

58 self.img_response = MagicMock(content=MOCKED_PNG_CONTENT) 

59 

60 

61class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=("visit", "detector")): 

62 """Connections class used for testing. 

63 

64 Parameters 

65 ---------- 

66 config : `~lsst.pipe.base.PipelineTaskConfig` 

67 The config to use for this connections class. 

68 """ 

69 

70 input1 = cT.Input( 

71 name="", 

72 dimensions=["visit", "detector"], 

73 storageClass="ExposureF", 

74 doc="Input for this task", 

75 ) 

76 input1Prerequisite = cT.PrerequisiteInput( 

77 name="", 

78 dimensions=["visit", "detector"], 

79 storageClass="ExposureF", 

80 doc="Prerequisite input for this task", 

81 ) 

82 input2 = cT.Input( 

83 name="", 

84 dimensions=["visit", "detector"], 

85 storageClass="ExposureF", 

86 doc="Input for this task", 

87 ) 

88 input2Prerequisite = cT.PrerequisiteInput( 

89 name="", 

90 dimensions=["visit", "detector"], 

91 storageClass="ExposureF", 

92 doc="Prerequisite input for this task", 

93 ) 

94 output1 = cT.Output( 

95 name="", 

96 dimensions=["visit", "detector"], 

97 storageClass="ExposureF", 

98 doc="Output for this task", 

99 ) 

100 output2 = cT.Output( 

101 name="", 

102 dimensions=["visit", "detector"], 

103 storageClass="ExposureF", 

104 doc="Output for this task", 

105 ) 

106 

107 def __init__(self, *, config=None): 

108 super().__init__(config=config) 

109 for x in config.connections: 

110 # Avoid complaints about incorrect storage class for metadata. 

111 if getattr(config.connections, x).endswith("_metadata"): 

112 getattr(self, x).__dict__["storageClass"] = "TaskMetadata" 

113 

114 # Remove inputs and outputs that are not used. 

115 if x.startswith("input"): 

116 if not getattr(config.connections, x): 

117 if x in self.inputs: 

118 self.inputs.remove(x) 

119 elif x in self.prerequisiteInputs: 

120 self.prerequisiteInputs.remove(x) 

121 elif x.startswith("output"): 

122 if not getattr(config.connections, x): 

123 if x in self.outputs: 

124 self.outputs.remove(x) 

125 

126 

127class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections): 

128 """Example config used for testing.""" 

129 

130 

131def _makeConfig(inputNames, outputNames, pipeline, label, inputTypes=None): 

132 """Configure pipeline connections by adding config overrides. 

133 

134 Parameters 

135 ---------- 

136 inputNames : `list` or `tuple` of `str` 

137 A list or tuple containing up to two input names. 

138 outputNames : `list` or `tuple` of `str` 

139 A list or tuple containing up to two output names. 

140 pipeline : `~lsst.pipe.base.Pipeline` 

141 The pipeline object where configuration overrides will be applied. 

142 label : `str` 

143 The label associated with the configuration. 

144 inputTypes : `list` or `tuple` of `str`, optional 

145 A list or tuple specifying input types. Elements can be 'Prerequisite' 

146 or empty strings (the default, indicating a regular input). 

147 """ 

148 # Ensure at least two elements, default to empty strings. 

149 inputTypes = (inputTypes or ["", ""])[:2] 

150 

151 for i in (0, 1): 

152 inputName = inputNames[i] if i < len(inputNames) else "" 

153 isPrerequisite = inputTypes[i] == "Prerequisite" if i < len(inputNames) else False 

154 

155 pipeline.addConfigOverride(label, f"connections.input{i + 1}", "" if isPrerequisite else inputName) 

156 pipeline.addConfigOverride( 

157 label, f"connections.input{i + 1}Prerequisite", inputName if isPrerequisite else "" 

158 ) 

159 

160 outputName = outputNames[i] if i < len(outputNames) else "" 

161 pipeline.addConfigOverride(label, f"connections.output{i + 1}", outputName) 

162 

163 

164class ExamplePipelineTask(PipelineTask): 

165 """Example pipeline task used for testing.""" 

166 

167 ConfigClass = ExamplePipelineTaskConfig 

168 _DefaultName = "examplePipelineTask" 

169 

170 

171def _makePipeline(tasks, inputTypeOverrides=None): 

172 """Generate a Pipeline instance and return it along with task definitions. 

173 

174 Parameters 

175 ---------- 

176 tasks : `list` of `tuple` 

177 Each tuple in the list has 3 or 4 items: 

178 - input DatasetType name(s), string or tuple of strings 

179 - output DatasetType name(s), string or tuple of strings 

180 - task label, string 

181 - optional task class object, can be None 

182 inputTypeOverrides : `dict`, optional 

183 Dictionary of input type overrides. 

184 

185 Returns 

186 ------- 

187 pipe: `~lsst.pipe.base.Pipeline` 

188 The pipeline instance. 

189 task_defs: `list` of `~lsst.pipe.base.TaskDef` 

190 List of task definitions. 

191 """ 

192 pipe = Pipeline("test pipeline") 

193 for task in tasks: 

194 inputs = task[0] 

195 outputs = task[1] 

196 label = task[2] 

197 klass = task[3] if len(task) > 3 else ExamplePipelineTask 

198 pipe.addTask(klass, label) 

199 if not isinstance(inputs, tuple | list): 

200 inputs = (inputs,) 

201 if not isinstance(outputs, tuple | list): 

202 outputs = (outputs,) 

203 if inputTypeOverrides: 

204 inputTypes = [inputTypeOverrides.get(x) for x in inputs] 

205 _makeConfig(inputs, outputs, pipe, label, inputTypes) 

206 taskDefs = list(pipe.to_graph()._iter_task_defs()) 

207 return pipe, taskDefs 

208 

209 

210class MermaidTestCase(unittest.TestCase): 

211 """A test case for Mermaid pipeline visualization.""" 

212 

213 def setUp(self): 

214 # Create a pipeline with example tasks and dataset types. 

215 self.pipeline, self.pipelineTaskDefs = _makePipeline( 

216 [ 

217 ("A", ("B", "C"), "task0"), 

218 ("B", "E", "task1"), 

219 ("B", ("D", "F"), "task2"), 

220 ("D", "G", "task3"), 

221 (("C", "F"), "H", "task4"), 

222 ("task4_metadata", "I", "task5"), 

223 ], 

224 inputTypeOverrides={"A": "Prerequisite"}, 

225 ) 

226 

227 # Create a pipeline graph for processing. 

228 self.graph = self.pipeline.to_graph(visualization_only=True) 

229 

230 def validateMermaidSource(self, file): 

231 """Validate the Mermaid source output for basic components. 

232 

233 Parameters 

234 ---------- 

235 file : `~io.StringIO` 

236 The in-memory file-like object containing the Mermaid source. 

237 """ 

238 # It's hard to validate the complete output, just checking some basic 

239 # things, but even that is not terribly stable. 

240 fileValue = file.getvalue() 

241 lines = fileValue.strip().split("\n") 

242 nClassDefs = 2 

243 nTasks = 6 

244 nTaskClasses = 6 

245 nDatasets = 10 

246 nDatasetClasses = 10 

247 nEdges = 16 

248 nLinkStyles = 2 # For the default and pre-requisite edges 

249 nExtra = 1 # 'flowchart' opening line 

250 

251 self.assertEqual( 

252 len(lines), 

253 nClassDefs + nTasks + nTaskClasses + nDatasets + nDatasetClasses + nEdges + nLinkStyles + nExtra, 

254 ) 

255 

256 # Confirm Mermaid syntax begins with a top-down flowchart declaration. 

257 self.assertTrue(fileValue.startswith("flowchart TD")) 

258 

259 # Select the reference set based on naming style. 

260 namingRef = { 

261 "edge1": "B:0 --> task1:2", 

262 "edge2": "task2:2 --> F:0", 

263 "classDef": "class E:0 dsType;", 

264 "metadata": "task4:2 --> task4_metadata:0", 

265 } 

266 

267 # Make sure components are connected appropriately. 

268 self.assertIn(namingRef["edge1"], fileValue) 

269 self.assertIn(namingRef["edge2"], fileValue) 

270 

271 # Make sure class definitions are created for datasets. 

272 self.assertIn(namingRef["classDef"], fileValue) 

273 

274 # Make sure there is a connection created for metadata. 

275 self.assertIn(namingRef["metadata"], fileValue) 

276 

277 def test_pipeline2mermaid(self): 

278 """Validate Mermaid syntax generated by `pipeline2mermaid`.""" 

279 for source in [self.pipeline, self.pipelineTaskDefs]: 

280 file = io.StringIO() 

281 pipeline2mermaid(source, file, expand_dimensions=False) 

282 self.validateMermaidSource(file) 

283 

284 def test_show_mermaid_source(self): 

285 """Test generating Mermaid source using `show_mermaid`.""" 

286 # Validate Mermaid source output (mmd format) in a text stream. 

287 file = io.StringIO() 

288 vis.show_mermaid(self.graph, file, dataset_types=True, task_classes="full") 

289 self.validateMermaidSource(file) 

290 

291 @unittest.skipIf(not MERMAID_AVAILABLE, "Skipping image rendering tests since `mermaid-py` is missing.") 

292 def test_show_mermaid_image(self): 

293 """Test the output of the `show_mermaid` method in png and svg formats. 

294 

295 For generating image outputs, the `show_mermaid` method invokes a 

296 remote rendering service (`mermaid.ink`), which can lead to timeouts at 

297 times. To bypass this instability, we patch Mermaid's `__init__` method 

298 to override its `svg_response` and `img_response` attributes, allowing 

299 us to return predictable, mocked contents instead. 

300 """ 

301 # Test image rendering formats with mocked responses in binary streams. 

302 with patch.object(vis._mermaid.Mermaid, "__init__", new=_mockMermaidInit): 

303 for fmt, expected in [("svg", MOCKED_SVG_CONTENT), ("png", MOCKED_PNG_CONTENT)]: 

304 file = io.BytesIO() 

305 vis.show_mermaid(self.graph, file, output_format=fmt, dataset_types=True, task_classes="full") 

306 file.seek(0) # Read from the beginning. 

307 self.assertEqual(file.read(), expected) 

308 

309 

310class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

311 """Generic file handle leak check.""" 

312 

313 

314def setup_module(module): 

315 """Set up the module for pytest. 

316 

317 Parameters 

318 ---------- 

319 module : `~types.ModuleType` 

320 Module to set up. 

321 """ 

322 lsst.utils.tests.init() 

323 

324 

325if __name__ == "__main__": 

326 lsst.utils.tests.init() 

327 unittest.main()