Coverage for python / lsst / pipe / base / pipeline_graph / visualization / _mermaid.py: 15%

147 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:59 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("show_mermaid",) 

30 

31import html 

32import os 

33import sys 

34from collections.abc import Mapping 

35from io import StringIO 

36from typing import IO, Any 

37 

38from .._nodes import NodeType 

39from .._pipeline_graph import PipelineGraph 

40from ._formatting import NodeKey, format_dimensions, format_task_class 

41from ._options import NodeAttributeOptions 

42from ._show import parse_display_args 

43 

44try: 

45 from mermaid import Mermaid # type: ignore 

46 from mermaid.graph import Graph # type: ignore 

47 

48 MERMAID_AVAILABLE = True 

49except ImportError: 

50 MERMAID_AVAILABLE = False 

51 

52# Configuration constants for label formatting and overflow handling. 

53_LABEL_PX_SIZE = 18 

54_LABEL_MAX_LINES_SOFT = 10 

55_LABEL_MAX_LINES_HARD = 15 

56_OVERFLOW_MAX_LINES = 20 

57 

58 

59def show_mermaid( 

60 pipeline_graph: PipelineGraph, 

61 stream: IO[Any] = sys.stdout, 

62 output_format: str = "mmd", 

63 width: int | None = None, 

64 height: int | None = None, 

65 scale: float | None = None, 

66 **kwargs: Any, 

67) -> None: 

68 """Write a Mermaid flowchart representation of the pipeline graph to a 

69 stream. 

70 

71 This function converts a given `PipelineGraph` into a Mermaid-based 

72 flowchart. Nodes represent tasks (and possibly task-init nodes) and dataset 

73 types, and edges represent connections between them. Dimensions and storage 

74 classes can be included as additional metadata on nodes. Prerequisite edges 

75 are rendered as dashed lines. 

76 

77 Parameters 

78 ---------- 

79 pipeline_graph : `PipelineGraph` 

80 The pipeline graph to visualize. 

81 stream : `typing.IO`, optional 

82 The output stream where Mermaid code is written. Defaults to 

83 `sys.stdout`. 

84 output_format : str, optional 

85 Defines the output format. 'mmd' (default) generates a Mermaid 

86 definition text file, while 'svg' and 'png' produce rendered images as 

87 binary streams. 

88 width : int, optional 

89 The width of the rendered image in pixels. 

90 height : int, optional 

91 The height of the rendered image in pixels. 

92 scale : float, optional 

93 The scale factor for the rendered image. Must be an float between 1 

94 and 3, and one of height or width must be provided. 

95 **kwargs : Any 

96 Additional arguments passed to `parse_display_args` to control aspects 

97 such as displaying dimensions, storage classes, or full task class 

98 names. 

99 

100 Notes 

101 ----- 

102 - The diagram uses a top-down layout (`flowchart TD`). 

103 - Three Mermaid classes are defined: 

104 - `task` for normal tasks, 

105 - `dsType` for dataset-type nodes, 

106 - `taskInit` for task-init nodes. 

107 - Edges that represent prerequisite relationships are rendered as dashed 

108 lines using `linkStyle`. 

109 - If a node's label is too long, overflow nodes are created to hold extra 

110 lines. 

111 """ 

112 # Generate Mermaid source code in-memory. 

113 mermaid_source = _generate_mermaid_source(pipeline_graph, **kwargs) 

114 

115 if output_format == "mmd": 

116 # Write Mermaid source as a string. 

117 stream.write(mermaid_source) 

118 else: 

119 # Render Mermaid source as an image and write to binary stream. 

120 _render_mermaid_image(mermaid_source, stream, output_format, width=width, height=height, scale=scale) 

121 

122 

123def _generate_mermaid_source(pipeline_graph: PipelineGraph, **kwargs: Any) -> str: 

124 """Generate the Mermaid source code from the pipeline graph. 

125 

126 Parameters 

127 ---------- 

128 pipeline_graph : `PipelineGraph` 

129 The pipeline graph to visualize. 

130 **kwargs : Any 

131 Additional arguments passed to `parse_display_args` for rendering. 

132 

133 Returns 

134 ------- 

135 str 

136 The Mermaid source code as a string. 

137 """ 

