Coverage for python/lsst/ctrl/mpexec/dotTools.py: 8%
133 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-01 12:18 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-01 12:18 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to generate GraphViz diagrams from pipelines
23or quantum graphs.
24"""
26__all__ = ["graph2dot", "pipeline2dot"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31import re
33# -----------------------------
34# Imports for other modules --
35# -----------------------------
36from lsst.daf.butler import DatasetType, DimensionUniverse
37from lsst.pipe.base import Pipeline, iterConnections
39# ----------------------------------
40# Local non-exported definitions --
41# ----------------------------------
43# Node styles indexed by node type.
44_STYLES = dict(
45 task=dict(shape="box", style="filled,bold", fillcolor="gray70"),
46 quantum=dict(shape="box", style="filled,bold", fillcolor="gray70"),
47 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
48 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
49)
52def _renderNode(file, nodeName, style, labels):
53 """Render GV node"""
54 label = r"\n".join(labels)
55 attrib = dict(_STYLES[style], label=label)
56 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib.items()])
57 print(f'"{nodeName}" [{attrib}];', file=file)
60def _renderTaskNode(nodeName, taskDef, file, idx=None):
61 """Render GV node for a task"""
62 labels = [taskDef.label, taskDef.taskName]
63 if idx is not None:
64 labels.append(f"index: {idx}")
65 if taskDef.connections:
66 # don't print collection of str directly to avoid visually noisy quotes
67 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
68 labels.append(f"dimensions: {dimensions_str}")
69 _renderNode(file, nodeName, "task", labels)
72def _renderQuantumNode(nodeName, taskDef, quantumNode, file):
73 """Render GV node for a quantum"""
74 labels = [f"{quantumNode.nodeId}", taskDef.label]
75 dataId = quantumNode.quantum.dataId
76 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys()))
77 _renderNode(file, nodeName, "quantum", labels)
80def _renderDSTypeNode(name, dimensions, file):
81 """Render GV node for a dataset type"""
82 labels = [name]
83 if dimensions:
84 labels.append("Dimensions: " + ", ".join(sorted(dimensions)))
85 _renderNode(file, name, "dsType", labels)
88def _renderDSNode(nodeName, dsRef, file):
89 """Render GV node for a dataset"""
90 labels = [dsRef.datasetType.name, f"run: {dsRef.run!r}"]
91 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
92 _renderNode(file, nodeName, "dataset", labels)
95def _renderEdge(fromName, toName, file, **kwargs):
96 """Render GV edge"""
97 if kwargs:
98 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
99 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
100 else:
101 print(f'"{fromName}" -> "{toName}";', file=file)
104def _datasetRefId(dsRef):
105 """Make an identifying string for given ref"""
106 dsId = [dsRef.datasetType.name]
107 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
108 return ":".join(dsId)
111def _makeDSNode(dsRef, allDatasetRefs, file):
112 """Make new node for dataset if it does not exist.
114 Returns node name.
115 """
116 dsRefId = _datasetRefId(dsRef)
117 nodeName = allDatasetRefs.get(dsRefId)
118 if nodeName is None:
119 idx = len(allDatasetRefs)
120 nodeName = "dsref_{}".format(idx)
121 allDatasetRefs[dsRefId] = nodeName
122 _renderDSNode(nodeName, dsRef, file)
123 return nodeName
126# ------------------------
127# Exported definitions --
128# ------------------------
131def graph2dot(qgraph, file):
132 """Convert QuantumGraph into GraphViz digraph.
134 This method is mostly for documentation/presentation purposes.
136 Parameters
137 ----------
138 qgraph: `pipe.base.QuantumGraph`
139 QuantumGraph instance.
140 file : str or file object
141 File where GraphViz graph (DOT language) is written, can be a file name
142 or file object.
144 Raises
145 ------
146 `OSError` is raised when output file cannot be open.
147 `ImportError` is raised when task class cannot be imported.
148 """
149 # open a file if needed
150 close = False
151 if not hasattr(file, "write"):
152 file = open(file, "w")
153 close = True
155 print("digraph QuantumGraph {", file=file)
157 allDatasetRefs = {}
158 for taskId, taskDef in enumerate(qgraph.taskGraph):
160 quanta = qgraph.getNodesForTask(taskDef)
161 for qId, quantumNode in enumerate(quanta):
163 # node for a task
164 taskNodeName = "task_{}_{}".format(taskId, qId)
165 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
167 # quantum inputs
168 for dsRefs in quantumNode.quantum.inputs.values():
169 for dsRef in dsRefs:
170 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
171 _renderEdge(nodeName, taskNodeName, file)
173 # quantum outputs
174 for dsRefs in quantumNode.quantum.outputs.values():
175 for dsRef in dsRefs:
176 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
177 _renderEdge(taskNodeName, nodeName, file)
179 print("}", file=file)
180 if close:
181 file.close()
184def pipeline2dot(pipeline, file):
185 """Convert Pipeline into GraphViz digraph.
187 This method is mostly for documentation/presentation purposes.
188 Unlike other methods this method does not validate graph consistency.
190 Parameters
191 ----------
192 pipeline : `pipe.base.Pipeline`
193 Pipeline description.
194 file : str or file object
195 File where GraphViz graph (DOT language) is written, can be a file name
196 or file object.
198 Raises
199 ------
200 `OSError` is raised when output file cannot be open.
201 `ImportError` is raised when task class cannot be imported.
202 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
203 provided.
204 """
205 universe = DimensionUniverse()
207 def expand_dimensions(dimensions):
208 """Returns expanded list of dimensions, with special skypix treatment.
210 Parameters
211 ----------
212 dimensions : `list` [`str`]
214 Returns
215 -------
216 dimensions : `list` [`str`]
217 """
218 dimensions = set(dimensions)
219 skypix_dim = []
220 if "skypix" in dimensions:
221 dimensions.remove("skypix")
222 skypix_dim = ["skypix"]
223 dimensions = universe.extract(dimensions)
224 return list(dimensions.names) + skypix_dim
226 # open a file if needed
227 close = False
228 if not hasattr(file, "write"):
229 file = open(file, "w")
230 close = True
232 print("digraph Pipeline {", file=file)
234 allDatasets = set()
235 if isinstance(pipeline, Pipeline):
236 pipeline = pipeline.toExpandedPipeline()
238 # The next two lines are a workaround until DM-29658 at which time metadata
239 # connections should start working with the above code
240 labelToTaskName = {}
241 metadataNodesToLink = set()
243 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
245 # node for a task
246 taskNodeName = "task{}".format(idx)
248 # next line is workaround until DM-29658
249 labelToTaskName[taskDef.label] = taskNodeName
251 _renderTaskNode(taskNodeName, taskDef, file, None)
253 metadataRePattern = re.compile("^(.*)_metadata$")
254 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
255 if attr.name not in allDatasets:
256 dimensions = expand_dimensions(attr.dimensions)
257 _renderDSTypeNode(attr.name, dimensions, file)
258 allDatasets.add(attr.name)
259 nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
260 _renderEdge(attr.name, taskNodeName, file)
261 # connect component dataset types to the composite type that
262 # produced it
263 if component is not None and (nodeName, attr.name) not in allDatasets:
264 _renderEdge(nodeName, attr.name, file)
265 allDatasets.add((nodeName, attr.name))
266 if nodeName not in allDatasets:
267 dimensions = expand_dimensions(attr.dimensions)
268 _renderDSTypeNode(nodeName, dimensions, file)
269 # The next if block is a workaround until DM-29658 at which time
270 # metadata connections should start working with the above code
271 if (match := metadataRePattern.match(attr.name)) is not None:
272 matchTaskLabel = match.group(1)
273 metadataNodesToLink.add((matchTaskLabel, attr.name))
275 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
276 if attr.name not in allDatasets:
277 dimensions = expand_dimensions(attr.dimensions)
278 _renderDSTypeNode(attr.name, dimensions, file)
279 allDatasets.add(attr.name)
280 # use dashed line for prerequisite edges to distinguish them
281 _renderEdge(attr.name, taskNodeName, file, style="dashed")
283 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
284 if attr.name not in allDatasets:
285 dimensions = expand_dimensions(attr.dimensions)
286 _renderDSTypeNode(attr.name, dimensions, file)
287 allDatasets.add(attr.name)
288 _renderEdge(taskNodeName, attr.name, file)
290 # This for loop is a workaround until DM-29658 at which time metadata
291 # connections should start working with the above code
292 for matchLabel, dsTypeName in metadataNodesToLink:
293 # only render an edge to metadata if the label is part of the current
294 # graph
295 if (result := labelToTaskName.get(matchLabel)) is not None:
296 _renderEdge(result, dsTypeName, file)
298 print("}", file=file)
299 if close:
300 file.close()