Coverage for python/lsst/ctrl/mpexec/dotTools.py : 8%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to generate GraphViz diagrams from pipelines
23or quantum graphs.
24"""
26__all__ = ["graph2dot", "pipeline2dot"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
32# -----------------------------
33# Imports for other modules --
34# -----------------------------
35from lsst.daf.butler import DimensionUniverse
36from lsst.pipe.base import iterConnections, Pipeline
38# ----------------------------------
39# Local non-exported definitions --
40# ----------------------------------
42# Node styles indexed by node type.
43_STYLES = dict(
44 task=dict(shape="box", style="filled,bold", fillcolor="gray70"),
45 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
46 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
47)
50def _renderNode(file, nodeName, style, labels):
51 """Render GV node"""
52 label = r'\n'.join(labels)
53 attrib = dict(_STYLES[style], label=label)
54 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib.items()])
55 print(f'"{nodeName}" [{attrib}];', file=file)
58def _renderTaskNode(nodeName, taskDef, file, idx=None):
59 """Render GV node for a task"""
60 labels = [taskDef.taskName.rpartition('.')[-1]]
61 if idx is not None:
62 labels += [f"index: {idx}"]
63 if taskDef.label:
64 labels += [f"label: {taskDef.label}"]
65 _renderNode(file, nodeName, "task", labels)
68def _renderDSTypeNode(name, dimensions, file):
69 """Render GV node for a dataset type"""
70 labels = [name]
71 if dimensions:
72 labels += ["Dimensions: " + ", ".join(dimensions)]
73 _renderNode(file, name, "dsType", labels)
76def _renderDSNode(nodeName, dsRef, file):
77 """Render GV node for a dataset"""
78 labels = [dsRef.datasetType.name]
79 labels += [f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())]
80 _renderNode(file, nodeName, "dataset", labels)
83def _renderEdge(fromName, toName, file, **kwargs):
84 """Render GV edge"""
85 if kwargs:
86 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
87 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
88 else:
89 print(f'"{fromName}" -> "{toName}";', file=file)
92def _datasetRefId(dsRef):
93 """Make an identifying string for given ref"""
94 dsId = [dsRef.datasetType.name]
95 dsId += [f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())]
96 return ":".join(dsId)
99def _makeDSNode(dsRef, allDatasetRefs, file):
100 """Make new node for dataset if it does not exist.
102 Returns node name.
103 """
104 dsRefId = _datasetRefId(dsRef)
105 nodeName = allDatasetRefs.get(dsRefId)
106 if nodeName is None:
107 idx = len(allDatasetRefs)
108 nodeName = "dsref_{}".format(idx)
109 allDatasetRefs[dsRefId] = nodeName
110 _renderDSNode(nodeName, dsRef, file)
111 return nodeName
113# ------------------------
114# Exported definitions --
115# ------------------------
118def graph2dot(qgraph, file):
119 """Convert QuantumGraph into GraphViz digraph.
121 This method is mostly for documentation/presentation purposes.
123 Parameters
124 ----------
125 qgraph: `pipe.base.QuantumGraph`
126 QuantumGraph instance.
127 file : str or file object
128 File where GraphViz graph (DOT language) is written, can be a file name
129 or file object.
131 Raises
132 ------
133 `OSError` is raised when output file cannot be open.
134 `ImportError` is raised when task class cannot be imported.
135 """
136 # open a file if needed
137 close = False
138 if not hasattr(file, "write"):
139 file = open(file, "w")
140 close = True
142 print("digraph QuantumGraph {", file=file)
144 allDatasetRefs = {}
145 for taskId, nodes in enumerate(qgraph):
147 taskDef = nodes.taskDef
149 for qId, quantum in enumerate(nodes.quanta):
151 # node for a task
152 taskNodeName = "task_{}_{}".format(taskId, qId)
153 _renderTaskNode(taskNodeName, taskDef, file)
155 # quantum inputs
156 for dsRefs in quantum.predictedInputs.values():
157 for dsRef in dsRefs:
158 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
159 _renderEdge(nodeName, taskNodeName, file)
161 # quantum outputs
162 for dsRefs in quantum.outputs.values():
163 for dsRef in dsRefs:
164 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
165 _renderEdge(taskNodeName, nodeName, file)
167 print("}", file=file)
168 if close:
169 file.close()
172def pipeline2dot(pipeline, file):
173 """Convert Pipeline into GraphViz digraph.
175 This method is mostly for documentation/presentation purposes.
176 Unlike other methods this method does not validate graph consistency.
178 Parameters
179 ----------
180 pipeline : `pipe.base.Pipeline`
181 Pipeline description.
182 file : str or file object
183 File where GraphViz graph (DOT language) is written, can be a file name
184 or file object.
186 Raises
187 ------
188 `OSError` is raised when output file cannot be open.
189 `ImportError` is raised when task class cannot be imported.
190 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
191 provided.
192 """
193 universe = DimensionUniverse()
195 def expand_dimensions(dimensions):
196 """Returns expanded list of dimensions, with special skypix treatment.
198 Parameters
199 ----------
200 dimensions : `list` [`str`]
202 Returns
203 -------
204 dimensions : `list` [`str`]
205 """
206 dimensions = set(dimensions)
207 skypix_dim = []
208 if "skypix" in dimensions:
209 dimensions.remove("skypix")
210 skypix_dim = ["skypix"]
211 dimensions = universe.extract(dimensions)
212 return list(dimensions.names) + skypix_dim
214 # open a file if needed
215 close = False
216 if not hasattr(file, "write"):
217 file = open(file, "w")
218 close = True
220 print("digraph Pipeline {", file=file)
222 allDatasets = set()
223 if isinstance(pipeline, Pipeline):
224 pipeline = pipeline.toExpandedPipeline()
225 for idx, taskDef in enumerate(pipeline):
227 # node for a task
228 taskNodeName = "task{}".format(idx)
229 _renderTaskNode(taskNodeName, taskDef, file, idx)
231 for attr in iterConnections(taskDef.connections, 'inputs'):
232 if attr.name not in allDatasets:
233 dimensions = expand_dimensions(attr.dimensions)
234 _renderDSTypeNode(attr.name, dimensions, file)
235 allDatasets.add(attr.name)
236 _renderEdge(attr.name, taskNodeName, file)
238 for attr in iterConnections(taskDef.connections, 'prerequisiteInputs'):
239 if attr.name not in allDatasets:
240 dimensions = expand_dimensions(attr.dimensions)
241 _renderDSTypeNode(attr.name, dimensions, file)
242 allDatasets.add(attr.name)
243 # use dashed line for prerequisite edges to distinguish them
244 _renderEdge(attr.name, taskNodeName, file, style="dashed")
246 for attr in iterConnections(taskDef.connections, 'outputs'):
247 if attr.name not in allDatasets:
248 dimensions = expand_dimensions(attr.dimensions)
249 _renderDSTypeNode(attr.name, dimensions, file)
250 allDatasets.add(attr.name)
251 _renderEdge(taskNodeName, attr.name, file)
253 print("}", file=file)
254 if close:
255 file.close()