Coverage for python/lsst/ctrl/mpexec/dotTools.py: 10%

142 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-30 02:36 -0800

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to generate GraphViz diagrams from pipelines 

23or quantum graphs. 

24""" 

25 

26from __future__ import annotations 

27 

28__all__ = ["graph2dot", "pipeline2dot"] 

29 

30# ------------------------------- 

31# Imports of standard modules -- 

32# ------------------------------- 

33import io 

34import re 

35from typing import TYPE_CHECKING, Any, Iterable, Union 

36 

37# ----------------------------- 

38# Imports for other modules -- 

39# ----------------------------- 

40from lsst.daf.butler import DatasetType, DimensionUniverse 

41from lsst.pipe.base import Pipeline, connectionTypes, iterConnections 

42 

43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true

44 from lsst.daf.butler import DatasetRef 

45 from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef 

46 

47# ---------------------------------- 

48# Local non-exported definitions -- 

49# ---------------------------------- 

50 

51# Node styles indexed by node type. 

52_STYLES = dict( 

53 task=dict(shape="box", style="filled,bold", fillcolor="gray70"), 

54 quantum=dict(shape="box", style="filled,bold", fillcolor="gray70"), 

55 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"), 

56 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"), 

57) 

58 

59 

60def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None: 

61 """Render GV node""" 

62 label = r"\n".join(labels) 

63 attrib_dict = dict(_STYLES[style], label=label) 

64 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib_dict.items()]) 

65 print(f'"{nodeName}" [{attrib}];', file=file) 

66 

67 

68def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None: 

69 """Render GV node for a task""" 

70 labels = [taskDef.label, taskDef.taskName] 

71 if idx is not None: 

72 labels.append(f"index: {idx}") 

73 if taskDef.connections: 

74 # don't print collection of str directly to avoid visually noisy quotes 

75 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions)) 

76 labels.append(f"dimensions: {dimensions_str}") 

77 _renderNode(file, nodeName, "task", labels) 

78 

79 

80def _renderQuantumNode( 

81 nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase 

82) -> None: 

83 """Render GV node for a quantum""" 

84 labels = [f"{quantumNode.nodeId}", taskDef.label] 

85 dataId = quantumNode.quantum.dataId 

86 assert dataId is not None, "Quantum DataId cannot be None" 

87 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys())) 

88 _renderNode(file, nodeName, "quantum", labels) 

89 

90 

91def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None: 

92 """Render GV node for a dataset type""" 

93 labels = [name] 

94 if dimensions: 

95 labels.append("Dimensions: " + ", ".join(sorted(dimensions))) 

96 _renderNode(file, name, "dsType", labels) 

97 

98 

99def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None: 

100 """Render GV node for a dataset""" 

101 labels = [dsRef.datasetType.name, f"run: {dsRef.run!r}"] 

102 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

103 _renderNode(file, nodeName, "dataset", labels) 

104 

105 

106def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None: 

107 """Render GV edge""" 

108 if kwargs: 

109 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()]) 

110 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file) 

111 else: 

112 print(f'"{fromName}" -> "{toName}";', file=file) 

113 

114 

115def _datasetRefId(dsRef: DatasetRef) -> str: 

116 """Make an identifying string for given ref""" 

117 dsId = [dsRef.datasetType.name] 

118 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

119 return ":".join(dsId) 

120 

121 

122def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str: 

123 """Make new node for dataset if it does not exist. 

124 

125 Returns node name. 

126 """ 

127 dsRefId = _datasetRefId(dsRef) 

128 nodeName = allDatasetRefs.get(dsRefId) 

129 if nodeName is None: 

130 idx = len(allDatasetRefs) 

131 nodeName = "dsref_{}".format(idx) 

132 allDatasetRefs[dsRefId] = nodeName 

133 _renderDSNode(nodeName, dsRef, file) 

134 return nodeName 

135 

136 

137# ------------------------ 

138# Exported definitions -- 

139# ------------------------ 

140 

141 

142def graph2dot(qgraph: QuantumGraph, file: Any) -> None: 

143 """Convert QuantumGraph into GraphViz digraph. 

144 

145 This method is mostly for documentation/presentation purposes. 

146 

147 Parameters 

148 ---------- 

149 qgraph: `pipe.base.QuantumGraph` 

150 QuantumGraph instance. 

151 file : str or file object 

152 File where GraphViz graph (DOT language) is written, can be a file name 

153 or file object. 

154 

155 Raises 

156 ------ 

157 `OSError` is raised when output file cannot be open. 

158 `ImportError` is raised when task class cannot be imported. 

