Coverage for python/lsst/ctrl/mpexec/dotTools.py: 10%

153 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-06 02:30 +0000

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to generate GraphViz diagrams from pipelines 

23or quantum graphs. 

24""" 

25 

26from __future__ import annotations 

27 

28__all__ = ["graph2dot", "pipeline2dot"] 

29 

30# ------------------------------- 

31# Imports of standard modules -- 

32# ------------------------------- 

33import html 

34import io 

35import re 

36from collections.abc import Iterable 

37from typing import TYPE_CHECKING, Any 

38 

39# ----------------------------- 

40# Imports for other modules -- 

41# ----------------------------- 

42from lsst.daf.butler import DatasetType, DimensionUniverse 

43from lsst.pipe.base import Pipeline, connectionTypes, iterConnections 

44 

45if TYPE_CHECKING: 

46 from lsst.daf.butler import DatasetRef 

47 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef 

48 

49# ---------------------------------- 

50# Local non-exported definitions -- 

51# ---------------------------------- 

52 

53# Attributes applied to directed graph objects. 

54_NODELABELPOINTSIZE = "18" 

55_ATTRIBS = dict( 

56 defaultGraph=dict(splines="ortho", nodesep="0.5", ranksep="0.75", pad="0.5"), 

57 defaultNode=dict(shape="box", fontname="Monospace", fontsize="14", margin="0.2,0.1", penwidth="3"), 

58 defaultEdge=dict(color="black", arrowsize="1.5", penwidth="1.5"), 

59 task=dict(style="filled", color="black", fillcolor="#B1F2EF"), 

60 quantum=dict(style="filled", color="black", fillcolor="#B1F2EF"), 

61 dsType=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"), 

62 dataset=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"), 

63) 

64 

65 

66def _renderDefault(type: str, attribs: dict[str, str], file: io.TextIOBase) -> None: 

67 """Set default attributes for a given type.""" 

68 default_attribs = ", ".join([f'{key}="{val}"' for key, val in attribs.items()]) 

69 print(f"{type} [{default_attribs}];", file=file) 

70 

71 

72def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None: 

73 """Render GV node""" 

74 label = r"</TD></TR><TR><TD>".join(labels) 

75 attrib_dict = dict(_ATTRIBS[style], label=label) 

76 pre = '<<TABLE BORDER="0" CELLPADDING="5"><TR><TD>' 

77 post = "</TD></TR></TABLE>>" 

78 attrib = ", ".join( 

79 [ 

80 f'{key}="{val}"' if key != "label" else f"{key}={pre}{val}{post}" 

81 for key, val in attrib_dict.items() 

82 ] 

83 ) 

84 print(f'"{nodeName}" [{attrib}];', file=file) 

85 

86 

87def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None: 

88 """Render GV node for a task""" 

89 labels = [ 

90 f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(taskDef.label) + "</FONT></B>", 

91 html.escape(taskDef.taskName), 

92 ] 

93 if idx is not None: 

94 labels.append(f"<I>index:</I>&nbsp;{idx}") 

95 if taskDef.connections: 

96 # don't print collection of str directly to avoid visually noisy quotes 

97 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions)) 

98 labels.append(f"<I>dimensions:</I>&nbsp;{html.escape(dimensions_str)}") 

99 _renderNode(file, nodeName, "task", labels) 

100 

101 

102def _renderQuantumNode( 

103 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase 

104) -> None: 

105 """Render GV node for a quantum""" 

106 labels = [f"{quantumNode.nodeId}", html.escape(taskDef.label)] 

107 dataId = quantumNode.quantum.dataId 

108 assert dataId is not None, "Quantum DataId cannot be None" 

109 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys())) 

110 _renderNode(file, nodeName, "quantum", labels) 

111 

112 

113def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None: 

114 """Render GV node for a dataset type""" 

115 labels = [f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(name) + "</FONT></B>"] 

116 if dimensions: 

117 labels.append("<I>dimensions:</I>&nbsp;" + html.escape(", ".join(sorted(dimensions)))) 

118 _renderNode(file, name, "dsType", labels) 

119 

120 

121def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None: 

122 """Render GV node for a dataset""" 

123 labels = [html.escape(dsRef.datasetType.name), f"run: {dsRef.run!r}"] 

124 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

125 _renderNode(file, nodeName, "dataset", labels) 

126 

127 

128def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None: 

129 """Render GV edge""" 

130 if kwargs: 

131 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()]) 

132 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file) 

133 else: 

134 print(f'"{fromName}" -> "{toName}";', file=file) 

135 

136 

137def _datasetRefId(dsRef: DatasetRef) -> str: 

138 """Make an identifying string for given ref""" 

139 dsId = [dsRef.datasetType.name] 

140 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

141 return ":".join(dsId) 

142 

143 

144def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str: 

145 """Make new node for dataset if it does not exist. 

