Coverage for python/lsst/ctrl/mpexec/dotTools.py: 10%
142 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-09 11:21 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-09 11:21 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to generate GraphViz diagrams from pipelines
23or quantum graphs.
24"""
26from __future__ import annotations
28__all__ = ["graph2dot", "pipeline2dot"]
30# -------------------------------
31# Imports of standard modules --
32# -------------------------------
33import io
34import re
35from typing import TYPE_CHECKING, Any, Iterable, Union
37# -----------------------------
38# Imports for other modules --
39# -----------------------------
40from lsst.daf.butler import DatasetType, DimensionUniverse
41from lsst.pipe.base import Pipeline, connectionTypes, iterConnections
43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true
44 from lsst.daf.butler import DatasetRef
45 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef
47# ----------------------------------
48# Local non-exported definitions --
49# ----------------------------------
51# Node styles indexed by node type.
52_STYLES = dict(
53 task=dict(shape="box", style="filled,bold", fillcolor="gray70"),
54 quantum=dict(shape="box", style="filled,bold", fillcolor="gray70"),
55 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
56 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"),
57)
60def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
61 """Render GV node"""
62 label = r"\n".join(labels)
63 attrib_dict = dict(_STYLES[style], label=label)
64 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib_dict.items()])
65 print(f'"{nodeName}" [{attrib}];', file=file)
68def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
69 """Render GV node for a task"""
70 labels = [taskDef.label, taskDef.taskName]
71 if idx is not None:
72 labels.append(f"index: {idx}")
73 if taskDef.connections:
74 # don't print collection of str directly to avoid visually noisy quotes
75 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
76 labels.append(f"dimensions: {dimensions_str}")
77 _renderNode(file, nodeName, "task", labels)
80def _renderQuantumNode(
81 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
82) -> None:
83 """Render GV node for a quantum"""
84 labels = [f"{quantumNode.nodeId}", taskDef.label]
85 dataId = quantumNode.quantum.dataId
86 assert dataId is not None, "Quantum DataId cannot be None"
87 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys()))
88 _renderNode(file, nodeName, "quantum", labels)
91def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
92 """Render GV node for a dataset type"""
93 labels = [name]
94 if dimensions:
95 labels.append("Dimensions: " + ", ".join(sorted(dimensions)))
96 _renderNode(file, name, "dsType", labels)
99def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
100 """Render GV node for a dataset"""
101 labels = [dsRef.datasetType.name, f"run: {dsRef.run!r}"]
102 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
103 _renderNode(file, nodeName, "dataset", labels)
106def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
107 """Render GV edge"""
108 if kwargs:
109 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
110 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
111 else:
112 print(f'"{fromName}" -> "{toName}";', file=file)
115def _datasetRefId(dsRef: DatasetRef) -> str:
116 """Make an identifying string for given ref"""
117 dsId = [dsRef.datasetType.name]
118 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
119 return ":".join(dsId)
122def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
123 """Make new node for dataset if it does not exist.
125 Returns node name.
126 """
127 dsRefId = _datasetRefId(dsRef)
128 nodeName = allDatasetRefs.get(dsRefId)
129 if nodeName is None:
130 idx = len(allDatasetRefs)
131 nodeName = "dsref_{}".format(idx)
132 allDatasetRefs[dsRefId] = nodeName
133 _renderDSNode(nodeName, dsRef, file)
134 return nodeName
137# ------------------------
138# Exported definitions --
139# ------------------------
142def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
143 """Convert QuantumGraph into GraphViz digraph.
145 This method is mostly for documentation/presentation purposes.
147 Parameters
148 ----------
149 qgraph: `pipe.base.QuantumGraph`
150 QuantumGraph instance.
151 file : str or file object
152 File where GraphViz graph (DOT language) is written, can be a file name
153 or file object.
155 Raises
156 ------
157 `OSError` is raised when output file cannot be open.
158 `ImportError` is raised when task class cannot be imported.
159 """
160 # open a file if needed
161 close = False
162 if not hasattr(file, "write"):
163 file = open(file, "w")
164 close = True
166 print("digraph QuantumGraph {", file=file)
168 allDatasetRefs: dict[str, str] = {}
169 for taskId, taskDef in enumerate(qgraph.taskGraph):
170 quanta = qgraph.getNodesForTask(taskDef)
171 for qId, quantumNode in enumerate(quanta):
172 # node for a task
173 taskNodeName = "task_{}_{}".format(taskId, qId)
174 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
176 # quantum inputs
177 for dsRefs in quantumNode.quantum.inputs.values():
178 for dsRef in dsRefs:
179 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
180 _renderEdge(nodeName, taskNodeName, file)
182 # quantum outputs
183 for dsRefs in quantumNode.quantum.outputs.values():
184 for dsRef in dsRefs:
185 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
186 _renderEdge(taskNodeName, nodeName, file)
188 print("}", file=file)
189 if close:
190 file.close()
193def pipeline2dot(pipeline: Union[Pipeline, Iterable[TaskDef]], file: Any) -> None:
194 """Convert Pipeline into GraphViz digraph.
196 This method is mostly for documentation/presentation purposes.
197 Unlike other methods this method does not validate graph consistency.
199 Parameters
200 ----------
201 pipeline : `pipe.base.Pipeline`
202 Pipeline description.
203 file : str or file object
204 File where GraphViz graph (DOT language) is written, can be a file name
205 or file object.
207 Raises
208 ------
209 `OSError` is raised when output file cannot be open.
210 `ImportError` is raised when task class cannot be imported.
211 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
212 provided.
213 """
214 universe = DimensionUniverse()
216 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
217 """Returns expanded list of dimensions, with special skypix treatment.
219 Parameters
220 ----------
221 dimensions : `list` [`str`]
223 Returns
224 -------
225 dimensions : `list` [`str`]
226 """
227 dimension_set = set()
228 if isinstance(connection, connectionTypes.DimensionedConnection):
229 dimension_set = set(connection.dimensions)
230 skypix_dim = []
231 if "skypix" in dimension_set:
232 dimension_set.remove("skypix")
233 skypix_dim = ["skypix"]
234 dimension_graph = universe.extract(dimension_set)
235 return list(dimension_graph.names) + skypix_dim
237 # open a file if needed
238 close = False
239 if not hasattr(file, "write"):
240 file = open(file, "w")
241 close = True
243 print("digraph Pipeline {", file=file)
245 allDatasets: set[Union[str, tuple[str, str]]] = set()
246 if isinstance(pipeline, Pipeline):
247 pipeline = pipeline.toExpandedPipeline()
249 # The next two lines are a workaround until DM-29658 at which time metadata
250 # connections should start working with the above code
251 labelToTaskName = {}
252 metadataNodesToLink = set()
254 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
255 # node for a task
256 taskNodeName = "task{}".format(idx)
258 # next line is workaround until DM-29658
259 labelToTaskName[taskDef.label] = taskNodeName
261 _renderTaskNode(taskNodeName, taskDef, file, None)
263 metadataRePattern = re.compile("^(.*)_metadata$")
264 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
265 if attr.name not in allDatasets:
266 dimensions = expand_dimensions(attr)
267 _renderDSTypeNode(attr.name, dimensions, file)
268 allDatasets.add(attr.name)
269 nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
270 _renderEdge(attr.name, taskNodeName, file)
271 # connect component dataset types to the composite type that
272 # produced it
273 if component is not None and (nodeName, attr.name) not in allDatasets:
274 _renderEdge(nodeName, attr.name, file)
275 allDatasets.add((nodeName, attr.name))
276 if nodeName not in allDatasets:
277 dimensions = expand_dimensions(attr)
278 _renderDSTypeNode(nodeName, dimensions, file)
279 # The next if block is a workaround until DM-29658 at which time
280 # metadata connections should start working with the above code
281 if (match := metadataRePattern.match(attr.name)) is not None:
282 matchTaskLabel = match.group(1)
283 metadataNodesToLink.add((matchTaskLabel, attr.name))
285 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
286 if attr.name not in allDatasets:
287 dimensions = expand_dimensions(attr)
288 _renderDSTypeNode(attr.name, dimensions, file)
289 allDatasets.add(attr.name)
290 # use dashed line for prerequisite edges to distinguish them
291 _renderEdge(attr.name, taskNodeName, file, style="dashed")
293 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
294 if attr.name not in allDatasets:
295 dimensions = expand_dimensions(attr)
296 _renderDSTypeNode(attr.name, dimensions, file)
297 allDatasets.add(attr.name)
298 _renderEdge(taskNodeName, attr.name, file)
300 # This for loop is a workaround until DM-29658 at which time metadata
301 # connections should start working with the above code
302 for matchLabel, dsTypeName in metadataNodesToLink:
303 # only render an edge to metadata if the label is part of the current
304 # graph
305 if (result := labelToTaskName.get(matchLabel)) is not None:
306 _renderEdge(result, dsTypeName, file)
308 print("}", file=file)
309 if close:
310 file.close()