Coverage for python/lsst/ctrl/mpexec/dotTools.py: 10%
153 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-06 02:30 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-06 02:30 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to generate GraphViz diagrams from pipelines
23or quantum graphs.
24"""
26from __future__ import annotations
28__all__ = ["graph2dot", "pipeline2dot"]
30# -------------------------------
31# Imports of standard modules --
32# -------------------------------
33import html
34import io
35import re
36from collections.abc import Iterable
37from typing import TYPE_CHECKING, Any
39# -----------------------------
40# Imports for other modules --
41# -----------------------------
42from lsst.daf.butler import DatasetType, DimensionUniverse
43from lsst.pipe.base import Pipeline, connectionTypes, iterConnections
45if TYPE_CHECKING:
46 from lsst.daf.butler import DatasetRef
47 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef
49# ----------------------------------
50# Local non-exported definitions --
51# ----------------------------------
53# Attributes applied to directed graph objects.
54_NODELABELPOINTSIZE = "18"
55_ATTRIBS = dict(
56 defaultGraph=dict(splines="ortho", nodesep="0.5", ranksep="0.75", pad="0.5"),
57 defaultNode=dict(shape="box", fontname="Monospace", fontsize="14", margin="0.2,0.1", penwidth="3"),
58 defaultEdge=dict(color="black", arrowsize="1.5", penwidth="1.5"),
59 task=dict(style="filled", color="black", fillcolor="#B1F2EF"),
60 quantum=dict(style="filled", color="black", fillcolor="#B1F2EF"),
61 dsType=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
62 dataset=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
63)
66def _renderDefault(type: str, attribs: dict[str, str], file: io.TextIOBase) -> None:
67 """Set default attributes for a given type."""
68 default_attribs = ", ".join([f'{key}="{val}"' for key, val in attribs.items()])
69 print(f"{type} [{default_attribs}];", file=file)
72def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
73 """Render GV node"""
74 label = r"</TD></TR><TR><TD>".join(labels)
75 attrib_dict = dict(_ATTRIBS[style], label=label)
76 pre = '<<TABLE BORDER="0" CELLPADDING="5"><TR><TD>'
77 post = "</TD></TR></TABLE>>"
78 attrib = ", ".join(
79 [
80 f'{key}="{val}"' if key != "label" else f"{key}={pre}{val}{post}"
81 for key, val in attrib_dict.items()
82 ]
83 )
84 print(f'"{nodeName}" [{attrib}];', file=file)
87def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
88 """Render GV node for a task"""
89 labels = [
90 f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(taskDef.label) + "</FONT></B>",
91 html.escape(taskDef.taskName),
92 ]
93 if idx is not None:
94 labels.append(f"<I>index:</I> {idx}")
95 if taskDef.connections:
96 # don't print collection of str directly to avoid visually noisy quotes
97 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
98 labels.append(f"<I>dimensions:</I> {html.escape(dimensions_str)}")
99 _renderNode(file, nodeName, "task", labels)
102def _renderQuantumNode(
103 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
104) -> None:
105 """Render GV node for a quantum"""
106 labels = [f"{quantumNode.nodeId}", html.escape(taskDef.label)]
107 dataId = quantumNode.quantum.dataId
108 assert dataId is not None, "Quantum DataId cannot be None"
109 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys()))
110 _renderNode(file, nodeName, "quantum", labels)
113def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
114 """Render GV node for a dataset type"""
115 labels = [f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(name) + "</FONT></B>"]
116 if dimensions:
117 labels.append("<I>dimensions:</I> " + html.escape(", ".join(sorted(dimensions))))
118 _renderNode(file, name, "dsType", labels)
121def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
122 """Render GV node for a dataset"""
123 labels = [html.escape(dsRef.datasetType.name), f"run: {dsRef.run!r}"]
124 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
125 _renderNode(file, nodeName, "dataset", labels)
128def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
129 """Render GV edge"""
130 if kwargs:
131 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
132 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
133 else:
134 print(f'"{fromName}" -> "{toName}";', file=file)
137def _datasetRefId(dsRef: DatasetRef) -> str:
138 """Make an identifying string for given ref"""
139 dsId = [dsRef.datasetType.name]
140 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
141 return ":".join(dsId)
144def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
145 """Make new node for dataset if it does not exist.
147 Returns node name.
148 """
149 dsRefId = _datasetRefId(dsRef)
150 nodeName = allDatasetRefs.get(dsRefId)
151 if nodeName is None:
152 idx = len(allDatasetRefs)
153 nodeName = f"dsref_{idx}"
154 allDatasetRefs[dsRefId] = nodeName
155 _renderDSNode(nodeName, dsRef, file)
156 return nodeName
159# ------------------------
160# Exported definitions --
161# ------------------------
164def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
165 """Convert QuantumGraph into GraphViz digraph.
167 This method is mostly for documentation/presentation purposes.
169 Parameters
170 ----------
171 qgraph: `lsst.pipe.base.QuantumGraph`
172 QuantumGraph instance.
173 file : `str` or file object
174 File where GraphViz graph (DOT language) is written, can be a file name
175 or file object.
177 Raises
178 ------
179 `OSError` is raised when output file cannot be open.
180 `ImportError` is raised when task class cannot be imported.
181 """
182 # open a file if needed
183 close = False
184 if not hasattr(file, "write"):
185 file = open(file, "w")
186 close = True
188 print("digraph QuantumGraph {", file=file)
189 _renderDefault("graph", _ATTRIBS["defaultGraph"], file)
190 _renderDefault("node", _ATTRIBS["defaultNode"], file)
191 _renderDefault("edge", _ATTRIBS["defaultEdge"], file)
193 allDatasetRefs: dict[str, str] = {}
194 for taskId, taskDef in enumerate(qgraph.taskGraph):
195 quanta = qgraph.getNodesForTask(taskDef)
196 for qId, quantumNode in enumerate(quanta):
197 # node for a task
198 taskNodeName = f"task_{taskId}_{qId}"
199 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
201 # quantum inputs
202 for dsRefs in quantumNode.quantum.inputs.values():
203 for dsRef in dsRefs:
204 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
205 _renderEdge(nodeName, taskNodeName, file)
207 # quantum outputs
208 for dsRefs in quantumNode.quantum.outputs.values():
209 for dsRef in dsRefs:
210 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
211 _renderEdge(taskNodeName, nodeName, file)
213 print("}", file=file)
214 if close:
215 file.close()
218def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:
219 """Convert `~lsst.pipe.base.Pipeline` into GraphViz digraph.
221 This method is mostly for documentation/presentation purposes.
222 Unlike other methods this method does not validate graph consistency.
224 Parameters
225 ----------
226 pipeline : `lsst.pipe.base.Pipeline`
227 Pipeline description.
228 file : `str` or file object
229 File where GraphViz graph (DOT language) is written, can be a file name
230 or file object.
232 Raises
233 ------
234 `OSError` is raised when output file cannot be open.
235 `ImportError` is raised when task class cannot be imported.
236 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
237 provided.
238 """
239 universe = DimensionUniverse()
241 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
242 """Return expanded list of dimensions, with special skypix treatment.
244 Parameters
245 ----------
246 dimensions : `list` [`str`]
248 Returns
249 -------
250 dimensions : `list` [`str`]
251 """
252 dimension_set = set()
253 if isinstance(connection, connectionTypes.DimensionedConnection):
254 dimension_set = set(connection.dimensions)
255 skypix_dim = []
256 if "skypix" in dimension_set:
257 dimension_set.remove("skypix")
258 skypix_dim = ["skypix"]
259 dimension_graph = universe.extract(dimension_set)
260 return list(dimension_graph.names) + skypix_dim
262 # open a file if needed
263 close = False
264 if not hasattr(file, "write"):
265 file = open(file, "w")
266 close = True
268 print("digraph Pipeline {", file=file)
269 _renderDefault("graph", _ATTRIBS["defaultGraph"], file)
270 _renderDefault("node", _ATTRIBS["defaultNode"], file)
271 _renderDefault("edge", _ATTRIBS["defaultEdge"], file)
273 allDatasets: set[str | tuple[str, str]] = set()
274 if isinstance(pipeline, Pipeline):
275 pipeline = pipeline.toExpandedPipeline()
277 # The next two lines are a workaround until DM-29658 at which time metadata
278 # connections should start working with the above code
279 labelToTaskName = {}
280 metadataNodesToLink = set()
282 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
283 # node for a task
284 taskNodeName = f"task{idx}"
286 # next line is workaround until DM-29658
287 labelToTaskName[taskDef.label] = taskNodeName
289 _renderTaskNode(taskNodeName, taskDef, file, None)
291 metadataRePattern = re.compile("^(.*)_metadata$")
292 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
293 if attr.name not in allDatasets:
294 dimensions = expand_dimensions(attr)
295 _renderDSTypeNode(attr.name, dimensions, file)
296 allDatasets.add(attr.name)
297 nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
298 _renderEdge(attr.name, taskNodeName, file)
299 # connect component dataset types to the composite type that
300 # produced it
301 if component is not None and (nodeName, attr.name) not in allDatasets:
302 _renderEdge(nodeName, attr.name, file)
303 allDatasets.add((nodeName, attr.name))
304 if nodeName not in allDatasets:
305 dimensions = expand_dimensions(attr)
306 _renderDSTypeNode(nodeName, dimensions, file)
307 # The next if block is a workaround until DM-29658 at which time
308 # metadata connections should start working with the above code
309 if (match := metadataRePattern.match(attr.name)) is not None:
310 matchTaskLabel = match.group(1)
311 metadataNodesToLink.add((matchTaskLabel, attr.name))
313 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
314 if attr.name not in allDatasets:
315 dimensions = expand_dimensions(attr)
316 _renderDSTypeNode(attr.name, dimensions, file)
317 allDatasets.add(attr.name)
318 # use dashed line for prerequisite edges to distinguish them
319 _renderEdge(attr.name, taskNodeName, file, style="dashed")
321 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
322 if attr.name not in allDatasets:
323 dimensions = expand_dimensions(attr)
324 _renderDSTypeNode(attr.name, dimensions, file)
325 allDatasets.add(attr.name)
326 _renderEdge(taskNodeName, attr.name, file)
328 # This for loop is a workaround until DM-29658 at which time metadata
329 # connections should start working with the above code
330 for matchLabel, dsTypeName in metadataNodesToLink:
331 # only render an edge to metadata if the label is part of the current
332 # graph
333 if (result := labelToTaskName.get(matchLabel)) is not None:
334 _renderEdge(result, dsTypeName, file)
336 print("}", file=file)
337 if close:
338 file.close()