146 

147 Returns node name. 

148 """ 

149 dsRefId = _datasetRefId(dsRef) 

150 nodeName = allDatasetRefs.get(dsRefId) 

151 if nodeName is None: 

152 idx = len(allDatasetRefs) 

153 nodeName = f"dsref_{idx}" 

154 allDatasetRefs[dsRefId] = nodeName 

155 _renderDSNode(nodeName, dsRef, file) 

156 return nodeName 

157 

158 

159# ------------------------ 

160# Exported definitions -- 

161# ------------------------ 

162 

163 

164def graph2dot(qgraph: QuantumGraph, file: Any) -> None: 

165 """Convert QuantumGraph into GraphViz digraph. 

166 

167 This method is mostly for documentation/presentation purposes. 

168 

169 Parameters 

170 ---------- 

171 qgraph: `lsst.pipe.base.QuantumGraph` 

172 QuantumGraph instance. 

173 file : `str` or file object 

174 File where GraphViz graph (DOT language) is written, can be a file name 

175 or file object. 

176 

177 Raises 

178 ------ 

179 `OSError` is raised when output file cannot be open. 

180 `ImportError` is raised when task class cannot be imported. 

181 """ 

182 # open a file if needed 

183 close = False 

184 if not hasattr(file, "write"): 

185 file = open(file, "w") 

186 close = True 

187 

188 print("digraph QuantumGraph {", file=file) 

189 _renderDefault("graph", _ATTRIBS["defaultGraph"], file) 

190 _renderDefault("node", _ATTRIBS["defaultNode"], file) 

191 _renderDefault("edge", _ATTRIBS["defaultEdge"], file) 

192 

193 allDatasetRefs: dict[str, str] = {} 

194 for taskId, taskDef in enumerate(qgraph.taskGraph): 

195 quanta = qgraph.getNodesForTask(taskDef) 

196 for qId, quantumNode in enumerate(quanta): 

197 # node for a task 

198 taskNodeName = f"task_{taskId}_{qId}" 

199 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file) 

200 

201 # quantum inputs 

202 for dsRefs in quantumNode.quantum.inputs.values(): 

203 for dsRef in dsRefs: 

204 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

205 _renderEdge(nodeName, taskNodeName, file) 

206 

207 # quantum outputs 

208 for dsRefs in quantumNode.quantum.outputs.values(): 

209 for dsRef in dsRefs: 

210 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

211 _renderEdge(taskNodeName, nodeName, file) 

212 

213 print("}", file=file) 

214 if close: 

215 file.close() 

216 

217 

218def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None: 

219 """Convert `~lsst.pipe.base.Pipeline` into GraphViz digraph. 

220 

221 This method is mostly for documentation/presentation purposes. 

222 Unlike other methods this method does not validate graph consistency. 

223 

224 Parameters 

225 ---------- 

226 pipeline : `lsst.pipe.base.Pipeline` 

227 Pipeline description. 

228 file : `str` or file object 

229 File where GraphViz graph (DOT language) is written, can be a file name 

230 or file object. 

231 

232 Raises 

233 ------ 

234 `OSError` is raised when output file cannot be open. 

235 `ImportError` is raised when task class cannot be imported. 

236 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

237 provided. 

238 """ 

239 universe = DimensionUniverse() 

240 

241 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]: 

242 """Return expanded list of dimensions, with special skypix treatment. 

