Coverage for python/lsst/ctrl/mpexec/dotTools.py: 10%

156 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 02:55 -0700

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Module defining few methods to generate GraphViz diagrams from pipelines 

29or quantum graphs. 

30""" 

31 

32from __future__ import annotations 

33 

34__all__ = ["graph2dot", "pipeline2dot"] 

35 

36# ------------------------------- 

37# Imports of standard modules -- 

38# ------------------------------- 

39import html 

40import io 

41import re 

42import warnings 

43from collections.abc import Iterable 

44from typing import TYPE_CHECKING, Any 

45 

46# ----------------------------- 

47# Imports for other modules -- 

48# ----------------------------- 

49from lsst.daf.butler import DatasetType, DimensionUniverse 

50from lsst.pipe.base import Pipeline, connectionTypes, iterConnections 

51 

52if TYPE_CHECKING: 

53 from lsst.daf.butler import DatasetRef 

54 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef 

55 

56# ---------------------------------- 

57# Local non-exported definitions -- 

58# ---------------------------------- 

59 

60# Attributes applied to directed graph objects. 

61_NODELABELPOINTSIZE = "18" 

62_ATTRIBS = dict( 

63 defaultGraph=dict(splines="ortho", nodesep="0.5", ranksep="0.75", pad="0.5"), 

64 defaultNode=dict(shape="box", fontname="Monospace", fontsize="14", margin="0.2,0.1", penwidth="3"), 

65 defaultEdge=dict(color="black", arrowsize="1.5", penwidth="1.5"), 

66 task=dict(style="filled", color="black", fillcolor="#B1F2EF"), 

67 quantum=dict(style="filled", color="black", fillcolor="#B1F2EF"), 

68 dsType=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"), 

69 dataset=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"), 

70) 

71 

72 

73def _renderDefault(type: str, attribs: dict[str, str], file: io.TextIOBase) -> None: 

74 """Set default attributes for a given type.""" 

75 default_attribs = ", ".join([f'{key}="{val}"' for key, val in attribs.items()]) 

76 print(f"{type} [{default_attribs}];", file=file) 

77 

78 

79def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None: 

80 """Render GV node""" 

81 label = r"</TD></TR><TR><TD>".join(labels) 

82 attrib_dict = dict(_ATTRIBS[style], label=label) 

83 pre = '<<TABLE BORDER="0" CELLPADDING="5"><TR><TD>' 

84 post = "</TD></TR></TABLE>>" 

85 attrib = ", ".join( 

86 [ 

87 f'{key}="{val}"' if key != "label" else f"{key}={pre}{val}{post}" 

88 for key, val in attrib_dict.items() 

89 ] 

90 ) 

91 print(f'"{nodeName}" [{attrib}];', file=file) 

92 

93 

94def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None: 

95 """Render GV node for a task""" 

96 labels = [ 

97 f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(taskDef.label) + "</FONT></B>", 

98 html.escape(taskDef.taskName), 

99 ] 

100 if idx is not None: 

101 labels.append(f"<I>index:</I>&nbsp;{idx}") 

102 if taskDef.connections: 

103 # don't print collection of str directly to avoid visually noisy quotes 

104 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions)) 

105 labels.append(f"<I>dimensions:</I>&nbsp;{html.escape(dimensions_str)}") 

106 _renderNode(file, nodeName, "task", labels) 

107 

108 

109def _renderQuantumNode( 

110 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase 

111) -> None: 

112 """Render GV node for a quantum""" 

113 labels = [f"{quantumNode.nodeId}", html.escape(taskDef.label)] 

114 dataId = quantumNode.quantum.dataId 

115 assert dataId is not None, "Quantum DataId cannot be None" 

116 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys())) 

117 _renderNode(file, nodeName, "quantum", labels) 

118 

119 

120def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None: 

121 """Render GV node for a dataset type""" 

122 labels = [f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(name) + "</FONT></B>"] 

123 if dimensions: 

124 labels.append("<I>dimensions:</I>&nbsp;" + html.escape(", ".join(sorted(dimensions)))) 

125 _renderNode(file, name, "dsType", labels) 

126 

127 

128def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None: 

129 """Render GV node for a dataset""" 

130 labels = [html.escape(dsRef.datasetType.name), f"run: {dsRef.run!r}"] 

131 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

132 _renderNode(file, nodeName, "dataset", labels) 

133 

134 

135def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None: 

136 """Render GV edge""" 

137 if kwargs: 

138 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()]) 

139 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file) 

140 else: 

141 print(f'"{fromName}" -> "{toName}";', file=file) 

142 

143 

144def _datasetRefId(dsRef: DatasetRef) -> str: 

145 """Make an identifying string for given ref""" 

146 dsId = [dsRef.datasetType.name] 

147 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

148 return ":".join(dsId) 

149 

150 

151def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str: 

152 """Make new node for dataset if it does not exist. 