138 # A buffer to collect Mermaid source code. 

139 buffer = StringIO() 

140 

141 # Parse display arguments to determine what to show. 

142 xgraph, options = parse_display_args(pipeline_graph, **kwargs) 

143 

144 # Begin the Mermaid code block. 

145 buffer.write("flowchart TD\n") 

146 

147 # Define Mermaid classes for node styling. 

148 buffer.write( 

149 f"classDef task fill:#B1F2EF,color:#000,stroke:#000,stroke-width:3px," 

150 f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;\n" 

151 ) 

152 buffer.write( 

153 f"classDef dsType fill:#F5F5F5,color:#000,stroke:#00BABC,stroke-width:3px," 

154 f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left,rx:8,ry:8;\n" 

155 ) 

156 buffer.write( 

157 f"classDef taskInit fill:#F4DEFA,color:#000,stroke:#000,stroke-width:3px," 

158 f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;\n" 

159 ) 

160 

161 # `overflow_ref` tracks the reference numbers for overflow nodes. 

162 overflow_ref = 1 

163 overflow_ids = [] 

164 

165 # Render nodes. 

166 for node_key, node_data in xgraph.nodes.items(): 

167 match node_key.node_type: 

168 case NodeType.TASK | NodeType.TASK_INIT: 

169 _render_task_node(node_key, node_data, options, buffer) 

170 case NodeType.DATASET_TYPE: 

171 overflow_ref, node_overflow_ids = _render_dataset_type_node( 

172 node_key, node_data, options, buffer, overflow_ref 

173 ) 

174 overflow_ids += node_overflow_ids if node_overflow_ids else [] 

175 case _: 

176 raise AssertionError(f"Unexpected node type: {node_key.node_type}") 

177 

178 # Collect edges for adding to the Mermaid code and track which ones are 

179 # prerequisite so we can apply dashed styling to them later. 

180 edges = [] 

181 for _, (from_node, to_node, *_rest) in enumerate(xgraph.edges): 

182 is_prereq = xgraph.nodes[from_node].get("is_prerequisite", False) 

183 edges.append((from_node.node_id, to_node.node_id, is_prereq)) 

184 

185 # Render all edges. 

186 for _, (f, t, p) in enumerate(edges): 

187 _render_edge(f, t, p, buffer) 

188 

189 # After rendering all edges, apply linkStyle to prerequisite edges to make 

190 # them dashed: 

191 

192 # First, gather indices of prerequisite edges. 

193 prereq_indices = [str(i) for i, (_, _, p) in enumerate(edges) if p] 

194 

195 # Then apply dashed styling to all prerequisite edges in one line. 

196 if prereq_indices: 

197 buffer.write(f"linkStyle {','.join(prereq_indices)} stroke-dasharray:5;\n") 

198 

199 # Return Mermaid source as string. 

200 return buffer.getvalue() 

201 

202 

203def _render_mermaid_image( 

204 mermaid_source: str, 

205 binary_stream: IO[bytes], 

206 output_format: str, 

207 width: int | None = None, 

208 height: int | None = None, 

209 scale: float | None = None, 

210) -> None: 

211 """Render a Mermaid diagram as an image and write the output to a binary 

212 stream. 

213 

214 Parameters 

215 ---------- 

216 mermaid_source : str 

217 The Mermaid diagram source code. 

218 binary_stream : `BytesIO` 

219 The binary stream where the output content will be written. 

220 output_format : str 

221 The desired output format for the image. Supported image formats are 

222 'svg' and 'png'. 

223 width : int, optional 

224 The width of the rendered image in pixels. 

225 height : int, optional 

226 The height of the rendered image in pixels. 

227 scale : float, optional 

228 The scale factor for the rendered image. Must be a float between 1 and 

229 3, and one of height or width must be provided. 

230 

231 Raises 

232 ------ 

233 ImportError 

234 If `mermaid-py` is not installed. 

235 ValueError 

236 If the requested ``output_format`` is not supported. 

237 RuntimeError 

238 If the rendering process fails. 

239 """ 

240 if output_format.lower() not in {"svg", "png"}: 

241 raise ValueError(f"Unsupported format: {output_format}. Use 'svg' or 'png'.") 

242 

