Coverage for python/lsst/ctrl/mpexec/dotTools.py: 10%
156 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-03 09:59 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-03 09:59 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Module defining few methods to generate GraphViz diagrams from pipelines
29or quantum graphs.
30"""
32from __future__ import annotations
34__all__ = ["graph2dot", "pipeline2dot"]
36# -------------------------------
37# Imports of standard modules --
38# -------------------------------
39import html
40import io
41import re
42import warnings
43from collections.abc import Iterable
44from typing import TYPE_CHECKING, Any
46# -----------------------------
47# Imports for other modules --
48# -----------------------------
49from lsst.daf.butler import DatasetType, DimensionUniverse
50from lsst.pipe.base import Pipeline, connectionTypes, iterConnections
52if TYPE_CHECKING:
53 from lsst.daf.butler import DatasetRef
54 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef
56# ----------------------------------
57# Local non-exported definitions --
58# ----------------------------------
60# Attributes applied to directed graph objects.
61_NODELABELPOINTSIZE = "18"
62_ATTRIBS = dict(
63 defaultGraph=dict(splines="ortho", nodesep="0.5", ranksep="0.75", pad="0.5"),
64 defaultNode=dict(shape="box", fontname="Monospace", fontsize="14", margin="0.2,0.1", penwidth="3"),
65 defaultEdge=dict(color="black", arrowsize="1.5", penwidth="1.5"),
66 task=dict(style="filled", color="black", fillcolor="#B1F2EF"),
67 quantum=dict(style="filled", color="black", fillcolor="#B1F2EF"),
68 dsType=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
69 dataset=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
70)
73def _renderDefault(type: str, attribs: dict[str, str], file: io.TextIOBase) -> None:
74 """Set default attributes for a given type."""
75 default_attribs = ", ".join([f'{key}="{val}"' for key, val in attribs.items()])
76 print(f"{type} [{default_attribs}];", file=file)
79def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
80 """Render GV node"""
81 label = r"</TD></TR><TR><TD>".join(labels)
82 attrib_dict = dict(_ATTRIBS[style], label=label)
83 pre = '<<TABLE BORDER="0" CELLPADDING="5"><TR><TD>'
84 post = "</TD></TR></TABLE>>"
85 attrib = ", ".join(
86 [
87 f'{key}="{val}"' if key != "label" else f"{key}={pre}{val}{post}"
88 for key, val in attrib_dict.items()
89 ]
90 )
91 print(f'"{nodeName}" [{attrib}];', file=file)
94def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
95 """Render GV node for a task"""
96 labels = [
97 f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(taskDef.label) + "</FONT></B>",
98 html.escape(taskDef.taskName),
99 ]
100 if idx is not None:
101 labels.append(f"<I>index:</I> {idx}")
102 if taskDef.connections:
103 # don't print collection of str directly to avoid visually noisy quotes
104 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
105 labels.append(f"<I>dimensions:</I> {html.escape(dimensions_str)}")
106 _renderNode(file, nodeName, "task", labels)
109def _renderQuantumNode(
110 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
111) -> None:
112 """Render GV node for a quantum"""
113 labels = [f"{quantumNode.nodeId}", html.escape(taskDef.label)]
114 dataId = quantumNode.quantum.dataId
115 assert dataId is not None, "Quantum DataId cannot be None"
116 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys()))
117 _renderNode(file, nodeName, "quantum", labels)
120def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
121 """Render GV node for a dataset type"""
122 labels = [f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(name) + "</FONT></B>"]
123 if dimensions:
124 labels.append("<I>dimensions:</I> " + html.escape(", ".join(sorted(dimensions))))
125 _renderNode(file, name, "dsType", labels)
128def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
129 """Render GV node for a dataset"""
130 labels = [html.escape(dsRef.datasetType.name), f"run: {dsRef.run!r}"]
131 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
132 _renderNode(file, nodeName, "dataset", labels)
135def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
136 """Render GV edge"""
137 if kwargs:
138 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
139 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
140 else:
141 print(f'"{fromName}" -> "{toName}";', file=file)
144def _datasetRefId(dsRef: DatasetRef) -> str:
145 """Make an identifying string for given ref"""
146 dsId = [dsRef.datasetType.name]
147 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys()))
148 return ":".join(dsId)
151def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
152 """Make new node for dataset if it does not exist.
154 Returns node name.
155 """
156 dsRefId = _datasetRefId(dsRef)
157 nodeName = allDatasetRefs.get(dsRefId)
158 if nodeName is None:
159 idx = len(allDatasetRefs)
160 nodeName = f"dsref_{idx}"
161 allDatasetRefs[dsRefId] = nodeName
162 _renderDSNode(nodeName, dsRef, file)
163 return nodeName
166# ------------------------
167# Exported definitions --
168# ------------------------
171def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
172 """Convert QuantumGraph into GraphViz digraph.
174 This method is mostly for documentation/presentation purposes.
176 Parameters
177 ----------
178 qgraph : `lsst.pipe.base.QuantumGraph`
179 QuantumGraph instance.
180 file : `str` or file object
181 File where GraphViz graph (DOT language) is written, can be a file name
182 or file object.
184 Raises
185 ------
186 `OSError` is raised when output file cannot be open.
187 `ImportError` is raised when task class cannot be imported.
188 """
189 # open a file if needed
190 close = False
191 if not hasattr(file, "write"):
192 file = open(file, "w")
193 close = True
195 print("digraph QuantumGraph {", file=file)
196 _renderDefault("graph", _ATTRIBS["defaultGraph"], file)
197 _renderDefault("node", _ATTRIBS["defaultNode"], file)
198 _renderDefault("edge", _ATTRIBS["defaultEdge"], file)
200 allDatasetRefs: dict[str, str] = {}
201 for taskId, taskDef in enumerate(qgraph.taskGraph):
202 quanta = qgraph.getNodesForTask(taskDef)
203 for qId, quantumNode in enumerate(quanta):
204 # node for a task
205 taskNodeName = f"task_{taskId}_{qId}"
206 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
208 # quantum inputs
209 for dsRefs in quantumNode.quantum.inputs.values():
210 for dsRef in dsRefs:
211 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
212 _renderEdge(nodeName, taskNodeName, file)
214 # quantum outputs
215 for dsRefs in quantumNode.quantum.outputs.values():
216 for dsRef in dsRefs:
217 nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
218 _renderEdge(taskNodeName, nodeName, file)
220 print("}", file=file)
221 if close:
222 file.close()
225def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:
226 """Convert `~lsst.pipe.base.Pipeline` into GraphViz digraph.
228 This method is mostly for documentation/presentation purposes.
229 Unlike other methods this method does not validate graph consistency.
231 Parameters
232 ----------
233 pipeline : `lsst.pipe.base.Pipeline`
234 Pipeline description.
235 file : `str` or file object
236 File where GraphViz graph (DOT language) is written, can be a file name
237 or file object.
239 Raises
240 ------
241 `OSError` is raised when output file cannot be open.
242 `ImportError` is raised when task class cannot be imported.
243 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
244 provided.
245 """
246 universe = DimensionUniverse()
248 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
249 """Return expanded list of dimensions, with special skypix treatment.
251 Parameters
252 ----------
253 connection : `list` [`str`]
254 Connection to examine.
256 Returns
257 -------
258 dimensions : `list` [`str`]
259 Expanded list of dimensions.
260 """
261 dimension_set = set()
262 if isinstance(connection, connectionTypes.DimensionedConnection):
263 dimension_set = set(connection.dimensions)
264 skypix_dim = []
265 if "skypix" in dimension_set:
266 dimension_set.remove("skypix")
267 skypix_dim = ["skypix"]
268 dimensions = universe.conform(dimension_set)
269 return list(dimensions.names) + skypix_dim
271 # open a file if needed
272 close = False
273 if not hasattr(file, "write"):
274 file = open(file, "w")
275 close = True
277 print("digraph Pipeline {", file=file)
278 _renderDefault("graph", _ATTRIBS["defaultGraph"], file)
279 _renderDefault("node", _ATTRIBS["defaultNode"], file)
280 _renderDefault("edge", _ATTRIBS["defaultEdge"], file)
282 allDatasets: set[str | tuple[str, str]] = set()
283 if isinstance(pipeline, Pipeline):
284 # TODO: DM-40639 will rewrite this code and finish off the deprecation
285 # of toExpandedPipeline.
286 with warnings.catch_warnings():
287 warnings.simplefilter("ignore", category=FutureWarning)
288 pipeline = pipeline.toExpandedPipeline()
290 # The next two lines are a workaround until DM-29658 at which time metadata
291 # connections should start working with the above code
292 labelToTaskName = {}
293 metadataNodesToLink = set()
295 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
296 # node for a task
297 taskNodeName = f"task{idx}"
299 # next line is workaround until DM-29658
300 labelToTaskName[taskDef.label] = taskNodeName
302 _renderTaskNode(taskNodeName, taskDef, file, None)
304 metadataRePattern = re.compile("^(.*)_metadata$")
305 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
306 if attr.name not in allDatasets:
307 dimensions = expand_dimensions(attr)
308 _renderDSTypeNode(attr.name, dimensions, file)
309 allDatasets.add(attr.name)
310 nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
311 _renderEdge(attr.name, taskNodeName, file)
312 # connect component dataset types to the composite type that
313 # produced it
314 if component is not None and (nodeName, attr.name) not in allDatasets:
315 _renderEdge(nodeName, attr.name, file)
316 allDatasets.add((nodeName, attr.name))
317 if nodeName not in allDatasets:
318 dimensions = expand_dimensions(attr)
319 _renderDSTypeNode(nodeName, dimensions, file)
320 # The next if block is a workaround until DM-29658 at which time
321 # metadata connections should start working with the above code
322 if (match := metadataRePattern.match(attr.name)) is not None:
323 matchTaskLabel = match.group(1)
324 metadataNodesToLink.add((matchTaskLabel, attr.name))
326 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
327 if attr.name not in allDatasets:
328 dimensions = expand_dimensions(attr)
329 _renderDSTypeNode(attr.name, dimensions, file)
330 allDatasets.add(attr.name)
331 # use dashed line for prerequisite edges to distinguish them
332 _renderEdge(attr.name, taskNodeName, file, style="dashed")
334 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
335 if attr.name not in allDatasets:
336 dimensions = expand_dimensions(attr)
337 _renderDSTypeNode(attr.name, dimensions, file)
338 allDatasets.add(attr.name)
339 _renderEdge(taskNodeName, attr.name, file)
341 # This for loop is a workaround until DM-29658 at which time metadata
342 # connections should start working with the above code
343 for matchLabel, dsTypeName in metadataNodesToLink:
344 # only render an edge to metadata if the label is part of the current
345 # graph
346 if (result := labelToTaskName.get(matchLabel)) is not None:
347 _renderEdge(result, dsTypeName, file)
349 print("}", file=file)
350 if close:
351 file.close()