243 

244 Parameters 

245 ---------- 

246 dimensions : `list` [`str`] 

247 

248 Returns 

249 ------- 

250 dimensions : `list` [`str`] 

251 """ 

252 dimension_set = set() 

253 if isinstance(connection, connectionTypes.DimensionedConnection): 

254 dimension_set = set(connection.dimensions) 

255 skypix_dim = [] 

256 if "skypix" in dimension_set: 

257 dimension_set.remove("skypix") 

258 skypix_dim = ["skypix"] 

259 dimension_graph = universe.extract(dimension_set) 

260 return list(dimension_graph.names) + skypix_dim 

261 

262 # open a file if needed 

263 close = False 

264 if not hasattr(file, "write"): 

265 file = open(file, "w") 

266 close = True 

267 

268 print("digraph Pipeline {", file=file) 

269 _renderDefault("graph", _ATTRIBS["defaultGraph"], file) 

270 _renderDefault("node", _ATTRIBS["defaultNode"], file) 

271 _renderDefault("edge", _ATTRIBS["defaultEdge"], file) 

272 

273 allDatasets: set[str | tuple[str, str]] = set() 

274 if isinstance(pipeline, Pipeline): 

275 pipeline = pipeline.toExpandedPipeline() 

276 

277 # The next two lines are a workaround until DM-29658 at which time metadata 

278 # connections should start working with the above code 

279 labelToTaskName = {} 

280 metadataNodesToLink = set() 

281 

282 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)): 

283 # node for a task 

284 taskNodeName = f"task{idx}" 

285 

286 # next line is workaround until DM-29658 

287 labelToTaskName[taskDef.label] = taskNodeName 

288 

289 _renderTaskNode(taskNodeName, taskDef, file, None) 

290 

291 metadataRePattern = re.compile("^(.*)_metadata$") 

292 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name): 

293 if attr.name not in allDatasets: 

294 dimensions = expand_dimensions(attr) 

295 _renderDSTypeNode(attr.name, dimensions, file) 

296 allDatasets.add(attr.name) 

297 nodeName, component = DatasetType.splitDatasetTypeName(attr.name) 

298 _renderEdge(attr.name, taskNodeName, file) 

299 # connect component dataset types to the composite type that 

300 # produced it 

301 if component is not None and (nodeName, attr.name) not in allDatasets: 

302 _renderEdge(nodeName, attr.name, file) 

303 allDatasets.add((nodeName, attr.name)) 

304 if nodeName not in allDatasets: 

305 dimensions = expand_dimensions(attr) 

306 _renderDSTypeNode(nodeName, dimensions, file) 

307 # The next if block is a workaround until DM-29658 at which time 

308 # metadata connections should start working with the above code 

309 if (match := metadataRePattern.match(attr.name)) is not None: 

310 matchTaskLabel = match.group(1) 

311 metadataNodesToLink.add((matchTaskLabel, attr.name)) 

312 

313 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name): 

314 if attr.name not in allDatasets: 

315 dimensions = expand_dimensions(attr) 

316 _renderDSTypeNode(attr.name, dimensions, file) 

317 allDatasets.add(attr.name) 

318 # use dashed line for prerequisite edges to distinguish them 

319 _renderEdge(attr.name, taskNodeName, file, style="dashed") 

320 

321 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name): 

322 if attr.name not in allDatasets: 

323 dimensions = expand_dimensions(attr) 

324 _renderDSTypeNode(attr.name, dimensions, file) 

325 allDatasets.add(attr.name) 

326 _renderEdge(taskNodeName, attr.name, file) 

327 

328 # This for loop is a workaround until DM-29658 at which time metadata 

329 # connections should start working with the above code 

330 for matchLabel, dsTypeName in metadataNodesToLink: 

331 # only render an edge to metadata if the label is part of the current 

332 # graph 

333 if (result := labelToTaskName.get(matchLabel)) is not None: 

334 _renderEdge(result, dsTypeName, file) 

335 

336 print("}", file=file) 

337 if close: 

338 file.close()