Coverage for python/lsst/ctrl/mpexec/dotTools.py: 9%
140 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 09:14 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 09:14 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to generate GraphViz diagrams from pipelines
23or quantum graphs.
24"""
26from __future__ import annotations
28__all__ = ["graph2dot", "pipeline2dot"]
30# -------------------------------
31# Imports of standard modules --
32# -------------------------------
33import io
34import re
35from collections.abc import Iterable
36from typing import TYPE_CHECKING, Any
38# -----------------------------
39# Imports for other modules --
40# -----------------------------
41from lsst.daf.butler import DatasetType, DimensionUniverse
42from lsst.pipe.base import Pipeline, connectionTypes, iterConnections
44if TYPE_CHECKING:
45 from lsst.daf.butler import DatasetRef
46 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef
48# ----------------------------------
49# Local non-exported definitions --
50# ----------------------------------
52# Node styles indexed by node type.
53_STYLES = dict(
54 task=dict(shape="box", style="filled,bold", fillcolor="gray70"),
55 quantum=dict(shape="box", style="filled,bold", fillcolor="gray70"),
56 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
57 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
58)
61def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
62 """Render GV node"""
63 label = r"\n".join(labels)
64 attrib_dict = dict(_STYLES[style], label=label)
65 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib_dict.items()])
66 print(f'"{nodeName}" [{attrib}];', file=file)
69def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
70 """Render GV node for a task"""
71 labels = [taskDef.label, taskDef.taskName]
72 if idx is not None:
73 labels.append(f"index: {idx}")
74 if taskDef.connections:
75 # don't print collection of str directly to avoid visually noisy quotes
76 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
77 labels.append(f"dimensions: {dimensions_str}")
78 _renderNode(file, nodeName, "task", labels)
81def _renderQuantumNode(
82 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
83) -> None:
84 """Render GV node for a quantum"""
85 labels = [f"{quantumNode.nodeId}", taskDef.label]
86 dataId = quantumNode.quantum.dataId
87 assert dataId is not None, "Quantum DataId cannot be None"
88 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys()))
89 _renderNode(file, nodeName, "quantum", labels)
92def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
93 """Render GV node for a dataset type"""
94 labels = [name]
95 if dimensions:
96 labels.append("Dimensions: " + ", ".join(sorted(dimensions)))
97 _renderNode(file, name, "dsType", labels)
100def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
101 """Render GV node for a dataset"""
102 labels = [dsRef.datasetType.name, f"run: {dsRef.run!r}"]
103 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
104 _renderNode(file, nodeName, "dataset", labels)
107def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
108 """Render GV edge"""
109 if kwargs:
110 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
111 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
112 else:
113 print(f'"{fromName}" -> "{toName}";', file=file)
116def _datasetRefId(dsRef: DatasetRef) -> str:
117 """Make an identifying string for given ref"""
118 dsId = [dsRef.datasetType.name]
119 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
120 return ":".join(dsId)
123def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
124 """Make new node for dataset if it does not exist.
126 Returns node name.
127 """
128 dsRefId = _datasetRefId(dsRef)
129 nodeName = allDatasetRefs.get(dsRefId)
130 if nodeName is None:
131 idx = len(allDatasetRefs)
132 nodeName = "dsref_{}".format(idx)
133 allDatasetRefs[dsRefId] = nodeName
134 _renderDSNode(nodeName, dsRef, file)
135 return nodeName
138# ------------------------
139# Exported definitions --
140# ------------------------
143def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
144 """Convert QuantumGraph into GraphViz digraph.
146 This method is mostly for documentation/presentation purposes.
148 Parameters
149 ----------
150 qgraph: `pipe.base.QuantumGraph`
151 QuantumGraph instance.
152 file : str or file object
153 File where GraphViz graph (DOT language) is written, can be a file name
154 or file object.
156 Raises
157 ------
158 `OSError` is raised when output file cannot be open.
159 `ImportError` is raised when task class cannot be imported.
160 """
161 # open a file if needed
162 close = False
163 if not hasattr(file, "write"):
164 file = open(file, "w")
165 close = True
167 print("digraph QuantumGraph {", file=file)
169 allDatasetRefs: dict[str, str] = {}
170 for taskId, taskDef in enumerate(qgraph.taskGraph):
171 quanta = qgraph.getNodesForTask(taskDef)
172 for qId, quantumNode in enumerate(quanta):
173 # node for a task
174 taskNodeName = "task_{}_{}".format(taskId, qId)
175 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
177 # quantum inputs
178 for dsRefs in quantumNode.quantum.inputs.values():
179 for dsRef in dsRefs:
180 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
181 _renderEdge(nodeName, taskNodeName, file)
183 # quantum outputs
184 for dsRefs in quantumNode.quantum.outputs.values():
185 for dsRef in dsRefs:
186 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
187 _renderEdge(taskNodeName, nodeName, file)
189 print("}", file=file)
190 if close:
191 file.close()
194def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:
195 """Convert Pipeline into GraphViz digraph.
197 This method is mostly for documentation/presentation purposes.
198 Unlike other methods this method does not validate graph consistency.
200 Parameters
201 ----------
202 pipeline : `pipe.base.Pipeline`
203 Pipeline description.
204 file : str or file object
205 File where GraphViz graph (DOT language) is written, can be a file name
206 or file object.
208 Raises
209 ------
210 `OSError` is raised when output file cannot be open.
211 `ImportError` is raised when task class cannot be imported.
212 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
213 provided.
214 """
215 universe = DimensionUniverse()
217 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
218 """Return expanded list of dimensions, with special skypix treatment.
220 Parameters
221 ----------
222 dimensions : `list` [`str`]
224 Returns
225 -------
226 dimensions : `list` [`str`]
227 """
228 dimension_set = set()
229 if isinstance(connection, connectionTypes.DimensionedConnection):
230 dimension_set = set(connection.dimensions)
231 skypix_dim = []
232 if "skypix" in dimension_set:
233 dimension_set.remove("skypix")
234 skypix_dim = ["skypix"]
235 dimension_graph = universe.extract(dimension_set)
236 return list(dimension_graph.names) + skypix_dim
238 # open a file if needed
239 close = False
240 if not hasattr(file, "write"):
241 file = open(file, "w")
242 close = True
244 print("digraph Pipeline {", file=file)
246 allDatasets: set[str | tuple[str, str]] = set()
247 if isinstance(pipeline, Pipeline):
248 pipeline = pipeline.toExpandedPipeline()
250 # The next two lines are a workaround until DM-29658 at which time metadata
251 # connections should start working with the above code
252 labelToTaskName = {}
253 metadataNodesToLink = set()
255 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
256 # node for a task
257 taskNodeName = "task{}".format(idx)
259 # next line is workaround until DM-29658
260 labelToTaskName[taskDef.label] = taskNodeName
262 _renderTaskNode(taskNodeName, taskDef, file, None)
264 metadataRePattern = re.compile("^(.*)_metadata$")
265 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
266 if attr.name not in allDatasets:
267 dimensions = expand_dimensions(attr)
268 _renderDSTypeNode(attr.name, dimensions, file)
269 allDatasets.add(attr.name)
270 nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
271 _renderEdge(attr.name, taskNodeName, file)
272 # connect component dataset types to the composite type that
273 # produced it
274 if component is not None and (nodeName, attr.name) not in allDatasets:
275 _renderEdge(nodeName, attr.name, file)
276 allDatasets.add((nodeName, attr.name))
277 if nodeName not in allDatasets:
278 dimensions = expand_dimensions(attr)
279 _renderDSTypeNode(nodeName, dimensions, file)
280 # The next if block is a workaround until DM-29658 at which time
281 # metadata connections should start working with the above code
282 if (match := metadataRePattern.match(attr.name)) is not None:
283 matchTaskLabel = match.group(1)
284 metadataNodesToLink.add((matchTaskLabel, attr.name))
286 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
287 if attr.name not in allDatasets:
288 dimensions = expand_dimensions(attr)
289 _renderDSTypeNode(attr.name, dimensions, file)
290 allDatasets.add(attr.name)
291 # use dashed line for prerequisite edges to distinguish them
292 _renderEdge(attr.name, taskNodeName, file, style="dashed")
294 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
295 if attr.name not in allDatasets:
296 dimensions = expand_dimensions(attr)
297 _renderDSTypeNode(attr.name, dimensions, file)
298 allDatasets.add(attr.name)
299 _renderEdge(taskNodeName, attr.name, file)
301 # This for loop is a workaround until DM-29658 at which time metadata
302 # connections should start working with the above code
303 for matchLabel, dsTypeName in metadataNodesToLink:
304 # only render an edge to metadata if the label is part of the current
305 # graph
306 if (result := labelToTaskName.get(matchLabel)) is not None:
307 _renderEdge(result, dsTypeName, file)
309 print("}", file=file)
310 if close:
311 file.close()