153 

154 Returns node name. 

155 """ 

156 dsRefId = _datasetRefId(dsRef) 

157 nodeName = allDatasetRefs.get(dsRefId) 

158 if nodeName is None: 

159 idx = len(allDatasetRefs) 

160 nodeName = f"dsref_{idx}" 

161 allDatasetRefs[dsRefId] = nodeName 

162 _renderDSNode(nodeName, dsRef, file) 

163 return nodeName 

164 

165 

166# ------------------------ 

167# Exported definitions -- 

168# ------------------------ 

169 

170 

171def graph2dot(qgraph: QuantumGraph, file: Any) -> None: 

172 """Convert QuantumGraph into GraphViz digraph. 

173 

174 This method is mostly for documentation/presentation purposes. 

175 

176 Parameters 

177 ---------- 

178 qgraph : `lsst.pipe.base.QuantumGraph` 

179 QuantumGraph instance. 

180 file : `str` or file object 

181 File where GraphViz graph (DOT language) is written, can be a file name 

182 or file object. 

183 

184 Raises 

185 ------ 

186 `OSError` is raised when output file cannot be open. 

187 `ImportError` is raised when task class cannot be imported. 

188 """ 

189 # open a file if needed 

190 close = False 

191 if not hasattr(file, "write"): 

192 file = open(file, "w") 

193 close = True 

194 

195 print("digraph QuantumGraph {", file=file) 

196 _renderDefault("graph", _ATTRIBS["defaultGraph"], file) 

197 _renderDefault("node", _ATTRIBS["defaultNode"], file) 

198 _renderDefault("edge", _ATTRIBS["defaultEdge"], file) 

199 

200 allDatasetRefs: dict[str, str] = {} 

201 for taskId, taskDef in enumerate(qgraph.taskGraph): 

202 quanta = qgraph.getNodesForTask(taskDef) 

203 for qId, quantumNode in enumerate(quanta): 

204 # node for a task 

205 taskNodeName = f"task_{taskId}_{qId}" 

206 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file) 

207 

208 # quantum inputs 

209 for dsRefs in quantumNode.quantum.inputs.values(): 

210 for dsRef in dsRefs: 

211 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

212 _renderEdge(nodeName, taskNodeName, file) 

213 

214 # quantum outputs 

215 for dsRefs in quantumNode.quantum.outputs.values(): 

216 for dsRef in dsRefs: 

217 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

218 _renderEdge(taskNodeName, nodeName, file) 

219 

220 print("}", file=file) 

221 if close: 

222 file.close() 

223 

224 

225def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None: 

226 """Convert `~lsst.pipe.base.Pipeline` into GraphViz digraph. 

227 

228 This method is mostly for documentation/presentation purposes. 

229 Unlike other methods this method does not validate graph consistency. 

230 

231 Parameters 

232 ---------- 

233 pipeline : `lsst.pipe.base.Pipeline` 

234 Pipeline description. 

235 file : `str` or file object 

236 File where GraphViz graph (DOT language) is written, can be a file name 

237 or file object. 

238 

239 Raises 

240 ------ 

241 `OSError` is raised when output file cannot be open. 

242 `ImportError` is raised when task class cannot be imported. 

243 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

244 provided. 

245 """ 

246 universe = DimensionUniverse() 

247 

248 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]: 

249 """Return expanded list of dimensions, with special skypix treatment. 

250 

251 Parameters 

252 ---------- 

253 connection : `list` [`str`] 

254 Connection to examine. 

