Coverage for python / lsst / pipe / base / pipeline_graph / visualization / _mermaid.py: 15%
147 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("show_mermaid",)
31import html
32import os
33import sys
34from collections.abc import Mapping
35from io import StringIO
36from typing import IO, Any
38from .._nodes import NodeType
39from .._pipeline_graph import PipelineGraph
40from ._formatting import NodeKey, format_dimensions, format_task_class
41from ._options import NodeAttributeOptions
42from ._show import parse_display_args
44try:
45 from mermaid import Mermaid # type: ignore
46 from mermaid.graph import Graph # type: ignore
48 MERMAID_AVAILABLE = True
49except ImportError:
50 MERMAID_AVAILABLE = False
52# Configuration constants for label formatting and overflow handling.
53_LABEL_PX_SIZE = 18
54_LABEL_MAX_LINES_SOFT = 10
55_LABEL_MAX_LINES_HARD = 15
56_OVERFLOW_MAX_LINES = 20
59def show_mermaid(
60 pipeline_graph: PipelineGraph,
61 stream: IO[Any] = sys.stdout,
62 output_format: str = "mmd",
63 width: int | None = None,
64 height: int | None = None,
65 scale: float | None = None,
66 **kwargs: Any,
67) -> None:
68 """Write a Mermaid flowchart representation of the pipeline graph to a
69 stream.
71 This function converts a given `PipelineGraph` into a Mermaid-based
72 flowchart. Nodes represent tasks (and possibly task-init nodes) and dataset
73 types, and edges represent connections between them. Dimensions and storage
74 classes can be included as additional metadata on nodes. Prerequisite edges
75 are rendered as dashed lines.
77 Parameters
78 ----------
79 pipeline_graph : `PipelineGraph`
80 The pipeline graph to visualize.
81 stream : `typing.IO`, optional
82 The output stream where Mermaid code is written. Defaults to
83 `sys.stdout`.
84 output_format : str, optional
85 Defines the output format. 'mmd' (default) generates a Mermaid
86 definition text file, while 'svg' and 'png' produce rendered images as
87 binary streams.
88 width : int, optional
89 The width of the rendered image in pixels.
90 height : int, optional
91 The height of the rendered image in pixels.
92 scale : float, optional
93 The scale factor for the rendered image. Must be an float between 1
94 and 3, and one of height or width must be provided.
95 **kwargs : Any
96 Additional arguments passed to `parse_display_args` to control aspects
97 such as displaying dimensions, storage classes, or full task class
98 names.
100 Notes
101 -----
102 - The diagram uses a top-down layout (`flowchart TD`).
103 - Three Mermaid classes are defined:
104 - `task` for normal tasks,
105 - `dsType` for dataset-type nodes,
106 - `taskInit` for task-init nodes.
107 - Edges that represent prerequisite relationships are rendered as dashed
108 lines using `linkStyle`.
109 - If a node's label is too long, overflow nodes are created to hold extra
110 lines.
111 """
112 # Generate Mermaid source code in-memory.
113 mermaid_source = _generate_mermaid_source(pipeline_graph, **kwargs)
115 if output_format == "mmd":
116 # Write Mermaid source as a string.
117 stream.write(mermaid_source)
118 else:
119 # Render Mermaid source as an image and write to binary stream.
120 _render_mermaid_image(mermaid_source, stream, output_format, width=width, height=height, scale=scale)
123def _generate_mermaid_source(pipeline_graph: PipelineGraph, **kwargs: Any) -> str:
124 """Generate the Mermaid source code from the pipeline graph.
126 Parameters
127 ----------
128 pipeline_graph : `PipelineGraph`
129 The pipeline graph to visualize.
130 **kwargs : Any
131 Additional arguments passed to `parse_display_args` for rendering.
133 Returns
134 -------
135 str
136 The Mermaid source code as a string.
137 """
138 # A buffer to collect Mermaid source code.
139 buffer = StringIO()
141 # Parse display arguments to determine what to show.
142 xgraph, options = parse_display_args(pipeline_graph, **kwargs)
144 # Begin the Mermaid code block.
145 buffer.write("flowchart TD\n")
147 # Define Mermaid classes for node styling.
148 buffer.write(
149 f"classDef task fill:#B1F2EF,color:#000,stroke:#000,stroke-width:3px,"
150 f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;\n"
151 )
152 buffer.write(
153 f"classDef dsType fill:#F5F5F5,color:#000,stroke:#00BABC,stroke-width:3px,"
154 f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left,rx:8,ry:8;\n"
155 )
156 buffer.write(
157 f"classDef taskInit fill:#F4DEFA,color:#000,stroke:#000,stroke-width:3px,"
158 f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;\n"
159 )
161 # `overflow_ref` tracks the reference numbers for overflow nodes.
162 overflow_ref = 1
163 overflow_ids = []
165 # Render nodes.
166 for node_key, node_data in xgraph.nodes.items():
167 match node_key.node_type:
168 case NodeType.TASK | NodeType.TASK_INIT:
169 _render_task_node(node_key, node_data, options, buffer)
170 case NodeType.DATASET_TYPE:
171 overflow_ref, node_overflow_ids = _render_dataset_type_node(
172 node_key, node_data, options, buffer, overflow_ref
173 )
174 overflow_ids += node_overflow_ids if node_overflow_ids else []
175 case _:
176 raise AssertionError(f"Unexpected node type: {node_key.node_type}")
178 # Collect edges for adding to the Mermaid code and track which ones are
179 # prerequisite so we can apply dashed styling to them later.
180 edges = []
181 for _, (from_node, to_node, *_rest) in enumerate(xgraph.edges):
182 is_prereq = xgraph.nodes[from_node].get("is_prerequisite", False)
183 edges.append((from_node.node_id, to_node.node_id, is_prereq))
185 # Render all edges.
186 for _, (f, t, p) in enumerate(edges):
187 _render_edge(f, t, p, buffer)
189 # After rendering all edges, apply linkStyle to prerequisite edges to make
190 # them dashed:
192 # First, gather indices of prerequisite edges.
193 prereq_indices = [str(i) for i, (_, _, p) in enumerate(edges) if p]
195 # Then apply dashed styling to all prerequisite edges in one line.
196 if prereq_indices:
197 buffer.write(f"linkStyle {','.join(prereq_indices)} stroke-dasharray:5;\n")
199 # Return Mermaid source as string.
200 return buffer.getvalue()
203def _render_mermaid_image(
204 mermaid_source: str,
205 binary_stream: IO[bytes],
206 output_format: str,
207 width: int | None = None,
208 height: int | None = None,
209 scale: float | None = None,
210) -> None:
211 """Render a Mermaid diagram as an image and write the output to a binary
212 stream.
214 Parameters
215 ----------
216 mermaid_source : str
217 The Mermaid diagram source code.
218 binary_stream : `BytesIO`
219 The binary stream where the output content will be written.
220 output_format : str
221 The desired output format for the image. Supported image formats are
222 'svg' and 'png'.
223 width : int, optional
224 The width of the rendered image in pixels.
225 height : int, optional
226 The height of the rendered image in pixels.
227 scale : float, optional
228 The scale factor for the rendered image. Must be a float between 1 and
229 3, and one of height or width must be provided.
231 Raises
232 ------
233 ImportError
234 If `mermaid-py` is not installed.
235 ValueError
236 If the requested ``output_format`` is not supported.
237 RuntimeError
238 If the rendering process fails.
239 """
240 if output_format.lower() not in {"svg", "png"}:
241 raise ValueError(f"Unsupported format: {output_format}. Use 'svg' or 'png'.")
243 if not MERMAID_AVAILABLE:
244 raise ImportError("The `mermaid-py` package is required for rendering images but is not installed.")
246 # Generate Mermaid graph object.
247 graph = Graph(title="Mermaid Diagram", script=mermaid_source)
248 diagram = Mermaid(graph, width=width, height=height, scale=scale)
250 # Determine the response type based on the output format.
251 if output_format.lower() == "svg":
252 response_type = "svg_response"
253 else:
254 response_type = "img_response"
256 # Select the appropriate output format and write the content to the stream.
257 try:
258 content = getattr(diagram, response_type).content
260 # Check if the response is actually an image.
261 if content.startswith(b"<!DOCTYPE html>") or b"<title>" in content[:200]:
262 error_msg = content.decode(errors="ignore")[:1000]
263 if "524" in error_msg or "timeout" in error_msg.lower():
264 raise RuntimeError(
265 f"Mermaid rendering service (mermaid.ink) timed out while generating {response_type}. "
266 "This may be due to server overload. Try again later or use a local rendering option."
267 )
268 raise RuntimeError(
269 f"Unexpected error from Mermaid API while generating {response_type}. Response:\n{error_msg}"
270 )
272 # Write the content to the binary stream if it's a valid image.
273 binary_stream.write(content)
274 except AttributeError as exc:
275 raise RuntimeError(f"Failed to generate {response_type} content") from exc
278def _render_task_node(
279 node_key: NodeKey,
280 node_data: Mapping[str, Any],
281 options: NodeAttributeOptions,
282 stream: IO[str],
283) -> None:
284 """Render a Mermaid node for a task or task-init node.
286 Parameters
287 ----------
288 node_key : NodeKey
289 Identifies the node. The node type determines styling and whether
290 dimensions apply.
291 node_data : Mapping[str, Any]
292 Node attributes, including possibly 'task_class_name' and 'dimensions'.
293 options : NodeAttributeOptions
294 Rendering options controlling whether to show dimensions, storage
295 classes, etc.
296 stream : `typing.IO` [ `str` ]
297 The output stream for Mermaid syntax.
298 """
299 # Convert node_key into a label, handling line splitting and prefix
300 # extraction.
301 lines, _, _ = _format_label(str(node_key))
303 # If requested, show the fully qualified task class name beneath the task
304 # label.
305 if options.task_classes and node_key.node_type in (NodeType.TASK, NodeType.TASK_INIT):
306 lines.append(html.escape(format_task_class(options, node_data["task_class_name"])))
308 # Show dimensions if requested and if this is not a task-init node.
309 if options.dimensions and node_key.node_type != NodeType.TASK_INIT:
310 dims_str = html.escape(format_dimensions(options, node_data["dimensions"])).replace(" ", " ")
311 lines.append(f"<i>dimensions:</i> {dims_str}")
313 # Join lines with <br> for multi-line label.
314 label = "<br>".join(lines)
316 # Print Mermaid node.
317 node_id = node_key.node_id
318 print(f'{node_id}["{label}"]', file=stream)
320 # Assign class based on node type.
321 if node_key.node_type == NodeType.TASK:
322 print(f"class {node_id} task;", file=stream)
323 else:
324 # For NodeType.TASK_INIT.
325 print(f"class {node_id} taskInit;", file=stream)
328def _render_dataset_type_node(
329 node_key: NodeKey,
330 node_data: Mapping[str, Any],
331 options: NodeAttributeOptions,
332 stream: IO[str],
333 overflow_ref: int,
334) -> tuple[int, list[str]]:
335 """Render a Mermaid node for a dataset-type node, handling overflow lines
336 if needed.
338 Dataset-type nodes can have many lines of label text. If the label exceeds
339 a certain threshold, we create separate "overflow" nodes.
341 Parameters
342 ----------
343 node_key : NodeKey
344 Identifies this dataset-type node.
345 node_data : Mapping[str, Any]
346 Node attributes, possibly including dimensions and storage class.
347 options : NodeAttributeOptions
348 Rendering options controlling whether to show dimensions and storage
349 classes.
350 stream : `typing.IO` [ `str` ]
351 The output stream for Mermaid syntax.
352 overflow_ref : int
353 The current reference number for overflow nodes. If overflow occurs,
354 this is incremented.
356 Returns
357 -------
358 overflow_ref : int
359 Possibly incremented overflow reference number.
360 overflow_ids : list[str]
361 IDs of overflow nodes created, if any.
362 """
363 # Format the node label, respecting soft/hard line limits.
364 labels, label_extras, _ = _format_label(str(node_key), _LABEL_MAX_LINES_SOFT)
366 overflow_ids = []
367 total_lines = len(labels) + len(label_extras)
368 if total_lines > _LABEL_MAX_LINES_HARD:
369 # Too many lines, we must handle overflow by splitting extras.
370 allowed_extras = _LABEL_MAX_LINES_HARD - len(labels)
371 if allowed_extras < 0:
372 allowed_extras = 0
373 extras_for_overflow = label_extras[allowed_extras:]
374 label_extras = label_extras[:allowed_extras]
376 if extras_for_overflow:
377 # Introduce an overflow anchor.
378 overflow_anchor = f"[{overflow_ref}]"
379 labels.append(f"<b>...more details in {overflow_anchor}</b>")
381 # Create overflow nodes in chunks.
382 for i in range(0, len(extras_for_overflow), _OVERFLOW_MAX_LINES):
383 overflow_id = f"{node_key.node_id}_overflow_{overflow_ref}_{i}"
384 chunk = extras_for_overflow[i : i + _OVERFLOW_MAX_LINES]
385 chunk.insert(0, f"<b>{html.escape(overflow_anchor)}</b>")
386 _render_simple_node(overflow_id, chunk, "dsType", stream)
387 overflow_ids.append(overflow_id)
389 overflow_ref += 1
391 # Combine final lines after overflow handling.
392 final_lines = labels + label_extras
394 # Append dimensions if requested.
395 if options.dimensions:
396 dims_str = html.escape(format_dimensions(options, node_data["dimensions"])).replace(" ", " ")
397 final_lines.append(f"<i>dimensions:</i> {dims_str}")
399 # Append storage class if requested.
400 if options.storage_classes:
401 final_lines.append(f"<i>storage class:</i> {html.escape(node_data['storage_class_name'])}")
403 # Render the main dataset-type node.
404 _render_simple_node(node_key.node_id, final_lines, "dsType", stream)
406 return overflow_ref, overflow_ids
409def _render_simple_node(node_id: str, lines: list[str], node_class: str, stream: IO[str]) -> None:
410 """Render a simple Mermaid node with given lines and a class.
412 This helper function is used for both primary nodes and overflow nodes once
413 the split has been decided.
415 Parameters
416 ----------
417 node_id : str
418 Mermaid node ID.
419 lines : list[str]
420 Lines of HTML-formatted text to display in the node.
421 node_class : str
422 Mermaid class name to style the node (e.g., 'dsType', 'task',
423 'taskInit').
424 stream : `typing.IO` [ `str` ]
425 The output stream.
426 """
427 label = "<br>".join(lines)
428 print(f'{node_id}["{label}"]', file=stream)
429 print(f"class {node_id} {node_class};", file=stream)
432def _render_edge(from_node_id: str, to_node_id: str, is_prerequisite: bool, stream: IO[str]) -> None:
433 """Render a Mermaid edge from one node to another.
435 Edges in Mermaid are normally specified as `A --> B`. Prerequisite edges
436 will later be styled as dashed lines using linkStyle after all edges have
437 been printed.
439 Parameters
440 ----------
441 from_node_id : str
442 The ID of the 'from' node in the edge.
443 to_node_id : str
444 The ID of the 'to' node in the edge.
445 is_prerequisite : bool
446 If True, this edge represents a prerequisite connection and will be
447 styled as dashed.
448 stream : `typing.IO` [ `str` ]
449 The output stream for Mermaid syntax.
450 """
451 # At this stage, we simply print the edge. The styling (dashed) for
452 # prerequisite edges is applied afterwards via linkStyle lines.
453 print(f"{from_node_id} --> {to_node_id}", file=stream)
456def _format_label(
457 label: str,
458 max_lines: int = 10,
459 min_common_prefix_len: int = 1000,
460) -> tuple[list[str], list[str], str]:
461 """Parse and format a label into multiple lines with optional overflow
462 handling.
464 This function attempts to cleanly format long labels by:
465 - Splitting the label by ", ".
466 - Identifying a common prefix to factor out if sufficiently long.
467 - Limiting the number of lines to 'max_lines', storing extras for potential
468 overflow.
470 Parameters
471 ----------
472 label : str
473 The raw label text, often derived from a NodeKey.
474 max_lines : int, optional
475 Maximum lines before overflow is triggered.
476 min_common_prefix_len : int, optional
477 Minimum length for considering a common prefix extraction.
479 Returns
480 -------
481 labels : list[str]
482 Main label lines as HTML-formatted text.
483 label_extras : list[str]
484 Overflow lines if the label is too long.
485 common_prefix : str
486 Extracted common prefix, if any.
487 """
488 parsed_labels, parsed_label_extras, common_prefix = _parse_label(label, max_lines, min_common_prefix_len)
490 # If there's a common prefix, present it bolded.
491 if common_prefix:
492 common_prefix = f"<b>{html.escape(common_prefix)}:</b>"
494 indent = " " if common_prefix else ""
495 labels = [f"<b>{indent}{html.escape(el)}</b>" for el in parsed_labels]
496 label_extras = [f"<b>{indent}{html.escape(el)}</b>" for el in parsed_label_extras]
498 if common_prefix:
499 labels.insert(0, common_prefix)
501 return labels, label_extras, common_prefix or ""
504def _parse_label(
505 label: str,
506 max_lines: int,
507 min_common_prefix_len: int,
508) -> tuple[list[str], list[str], str]:
509 """Split and process label text for overflow and common prefix extraction.
511 Parameters
512 ----------
513 label : str
514 The raw label text.
515 max_lines : int
516 Maximum number of lines before overflow.
517 min_common_prefix_len : int
518 Minimum length for a common prefix to be considered.
520 Returns
521 -------
522 labels : list[str]
523 The primary label lines.
524 label_extras : list[str]
525 Any overflow lines that exceed max_lines.
526 common_prefix : str
527 The extracted common prefix, if applicable.
528 """
529 labels = label.split(", ")
530 common_prefix = os.path.commonprefix(labels)
532 # If there's a long common prefix for multiple labels, factor it out at the
533 # nearest underscore.
534 if len(labels) > 3 and len(common_prefix) > min_common_prefix_len:
535 final_underscore_index = common_prefix.rfind("_")
536 if final_underscore_index > 0:
537 common_prefix = common_prefix[: final_underscore_index + 1]
538 labels = [element[len(common_prefix) :] for element in labels]
539 else:
540 common_prefix = ""
541 else:
542 common_prefix = ""
544 # Handle overflow if needed.
545 if (len(labels) + bool(common_prefix)) > max_lines:
546 label_extras = labels[max_lines - bool(common_prefix) :]
547 labels = labels[: max_lines - bool(common_prefix)]
548 else:
549 label_extras = []
551 return labels, label_extras, common_prefix