243 if not MERMAID_AVAILABLE: 

244 raise ImportError("The `mermaid-py` package is required for rendering images but is not installed.") 

245 

246 # Generate Mermaid graph object. 

247 graph = Graph(title="Mermaid Diagram", script=mermaid_source) 

248 diagram = Mermaid(graph, width=width, height=height, scale=scale) 

249 

250 # Determine the response type based on the output format. 

251 if output_format.lower() == "svg": 

252 response_type = "svg_response" 

253 else: 

254 response_type = "img_response" 

255 

256 # Select the appropriate output format and write the content to the stream. 

257 try: 

258 content = getattr(diagram, response_type).content 

259 

260 # Check if the response is actually an image. 

261 if content.startswith(b"<!DOCTYPE html>") or b"<title>" in content[:200]: 

262 error_msg = content.decode(errors="ignore")[:1000] 

263 if "524" in error_msg or "timeout" in error_msg.lower(): 

264 raise RuntimeError( 

265 f"Mermaid rendering service (mermaid.ink) timed out while generating {response_type}. " 

266 "This may be due to server overload. Try again later or use a local rendering option." 

267 ) 

268 raise RuntimeError( 

269 f"Unexpected error from Mermaid API while generating {response_type}. Response:\n{error_msg}" 

270 ) 

271 

272 # Write the content to the binary stream if it's a valid image. 

273 binary_stream.write(content) 

274 except AttributeError as exc: 

275 raise RuntimeError(f"Failed to generate {response_type} content") from exc 

276 

277 

278def _render_task_node( 

279 node_key: NodeKey, 

280 node_data: Mapping[str, Any], 

281 options: NodeAttributeOptions, 

282 stream: IO[str], 

283) -> None: 

284 """Render a Mermaid node for a task or task-init node. 

285 

286 Parameters 

287 ---------- 

288 node_key : NodeKey 

289 Identifies the node. The node type determines styling and whether 

290 dimensions apply. 

291 node_data : Mapping[str, Any] 

292 Node attributes, including possibly 'task_class_name' and 'dimensions'. 

293 options : NodeAttributeOptions 

294 Rendering options controlling whether to show dimensions, storage 

295 classes, etc. 

296 stream : `typing.IO` [ `str` ] 

297 The output stream for Mermaid syntax. 

298 """ 

299 # Convert node_key into a label, handling line splitting and prefix 

300 # extraction. 

301 lines, _, _ = _format_label(str(node_key)) 

302 

303 # If requested, show the fully qualified task class name beneath the task 

304 # label. 

305 if options.task_classes and node_key.node_type in (NodeType.TASK, NodeType.TASK_INIT): 

306 lines.append(html.escape(format_task_class(options, node_data["task_class_name"]))) 

307 

308 # Show dimensions if requested and if this is not a task-init node. 

309 if options.dimensions and node_key.node_type != NodeType.TASK_INIT: 

310 dims_str = html.escape(format_dimensions(options, node_data["dimensions"])).replace(" ", "&nbsp;") 

311 lines.append(f"<i>dimensions:</i>&nbsp;{dims_str}") 

312 

313 # Join lines with <br> for multi-line label. 

314 label = "<br>".join(lines) 

315 

316 # Print Mermaid node. 

317 node_id = node_key.node_id 

318 print(f'{node_id}["{label}"]', file=stream) 

319 

320 # Assign class based on node type. 

321 if node_key.node_type == NodeType.TASK: 

322 print(f"class {node_id} task;", file=stream) 

323 else: 

324 # For NodeType.TASK_INIT. 

325 print(f"class {node_id} taskInit;", file=stream) 

326 

327 

328def _render_dataset_type_node( 

329 node_key: NodeKey, 

330 node_data: Mapping[str, Any], 

331 options: NodeAttributeOptions, 

332 stream: IO[str], 

333 overflow_ref: int, 

334) -> tuple[int, list[str]]: 