255 

256 Returns 

257 ------- 

258 dimensions : `list` [`str`] 

259 Expanded list of dimensions. 

260 """ 

261 dimension_set = set() 

262 if isinstance(connection, connectionTypes.DimensionedConnection): 

263 dimension_set = set(connection.dimensions) 

264 skypix_dim = [] 

265 if "skypix" in dimension_set: 

266 dimension_set.remove("skypix") 

267 skypix_dim = ["skypix"] 

268 dimensions = universe.conform(dimension_set) 

269 return list(dimensions.names) + skypix_dim 

270 

271 # open a file if needed 

272 close = False 

273 if not hasattr(file, "write"): 

274 file = open(file, "w") 

275 close = True 

276 

277 print("digraph Pipeline {", file=file) 

278 _renderDefault("graph", _ATTRIBS["defaultGraph"], file) 

279 _renderDefault("node", _ATTRIBS["defaultNode"], file) 

280 _renderDefault("edge", _ATTRIBS["defaultEdge"], file) 

281 

282 allDatasets: set[str | tuple[str, str]] = set() 

283 if isinstance(pipeline, Pipeline): 

284 # TODO: DM-40639 will rewrite this code and finish off the deprecation 

285 # of toExpandedPipeline. 

286 with warnings.catch_warnings(): 

287 warnings.simplefilter("ignore", category=FutureWarning) 

288 pipeline = pipeline.toExpandedPipeline() 

289 

290 # The next two lines are a workaround until DM-29658 at which time metadata 

291 # connections should start working with the above code 

292 labelToTaskName = {} 

293 metadataNodesToLink = set() 

294 

295 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)): 

296 # node for a task 

297 taskNodeName = f"task{idx}" 

298 

299 # next line is workaround until DM-29658 

300 labelToTaskName[taskDef.label] = taskNodeName 

301 

302 _renderTaskNode(taskNodeName, taskDef, file, None) 

303 

304 metadataRePattern = re.compile("^(.*)_metadata$") 

305 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name): 

306 if attr.name not in allDatasets: 

307 dimensions = expand_dimensions(attr) 

308 _renderDSTypeNode(attr.name, dimensions, file) 

309 allDatasets.add(attr.name) 

310 nodeName, component = DatasetType.splitDatasetTypeName(attr.name) 

311 _renderEdge(attr.name, taskNodeName, file) 

312 # connect component dataset types to the composite type that 

313 # produced it 

314 if component is not None and (nodeName, attr.name) not in allDatasets: 

315 _renderEdge(nodeName, attr.name, file) 

316 allDatasets.add((nodeName, attr.name)) 

317 if nodeName not in allDatasets: 

318 dimensions = expand_dimensions(attr) 

319 _renderDSTypeNode(nodeName, dimensions, file) 

320 # The next if block is a workaround until DM-29658 at which time 

321 # metadata connections should start working with the above code 

322 if (match := metadataRePattern.match(attr.name)) is not None: 

323 matchTaskLabel = match.group(1) 

324 metadataNodesToLink.add((matchTaskLabel, attr.name)) 

325 

326 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name): 

327 if attr.name not in allDatasets: 

328 dimensions = expand_dimensions(attr) 

329 _renderDSTypeNode(attr.name, dimensions, file) 

330 allDatasets.add(attr.name) 

331 # use dashed line for prerequisite edges to distinguish them 

332 _renderEdge(attr.name, taskNodeName, file, style="dashed") 

333 

334 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name): 

335 if attr.name not in allDatasets: 

336 dimensions = expand_dimensions(attr) 

337 _renderDSTypeNode(attr.name, dimensions, file) 

338 allDatasets.add(attr.name) 

339 _renderEdge(taskNodeName, attr.name, file) 

340 

341 # This for loop is a workaround until DM-29658 at which time metadata 

342 # connections should start working with the above code 

343 for matchLabel, dsTypeName in metadataNodesToLink: 

344 # only render an edge to metadata if the label is part of the current 

345 # graph 

346 if (result := labelToTaskName.get(matchLabel)) is not None: 

347 _renderEdge(result, dsTypeName, file) 

348 

349 print("}", file=file) 

350 if close: 

351 file.close()