Coverage for python/lsst/ctrl/mpexec/dotTools.py: 9%
115 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 18:04 -0800
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 18:04 -0800
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to generate GraphViz diagrams from pipelines
23or quantum graphs.
24"""
26__all__ = ["graph2dot", "pipeline2dot"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
32# -----------------------------
33# Imports for other modules --
34# -----------------------------
35from lsst.daf.butler import DimensionUniverse
36from lsst.pipe.base import iterConnections, Pipeline
38# ----------------------------------
39# Local non-exported definitions --
40# ----------------------------------
42# Node styles indexed by node type.
43_STYLES = dict(
44 task=dict(shape="box", style="filled,bold", fillcolor="gray70"),
45 quantum=dict(shape="box", style="filled,bold", fillcolor="gray70"),
46 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
47 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
48)
51def _renderNode(file, nodeName, style, labels):
52 """Render GV node"""
53 label = r'\n'.join(labels)
54 attrib = dict(_STYLES[style], label=label)
55 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib.items()])
56 print(f'"{nodeName}" [{attrib}];', file=file)
59def _renderTaskNode(nodeName, taskDef, file, idx=None):
60 """Render GV node for a task"""
61 labels = [taskDef.label, taskDef.taskName]
62 if idx is not None:
63 labels.append(f"index: {idx}")
64 if taskDef.connections:
65 # don't print collection of str directly to avoid visually noisy quotes
66 dimensions_str = ', '.join(taskDef.connections.dimensions)
67 labels.append(f"dimensions: {dimensions_str}")
68 _renderNode(file, nodeName, "task", labels)
71def _renderQuantumNode(nodeName, taskDef, quantumNode, file):
72 """Render GV node for a quantum"""
73 labels = [f"{quantumNode.nodeId}", taskDef.label]
74 dataId = quantumNode.quantum.dataId
75 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys()))
76 _renderNode(file, nodeName, "quantum", labels)
79def _renderDSTypeNode(name, dimensions, file):
80 """Render GV node for a dataset type"""
81 labels = [name]
82 if dimensions:
83 labels.append("Dimensions: " + ", ".join(dimensions))
84 _renderNode(file, name, "dsType", labels)
87def _renderDSNode(nodeName, dsRef, file):
88 """Render GV node for a dataset"""
89 labels = [dsRef.datasetType.name, f"run: {dsRef.run!r}"]
90 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
91 _renderNode(file, nodeName, "dataset", labels)
94def _renderEdge(fromName, toName, file, **kwargs):
95 """Render GV edge"""
96 if kwargs:
97 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
98 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
99 else:
100 print(f'"{fromName}" -> "{toName}";', file=file)
103def _datasetRefId(dsRef):
104 """Make an identifying string for given ref"""
105 dsId = [dsRef.datasetType.name]
106 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
107 return ":".join(dsId)
110def _makeDSNode(dsRef, allDatasetRefs, file):
111 """Make new node for dataset if it does not exist.
113 Returns node name.
114 """
115 dsRefId = _datasetRefId(dsRef)
116 nodeName = allDatasetRefs.get(dsRefId)
117 if nodeName is None:
118 idx = len(allDatasetRefs)
119 nodeName = "dsref_{}".format(idx)
120 allDatasetRefs[dsRefId] = nodeName
121 _renderDSNode(nodeName, dsRef, file)
122 return nodeName
124# ------------------------
125# Exported definitions --
126# ------------------------
129def graph2dot(qgraph, file):
130 """Convert QuantumGraph into GraphViz digraph.
132 This method is mostly for documentation/presentation purposes.
134 Parameters
135 ----------
136 qgraph: `pipe.base.QuantumGraph`
137 QuantumGraph instance.
138 file : str or file object
139 File where GraphViz graph (DOT language) is written, can be a file name
140 or file object.
142 Raises
143 ------
144 `OSError` is raised when output file cannot be open.
145 `ImportError` is raised when task class cannot be imported.
146 """
147 # open a file if needed
148 close = False
149 if not hasattr(file, "write"):
150 file = open(file, "w")
151 close = True
153 print("digraph QuantumGraph {", file=file)
155 allDatasetRefs = {}
156 for taskId, taskDef in enumerate(qgraph.taskGraph):
158 quanta = qgraph.getNodesForTask(taskDef)
159 for qId, quantumNode in enumerate(quanta):
161 # node for a task
162 taskNodeName = "task_{}_{}".format(taskId, qId)
163 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
165 # quantum inputs
166 for dsRefs in quantumNode.quantum.inputs.values():
167 for dsRef in dsRefs:
168 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
169 _renderEdge(nodeName, taskNodeName, file)
171 # quantum outputs
172 for dsRefs in quantumNode.quantum.outputs.values():
173 for dsRef in dsRefs:
174 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
175 _renderEdge(taskNodeName, nodeName, file)
177 print("}", file=file)
178 if close:
179 file.close()
182def pipeline2dot(pipeline, file):
183 """Convert Pipeline into GraphViz digraph.
185 This method is mostly for documentation/presentation purposes.
186 Unlike other methods this method does not validate graph consistency.
188 Parameters
189 ----------
190 pipeline : `pipe.base.Pipeline`
191 Pipeline description.
192 file : str or file object
193 File where GraphViz graph (DOT language) is written, can be a file name
194 or file object.
196 Raises
197 ------
198 `OSError` is raised when output file cannot be open.
199 `ImportError` is raised when task class cannot be imported.
200 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
201 provided.
202 """
203 universe = DimensionUniverse()
205 def expand_dimensions(dimensions):
206 """Returns expanded list of dimensions, with special skypix treatment.
208 Parameters
209 ----------
210 dimensions : `list` [`str`]
212 Returns
213 -------
214 dimensions : `list` [`str`]
215 """
216 dimensions = set(dimensions)
217 skypix_dim = []
218 if "skypix" in dimensions:
219 dimensions.remove("skypix")
220 skypix_dim = ["skypix"]
221 dimensions = universe.extract(dimensions)
222 return list(dimensions.names) + skypix_dim
224 # open a file if needed
225 close = False
226 if not hasattr(file, "write"):
227 file = open(file, "w")
228 close = True
230 print("digraph Pipeline {", file=file)
232 allDatasets = set()
233 if isinstance(pipeline, Pipeline):
234 pipeline = pipeline.toExpandedPipeline()
235 for idx, taskDef in enumerate(pipeline):
237 # node for a task
238 taskNodeName = "task{}".format(idx)
239 _renderTaskNode(taskNodeName, taskDef, file, idx)
241 for attr in iterConnections(taskDef.connections, 'inputs'):
242 if attr.name not in allDatasets:
243 dimensions = expand_dimensions(attr.dimensions)
244 _renderDSTypeNode(attr.name, dimensions, file)
245 allDatasets.add(attr.name)
246 _renderEdge(attr.name, taskNodeName, file)
248 for attr in iterConnections(taskDef.connections, 'prerequisiteInputs'):
249 if attr.name not in allDatasets:
250 dimensions = expand_dimensions(attr.dimensions)
251 _renderDSTypeNode(attr.name, dimensions, file)
252 allDatasets.add(attr.name)
253 # use dashed line for prerequisite edges to distinguish them
254 _renderEdge(attr.name, taskNodeName, file, style="dashed")
256 for attr in iterConnections(taskDef.connections, 'outputs'):
257 if attr.name not in allDatasets:
258 dimensions = expand_dimensions(attr.dimensions)
259 _renderDSTypeNode(attr.name, dimensions, file)
260 allDatasets.add(attr.name)
261 _renderEdge(taskNodeName, attr.name, file)
263 print("}", file=file)
264 if close:
265 file.close()