335 """Render a Mermaid node for a dataset-type node, handling overflow lines 

336 if needed. 

337 

338 Dataset-type nodes can have many lines of label text. If the label exceeds 

339 a certain threshold, we create separate "overflow" nodes. 

340 

341 Parameters 

342 ---------- 

343 node_key : NodeKey 

344 Identifies this dataset-type node. 

345 node_data : Mapping[str, Any] 

346 Node attributes, possibly including dimensions and storage class. 

347 options : NodeAttributeOptions 

348 Rendering options controlling whether to show dimensions and storage 

349 classes. 

350 stream : `typing.IO` [ `str` ] 

351 The output stream for Mermaid syntax. 

352 overflow_ref : int 

353 The current reference number for overflow nodes. If overflow occurs, 

354 this is incremented. 

355 

356 Returns 

357 ------- 

358 overflow_ref : int 

359 Possibly incremented overflow reference number. 

360 overflow_ids : list[str] 

361 IDs of overflow nodes created, if any. 

362 """ 

363 # Format the node label, respecting soft/hard line limits. 

364 labels, label_extras, _ = _format_label(str(node_key), _LABEL_MAX_LINES_SOFT) 

365 

366 overflow_ids = [] 

367 total_lines = len(labels) + len(label_extras) 

368 if total_lines > _LABEL_MAX_LINES_HARD: 

369 # Too many lines, we must handle overflow by splitting extras. 

370 allowed_extras = _LABEL_MAX_LINES_HARD - len(labels) 

371 if allowed_extras < 0: 

372 allowed_extras = 0 

373 extras_for_overflow = label_extras[allowed_extras:] 

374 label_extras = label_extras[:allowed_extras] 

375 

376 if extras_for_overflow: 

377 # Introduce an overflow anchor. 

378 overflow_anchor = f"[{overflow_ref}]" 

379 labels.append(f"<b>...more details in {overflow_anchor}</b>") 

380 

381 # Create overflow nodes in chunks. 

382 for i in range(0, len(extras_for_overflow), _OVERFLOW_MAX_LINES): 

383 overflow_id = f"{node_key.node_id}_overflow_{overflow_ref}_{i}" 

384 chunk = extras_for_overflow[i : i + _OVERFLOW_MAX_LINES] 

385 chunk.insert(0, f"<b>{html.escape(overflow_anchor)}</b>") 

386 _render_simple_node(overflow_id, chunk, "dsType", stream) 

387 overflow_ids.append(overflow_id) 

388 

389 overflow_ref += 1 

390 

391 # Combine final lines after overflow handling. 

392 final_lines = labels + label_extras 

393 

394 # Append dimensions if requested. 

395 if options.dimensions: 

396 dims_str = html.escape(format_dimensions(options, node_data["dimensions"])).replace(" ", "&nbsp;") 

397 final_lines.append(f"<i>dimensions:</i>&nbsp;{dims_str}") 

398 

399 # Append storage class if requested. 

400 if options.storage_classes: 

401 final_lines.append(f"<i>storage&nbsp;class:</i>&nbsp;{html.escape(node_data['storage_class_name'])}") 

402 

403 # Render the main dataset-type node. 

404 _render_simple_node(node_key.node_id, final_lines, "dsType", stream) 

405 

406 return overflow_ref, overflow_ids 

407 

408 

409def _render_simple_node(node_id: str, lines: list[str], node_class: str, stream: IO[str]) -> None: 

410 """Render a simple Mermaid node with given lines and a class. 

411 

412 This helper function is used for both primary nodes and overflow nodes once 

413 the split has been decided. 

414 

415 Parameters 

416 ---------- 

417 node_id : str 

418 Mermaid node ID. 

419 lines : list[str] 

420 Lines of HTML-formatted text to display in the node. 

421 node_class : str 

422 Mermaid class name to style the node (e.g., 'dsType', 'task', 

423 'taskInit'). 

424 stream : `typing.IO` [ `str` ] 

425 The output stream. 

426 """ 

427 label = "<br>".join(lines) 

428 print(f'{node_id}["{label}"]', file=stream) 

429 print(f"class {node_id} {node_class};", file=stream) 

430 

431 

432def _render_edge(from_node_id: str, to_node_id: str, is_prerequisite: bool, stream: IO[str]) -> None: 

