Coverage for python/lsst/ctrl/mpexec/dotTools.py: 10%
142 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-23 23:17 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-23 23:17 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to generate GraphViz diagrams from pipelines
23or quantum graphs.
24"""
26from __future__ import annotations
28__all__ = ["graph2dot", "pipeline2dot"]
30# -------------------------------
31# Imports of standard modules --
32# -------------------------------
33import io
34import re
35from typing import TYPE_CHECKING, Any, Iterable, Union
37# -----------------------------
38# Imports for other modules --
39# -----------------------------
40from lsst.daf.butler import DatasetType, DimensionUniverse
41from lsst.pipe.base import Pipeline, connectionTypes, iterConnections
43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true
44 from lsst.daf.butler import DatasetRef
45 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef
47# ----------------------------------
48# Local non-exported definitions --
49# ----------------------------------
51# Node styles indexed by node type.
52_STYLES = dict(
53 task=dict(shape="box", style="filled,bold", fillcolor="gray70"),
54 quantum=dict(shape="box", style="filled,bold", fillcolor="gray70"),
55 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
56 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
57)
60def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
61 """Render GV node"""
62 label = r"\n".join(labels)
63 attrib_dict = dict(_STYLES[style], label=label)
64 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib_dict.items()])
65 print(f'"{nodeName}" [{attrib}];', file=file)
68def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
69 """Render GV node for a task"""
70 labels = [taskDef.label, taskDef.taskName]
71 if idx is not None:
72 labels.append(f"index: {idx}")
73 if taskDef.connections:
74 # don't print collection of str directly to avoid visually noisy quotes
75 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
76 labels.append(f"dimensions: {dimensions_str}")
77 _renderNode(file, nodeName, "task", labels)
80def _renderQuantumNode(
81 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
82) -> None:
83 """Render GV node for a quantum"""
84 labels = [f"{quantumNode.nodeId}", taskDef.label]
85 dataId = quantumNode.quantum.dataId
86 assert dataId is not None, "Quantum DataId cannot be None"
87 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys()))
88 _renderNode(file, nodeName, "quantum", labels)
91def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
92 """Render GV node for a dataset type"""
93 labels = [name]
94 if dimensions:
95 labels.append("Dimensions: " + ", ".join(sorted(dimensions)))
96 _renderNode(file, name, "dsType", labels)
99def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
100 """Render GV node for a dataset"""
101 labels = [dsRef.datasetType.name, f"run: {dsRef.run!r}"]
102 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
103 _renderNode(file, nodeName, "dataset", labels)
106def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
107 """Render GV edge"""
108 if kwargs:
109 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
110 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
111 else:
112 print(f'"{fromName}" -> "{toName}";', file=file)
115def _datasetRefId(dsRef: DatasetRef) -> str:
116 """Make an identifying string for given ref"""
117 dsId = [dsRef.datasetType.name]
118 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
119 return ":".join(dsId)
122def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
123 """Make new node for dataset if it does not exist.
125 Returns node name.
126 """
127 dsRefId = _datasetRefId(dsRef)
128 nodeName = allDatasetRefs.get(dsRefId)
129 if nodeName is None:
130 idx = len(allDatasetRefs)
131 nodeName = "dsref_{}".format(idx)
132 allDatasetRefs[dsRefId] = nodeName
133 _renderDSNode(nodeName, dsRef, file)
134 return nodeName
137# ------------------------
138# Exported definitions --
139# ------------------------
142def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
143 """Convert QuantumGraph into GraphViz digraph.
145 This method is mostly for documentation/presentation purposes.
147 Parameters
148 ----------
149 qgraph: `pipe.base.QuantumGraph`
150 QuantumGraph instance.
151 file : str or file object
152 File where GraphViz graph (DOT language) is written, can be a file name
153 or file object.
155 Raises
156 ------
157 `OSError` is raised when output file cannot be open.
158 `ImportError` is raised when task class cannot be imported.
159 """
160 # open a file if needed
161 close = False
162 if not hasattr(file, "write"):
163 file = open(file, "w")
164 close = True
166 print("digraph QuantumGraph {", file=file)
168 allDatasetRefs: dict[str, str] = {}
169 for taskId, taskDef in enumerate(qgraph.taskGraph):
171 quanta = qgraph.getNodesForTask(taskDef)
172 for qId, quantumNode in enumerate(quanta):
174 # node for a task
175 taskNodeName = "task_{}_{}".format(taskId, qId)
176 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
178 # quantum inputs
179 for dsRefs in quantumNode.quantum.inputs.values():
180 for dsRef in dsRefs:
181 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
182 _renderEdge(nodeName, taskNodeName, file)
184 # quantum outputs
185 for dsRefs in quantumNode.quantum.outputs.values():
186 for dsRef in dsRefs:
187 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
188 _renderEdge(taskNodeName, nodeName, file)
190 print("}", file=file)
191 if close:
192 file.close()
195def pipeline2dot(pipeline: Union[Pipeline, Iterable[TaskDef]], file: Any) -> None:
196 """Convert Pipeline into GraphViz digraph.
198 This method is mostly for documentation/presentation purposes.
199 Unlike other methods this method does not validate graph consistency.
201 Parameters
202 ----------
203 pipeline : `pipe.base.Pipeline`
204 Pipeline description.
205 file : str or file object
206 File where GraphViz graph (DOT language) is written, can be a file name
207 or file object.
209 Raises
210 ------
211 `OSError` is raised when output file cannot be open.
212 `ImportError` is raised when task class cannot be imported.
213 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
214 provided.
215 """
216 universe = DimensionUniverse()
218 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
219 """Returns expanded list of dimensions, with special skypix treatment.
221 Parameters
222 ----------
223 dimensions : `list` [`str`]
225 Returns
226 -------
227 dimensions : `list` [`str`]
228 """
229 dimension_set = set()
230 if isinstance(connection, connectionTypes.DimensionedConnection):
231 dimension_set = set(connection.dimensions)
232 skypix_dim = []
233 if "skypix" in dimension_set:
234 dimension_set.remove("skypix")
235 skypix_dim = ["skypix"]
236 dimension_graph = universe.extract(dimension_set)
237 return list(dimension_graph.names) + skypix_dim
239 # open a file if needed
240 close = False
241 if not hasattr(file, "write"):
242 file = open(file, "w")
243 close = True
245 print("digraph Pipeline {", file=file)
247 allDatasets: set[Union[str, tuple[str, str]]] = set()
248 if isinstance(pipeline, Pipeline):
249 pipeline = pipeline.toExpandedPipeline()
251 # The next two lines are a workaround until DM-29658 at which time metadata
252 # connections should start working with the above code
253 labelToTaskName = {}
254 metadataNodesToLink = set()
256 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
258 # node for a task
259 taskNodeName = "task{}".format(idx)
261 # next line is workaround until DM-29658
262 labelToTaskName[taskDef.label] = taskNodeName
264 _renderTaskNode(taskNodeName, taskDef, file, None)
266 metadataRePattern = re.compile("^(.*)_metadata$")
267 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
268 if attr.name not in allDatasets:
269 dimensions = expand_dimensions(attr)
270 _renderDSTypeNode(attr.name, dimensions, file)
271 allDatasets.add(attr.name)
272 nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
273 _renderEdge(attr.name, taskNodeName, file)
274 # connect component dataset types to the composite type that
275 # produced it
276 if component is not None and (nodeName, attr.name) not in allDatasets:
277 _renderEdge(nodeName, attr.name, file)
278 allDatasets.add((nodeName, attr.name))
279 if nodeName not in allDatasets:
280 dimensions = expand_dimensions(attr)
281 _renderDSTypeNode(nodeName, dimensions, file)
282 # The next if block is a workaround until DM-29658 at which time
283 # metadata connections should start working with the above code
284 if (match := metadataRePattern.match(attr.name)) is not None:
285 matchTaskLabel = match.group(1)
286 metadataNodesToLink.add((matchTaskLabel, attr.name))
288 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
289 if attr.name not in allDatasets:
290 dimensions = expand_dimensions(attr)
291 _renderDSTypeNode(attr.name, dimensions, file)
292 allDatasets.add(attr.name)
293 # use dashed line for prerequisite edges to distinguish them
294 _renderEdge(attr.name, taskNodeName, file, style="dashed")
296 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
297 if attr.name not in allDatasets:
298 dimensions = expand_dimensions(attr)
299 _renderDSTypeNode(attr.name, dimensions, file)
300 allDatasets.add(attr.name)
301 _renderEdge(taskNodeName, attr.name, file)
303 # This for loop is a workaround until DM-29658 at which time metadata
304 # connections should start working with the above code
305 for matchLabel, dsTypeName in metadataNodesToLink:
306 # only render an edge to metadata if the label is part of the current
307 # graph
308 if (result := labelToTaskName.get(matchLabel)) is not None:
309 _renderEdge(result, dsTypeName, file)
311 print("}", file=file)
312 if close:
313 file.close()