159 """ 

160 # open a file if needed 

161 close = False 

162 if not hasattr(file, "write"): 

163 file = open(file, "w") 

164 close = True 

165 

166 print("digraph QuantumGraph {", file=file) 

167 

168 allDatasetRefs: dict[str, str] = {} 

169 for taskId, taskDef in enumerate(qgraph.taskGraph): 

170 

171 quanta = qgraph.getNodesForTask(taskDef) 

172 for qId, quantumNode in enumerate(quanta): 

173 

174 # node for a task 

175 taskNodeName = "task_{}_{}".format(taskId, qId) 

176 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file) 

177 

178 # quantum inputs 

179 for dsRefs in quantumNode.quantum.inputs.values(): 

180 for dsRef in dsRefs: 

181 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

182 _renderEdge(nodeName, taskNodeName, file) 

183 

184 # quantum outputs 

185 for dsRefs in quantumNode.quantum.outputs.values(): 

186 for dsRef in dsRefs: 

187 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

188 _renderEdge(taskNodeName, nodeName, file) 

189 

190 print("}", file=file) 

191 if close: 

192 file.close() 

193 

194 

195def pipeline2dot(pipeline: Union[Pipeline, Iterable[TaskDef]], file: Any) -> None: 

196 """Convert Pipeline into GraphViz digraph. 

197 

198 This method is mostly for documentation/presentation purposes. 

199 Unlike other methods this method does not validate graph consistency. 

200 

201 Parameters 

202 ---------- 

203 pipeline : `pipe.base.Pipeline` 

204 Pipeline description. 

205 file : str or file object 

206 File where GraphViz graph (DOT language) is written, can be a file name 

207 or file object. 

208 

209 Raises 

210 ------ 

211 `OSError` is raised when output file cannot be open. 

212 `ImportError` is raised when task class cannot be imported. 

213 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

214 provided. 

215 """ 

216 universe = DimensionUniverse() 

217 

218 def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]: 

219 """Returns expanded list of dimensions, with special skypix treatment. 

220 

221 Parameters 

222 ---------- 

223 dimensions : `list` [`str`] 

224 

225 Returns 

226 ------- 

227 dimensions : `list` [`str`] 

228 """ 

229 dimension_set = set() 

230 if isinstance(connection, connectionTypes.DimensionedConnection): 

231 dimension_set = set(connection.dimensions) 

232 skypix_dim = [] 

233 if "skypix" in dimension_set: 

234 dimension_set.remove("skypix") 

235 skypix_dim = ["skypix"] 

236 dimension_graph = universe.extract(dimension_set) 

237 return list(dimension_graph.names) + skypix_dim 

238 

239 # open a file if needed 

240 close = False 

241 if not hasattr(file, "write"): 

242 file = open(file, "w") 

243 close = True 

244 

245 print("digraph Pipeline {", file=file) 

246 

247 allDatasets: set[Union[str, tuple[str, str]]] = set() 

248 if isinstance(pipeline, Pipeline): 

249 pipeline = pipeline.toExpandedPipeline() 

250 

251 # The next two lines are a workaround until DM-29658 at which time metadata 

252 # connections should start working with the above code 

253 labelToTaskName = {} 

254 metadataNodesToLink = set() 

255 

256 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)): 

257 

258 # node for a task 

259 taskNodeName = "task{}".format(idx) 

260 

261 # next line is workaround until DM-29658 

262 labelToTaskName[taskDef.label] = taskNodeName 

263 

264 _renderTaskNode(taskNodeName, taskDef, file, None) 

265 

266 metadataRePattern = re.compile("^(.*)_metadata$") 

267 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name): 

268 if attr.name not in allDatasets: 

269 dimensions = expand_dimensions(attr) 

270 _renderDSTypeNode(attr.name, dimensions, file) 

271 allDatasets.add(attr.name) 

272 nodeName, component = DatasetType.splitDatasetTypeName(attr.name) 

273 _renderEdge(attr.name, taskNodeName, file) 

274 # connect component dataset types to the composite type that 

275 # produced it 

276 if component is not None and (nodeName, attr.name) not in allDatasets: 

277 _renderEdge(nodeName, attr.name, file) 

278 allDatasets.add((nodeName, attr.name)) 

279 if nodeName not in allDatasets: 

280 dimensions = expand_dimensions(attr) 

281 _renderDSTypeNode(nodeName, dimensions, file) 

282 # The next if block is a workaround until DM-29658 at which time 

283 # metadata connections should start working with the above code 

284 if (match := metadataRePattern.match(attr.name)) is not None: 

285 matchTaskLabel = match.group(1) 

286 metadataNodesToLink.add((matchTaskLabel, attr.name)) 

287 

288 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name): 

289 if attr.name not in allDatasets: 

290 dimensions = expand_dimensions(attr) 

291 _renderDSTypeNode(attr.name, dimensions, file) 

292 allDatasets.add(attr.name) 

293 # use dashed line for prerequisite edges to distinguish them 

294 _renderEdge(attr.name, taskNodeName, file, style="dashed") 

295 

296 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name): 

297 if attr.name not in allDatasets: 

298 dimensions = expand_dimensions(attr) 

299 _renderDSTypeNode(attr.name, dimensions, file) 

300 allDatasets.add(attr.name) 

301 _renderEdge(taskNodeName, attr.name, file) 

302 

303 # This for loop is a workaround until DM-29658 at which time metadata 

304 # connections should start working with the above code 

305 for matchLabel, dsTypeName in metadataNodesToLink: 

306 # only render an edge to metadata if the label is part of the current 

307 # graph 

308 if (result := labelToTaskName.get(matchLabel)) is not None: 

309 _renderEdge(result, dsTypeName, file) 

310 

311 print("}", file=file) 

312 if close: 

313 file.close()