Coverage for python/lsst/ctrl/mpexec/dotTools.py: 8%

133 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-01 12:18 +0000

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to generate GraphViz diagrams from pipelines 

23or quantum graphs. 

24""" 

25 

26__all__ = ["graph2dot", "pipeline2dot"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31import re 

32 

33# ----------------------------- 

34# Imports for other modules -- 

35# ----------------------------- 

36from lsst.daf.butler import DatasetType, DimensionUniverse 

37from lsst.pipe.base import Pipeline, iterConnections 

38 

39# ---------------------------------- 

40# Local non-exported definitions -- 

41# ---------------------------------- 

42 

43# Node styles indexed by node type. 

44_STYLES = dict( 

45 task=dict(shape="box", style="filled,bold", fillcolor="gray70"), 

46 quantum=dict(shape="box", style="filled,bold", fillcolor="gray70"), 

47 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"), 

48 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"), 

49) 

50 

51 

52def _renderNode(file, nodeName, style, labels): 

53 """Render GV node""" 

54 label = r"\n".join(labels) 

55 attrib = dict(_STYLES[style], label=label) 

56 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib.items()]) 

57 print(f'"{nodeName}" [{attrib}];', file=file) 

58 

59 

60def _renderTaskNode(nodeName, taskDef, file, idx=None): 

61 """Render GV node for a task""" 

62 labels = [taskDef.label, taskDef.taskName] 

63 if idx is not None: 

64 labels.append(f"index: {idx}") 

65 if taskDef.connections: 

66 # don't print collection of str directly to avoid visually noisy quotes 

67 dimensions_str = ", ".join(sorted(taskDef.connections.dimensions)) 

68 labels.append(f"dimensions: {dimensions_str}") 

69 _renderNode(file, nodeName, "task", labels) 

70 

71 

72def _renderQuantumNode(nodeName, taskDef, quantumNode, file): 

73 """Render GV node for a quantum""" 

74 labels = [f"{quantumNode.nodeId}", taskDef.label] 

75 dataId = quantumNode.quantum.dataId 

76 labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.keys())) 

77 _renderNode(file, nodeName, "quantum", labels) 

78 

79 

80def _renderDSTypeNode(name, dimensions, file): 

81 """Render GV node for a dataset type""" 

82 labels = [name] 

83 if dimensions: 

84 labels.append("Dimensions: " + ", ".join(sorted(dimensions))) 

85 _renderNode(file, name, "dsType", labels) 

86 

87 

88def _renderDSNode(nodeName, dsRef, file): 

89 """Render GV node for a dataset""" 

90 labels = [dsRef.datasetType.name, f"run: {dsRef.run!r}"] 

91 labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

92 _renderNode(file, nodeName, "dataset", labels) 

93 

94 

95def _renderEdge(fromName, toName, file, **kwargs): 

96 """Render GV edge""" 

97 if kwargs: 

98 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()]) 

99 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file) 

100 else: 

101 print(f'"{fromName}" -> "{toName}";', file=file) 

102 

103 

104def _datasetRefId(dsRef): 

105 """Make an identifying string for given ref""" 

106 dsId = [dsRef.datasetType.name] 

107 dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())) 

108 return ":".join(dsId) 

109 

110 

111def _makeDSNode(dsRef, allDatasetRefs, file): 

112 """Make new node for dataset if it does not exist. 

113 

114 Returns node name. 

115 """ 

116 dsRefId = _datasetRefId(dsRef) 

117 nodeName = allDatasetRefs.get(dsRefId) 

118 if nodeName is None: 

119 idx = len(allDatasetRefs) 

120 nodeName = "dsref_{}".format(idx) 

121 allDatasetRefs[dsRefId] = nodeName 

122 _renderDSNode(nodeName, dsRef, file) 

123 return nodeName 

124 

125 

126# ------------------------ 

127# Exported definitions -- 

128# ------------------------ 

129 

130 

131def graph2dot(qgraph, file): 

132 """Convert QuantumGraph into GraphViz digraph. 

133 

134 This method is mostly for documentation/presentation purposes. 

135 

136 Parameters 

137 ---------- 

138 qgraph: `pipe.base.QuantumGraph` 

139 QuantumGraph instance. 

140 file : str or file object 

141 File where GraphViz graph (DOT language) is written, can be a file name 

142 or file object. 

143 

144 Raises 

145 ------ 

146 `OSError` is raised when output file cannot be open. 

147 `ImportError` is raised when task class cannot be imported. 