433 """Render a Mermaid edge from one node to another. 

434 

435 Edges in Mermaid are normally specified as `A --> B`. Prerequisite edges 

436 will later be styled as dashed lines using linkStyle after all edges have 

437 been printed. 

438 

439 Parameters 

440 ---------- 

441 from_node_id : str 

442 The ID of the 'from' node in the edge. 

443 to_node_id : str 

444 The ID of the 'to' node in the edge. 

445 is_prerequisite : bool 

446 If True, this edge represents a prerequisite connection and will be 

447 styled as dashed. 

448 stream : `typing.IO` [ `str` ] 

449 The output stream for Mermaid syntax. 

450 """ 

451 # At this stage, we simply print the edge. The styling (dashed) for 

452 # prerequisite edges is applied afterwards via linkStyle lines. 

453 print(f"{from_node_id} --> {to_node_id}", file=stream) 

454 

455 

456def _format_label( 

457 label: str, 

458 max_lines: int = 10, 

459 min_common_prefix_len: int = 1000, 

460) -> tuple[list[str], list[str], str]: 

461 """Parse and format a label into multiple lines with optional overflow 

462 handling. 

463 

464 This function attempts to cleanly format long labels by: 

465 - Splitting the label by ", ". 

466 - Identifying a common prefix to factor out if sufficiently long. 

467 - Limiting the number of lines to 'max_lines', storing extras for potential 

468 overflow. 

469 

470 Parameters 

471 ---------- 

472 label : str 

473 The raw label text, often derived from a NodeKey. 

474 max_lines : int, optional 

475 Maximum lines before overflow is triggered. 

476 min_common_prefix_len : int, optional 

477 Minimum length for considering a common prefix extraction. 

478 

479 Returns 

480 ------- 

481 labels : list[str] 

482 Main label lines as HTML-formatted text. 

483 label_extras : list[str] 

484 Overflow lines if the label is too long. 

485 common_prefix : str 

486 Extracted common prefix, if any. 

487 """ 

488 parsed_labels, parsed_label_extras, common_prefix = _parse_label(label, max_lines, min_common_prefix_len) 

489 

490 # If there's a common prefix, present it bolded. 

491 if common_prefix: 

492 common_prefix = f"<b>{html.escape(common_prefix)}:</b>" 

493 

494 indent = "&nbsp;&nbsp;" if common_prefix else "" 

495 labels = [f"<b>{indent}{html.escape(el)}</b>" for el in parsed_labels] 

496 label_extras = [f"<b>{indent}{html.escape(el)}</b>" for el in parsed_label_extras] 

497 

498 if common_prefix: 

499 labels.insert(0, common_prefix) 

500 

501 return labels, label_extras, common_prefix or "" 

502 

503 

504def _parse_label( 

505 label: str, 

506 max_lines: int, 

507 min_common_prefix_len: int, 

508) -> tuple[list[str], list[str], str]: 

509 """Split and process label text for overflow and common prefix extraction. 

510 

511 Parameters 

512 ---------- 

513 label : str 

514 The raw label text. 

515 max_lines : int 

516 Maximum number of lines before overflow. 

517 min_common_prefix_len : int 

518 Minimum length for a common prefix to be considered. 

519 

520 Returns 

521 ------- 

522 labels : list[str] 

523 The primary label lines. 

524 label_extras : list[str] 

525 Any overflow lines that exceed max_lines. 

526 common_prefix : str 

527 The extracted common prefix, if applicable. 

528 """ 

529 labels = label.split(", ") 

530 common_prefix = os.path.commonprefix(labels) 

531 

532 # If there's a long common prefix for multiple labels, factor it out at the 

533 # nearest underscore. 

534 if len(labels) > 3 and len(common_prefix) > min_common_prefix_len: 

535 final_underscore_index = common_prefix.rfind("_") 

536 if final_underscore_index > 0: 

537 common_prefix = common_prefix[: final_underscore_index + 1] 

538 labels = [element[len(common_prefix) :] for element in labels] 

539 else: 

540 common_prefix = "" 

541 else: 

542 common_prefix = "" 

543 

544 # Handle overflow if needed. 

545 if (len(labels) + bool(common_prefix)) > max_lines: 

546 label_extras = labels[max_lines - bool(common_prefix) :] 

547 labels = labels[: max_lines - bool(common_prefix)] 

548 else: 

549 label_extras = [] 

550 

551 return labels, label_extras, common_prefix