148 """ 

149 # open a file if needed 

150 close = False 

151 if not hasattr(file, "write"): 

152 file = open(file, "w") 

153 close = True 

154 

155 print("digraph QuantumGraph {", file=file) 

156 

157 allDatasetRefs = {} 

158 for taskId, taskDef in enumerate(qgraph.taskGraph): 

159 

160 quanta = qgraph.getNodesForTask(taskDef) 

161 for qId, quantumNode in enumerate(quanta): 

162 

163 # node for a task 

164 taskNodeName = "task_{}_{}".format(taskId, qId) 

165 _renderQuantumNode(taskNodeName, taskDef, quantumNode, file) 

166 

167 # quantum inputs 

168 for dsRefs in quantumNode.quantum.inputs.values(): 

169 for dsRef in dsRefs: 

170 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

171 _renderEdge(nodeName, taskNodeName, file) 

172 

173 # quantum outputs 

174 for dsRefs in quantumNode.quantum.outputs.values(): 

175 for dsRef in dsRefs: 

176 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

177 _renderEdge(taskNodeName, nodeName, file) 

178 

179 print("}", file=file) 

180 if close: 

181 file.close() 

182 

183 

184def pipeline2dot(pipeline, file): 

185 """Convert Pipeline into GraphViz digraph. 

186 

187 This method is mostly for documentation/presentation purposes. 

188 Unlike other methods this method does not validate graph consistency. 

189 

190 Parameters 

191 ---------- 

192 pipeline : `pipe.base.Pipeline` 

193 Pipeline description. 

194 file : str or file object 

195 File where GraphViz graph (DOT language) is written, can be a file name 

196 or file object. 

197 

198 Raises 

199 ------ 

200 `OSError` is raised when output file cannot be open. 

201 `ImportError` is raised when task class cannot be imported. 

202 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

203 provided. 

204 """ 

205 universe = DimensionUniverse() 

206 

207 def expand_dimensions(dimensions): 

208 """Returns expanded list of dimensions, with special skypix treatment. 

209 

210 Parameters 

211 ---------- 

212 dimensions : `list` [`str`] 

213 

214 Returns 

215 ------- 

216 dimensions : `list` [`str`] 

217 """ 

218 dimensions = set(dimensions) 

219 skypix_dim = [] 

220 if "skypix" in dimensions: 

221 dimensions.remove("skypix") 

222 skypix_dim = ["skypix"] 

223 dimensions = universe.extract(dimensions) 

224 return list(dimensions.names) + skypix_dim 

225 

226 # open a file if needed 

227 close = False 

228 if not hasattr(file, "write"): 

229 file = open(file, "w") 

230 close = True 

231 

232 print("digraph Pipeline {", file=file) 

233 

234 allDatasets = set() 

235 if isinstance(pipeline, Pipeline): 

236 pipeline = pipeline.toExpandedPipeline() 

237 

238 # The next two lines are a workaround until DM-29658 at which time metadata 

239 # connections should start working with the above code 

240 labelToTaskName = {} 

241 metadataNodesToLink = set() 

242 

243 for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)): 

244 

245 # node for a task 

246 taskNodeName = "task{}".format(idx) 

247 

248 # next line is workaround until DM-29658 

249 labelToTaskName[taskDef.label] = taskNodeName 

250 

251 _renderTaskNode(taskNodeName, taskDef, file, None) 

252 

253 metadataRePattern = re.compile("^(.*)_metadata$") 

254 for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name): 

255 if attr.name not in allDatasets: 

256 dimensions = expand_dimensions(attr.dimensions) 

257 _renderDSTypeNode(attr.name, dimensions, file) 

258 allDatasets.add(attr.name) 

259 nodeName, component = DatasetType.splitDatasetTypeName(attr.name) 

260 _renderEdge(attr.name, taskNodeName, file) 

261 # connect component dataset types to the composite type that 

262 # produced it 

263 if component is not None and (nodeName, attr.name) not in allDatasets: 

264 _renderEdge(nodeName, attr.name, file) 

265 allDatasets.add((nodeName, attr.name)) 

266 if nodeName not in allDatasets: 

267 dimensions = expand_dimensions(attr.dimensions) 

268 _renderDSTypeNode(nodeName, dimensions, file) 

269 # The next if block is a workaround until DM-29658 at which time 

270 # metadata connections should start working with the above code 

271 if (match := metadataRePattern.match(attr.name)) is not None: 

272 matchTaskLabel = match.group(1) 

273 metadataNodesToLink.add((matchTaskLabel, attr.name)) 

274 

275 for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name): 

276 if attr.name not in allDatasets: 

277 dimensions = expand_dimensions(attr.dimensions) 

278 _renderDSTypeNode(attr.name, dimensions, file) 

279 allDatasets.add(attr.name) 

280 # use dashed line for prerequisite edges to distinguish them 

281 _renderEdge(attr.name, taskNodeName, file, style="dashed") 

282 

283 for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name): 

284 if attr.name not in allDatasets: 

285 dimensions = expand_dimensions(attr.dimensions) 

286 _renderDSTypeNode(attr.name, dimensions, file) 

287 allDatasets.add(attr.name) 

288 _renderEdge(taskNodeName, attr.name, file) 

289 

290 # This for loop is a workaround until DM-29658 at which time metadata 

291 # connections should start working with the above code 

292 for matchLabel, dsTypeName in metadataNodesToLink: 

293 # only render an edge to metadata if the label is part of the current 

294 # graph 

295 if (result := labelToTaskName.get(matchLabel)) is not None: 

296 _renderEdge(result, dsTypeName, file) 

297 

298 print("}", file=file) 

299 if close: 

300 file.close()