Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to generate GraphViz diagrams from pipelines 

23or quantum graphs. 

24""" 

25 

26__all__ = ["graph2dot", "pipeline2dot"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31 

32# ----------------------------- 

33# Imports for other modules -- 

34# ----------------------------- 

35from lsst.daf.butler import DimensionUniverse 

36from lsst.pipe.base import iterConnections, Pipeline 

37 

38# ---------------------------------- 

39# Local non-exported definitions -- 

40# ---------------------------------- 

41 

42# Node styles indexed by node type. 

43_STYLES = dict( 

44 task=dict(shape="box", style="filled,bold", fillcolor="gray70"), 

45 dsType=dict(shape="box", style="rounded,filled", fillcolor="gray90"), 

46 dataset=dict(shape="box", style="rounded,filled", fillcolor="gray90"), 

47) 

48 

49 

50def _renderNode(file, nodeName, style, labels): 

51 """Render GV node""" 

52 label = r'\n'.join(labels) 

53 attrib = dict(_STYLES[style], label=label) 

54 attrib = ", ".join([f'{key}="{val}"' for key, val in attrib.items()]) 

55 print(f'"{nodeName}" [{attrib}];', file=file) 

56 

57 

58def _renderTaskNode(nodeName, taskDef, file, idx=None): 

59 """Render GV node for a task""" 

60 labels = [taskDef.taskName.rpartition('.')[-1]] 

61 if idx is not None: 

62 labels += [f"index: {idx}"] 

63 if taskDef.label: 

64 labels += [f"label: {taskDef.label}"] 

65 _renderNode(file, nodeName, "task", labels) 

66 

67 

68def _renderDSTypeNode(name, dimensions, file): 

69 """Render GV node for a dataset type""" 

70 labels = [name] 

71 if dimensions: 

72 labels += ["Dimensions: " + ", ".join(dimensions)] 

73 _renderNode(file, name, "dsType", labels) 

74 

75 

76def _renderDSNode(nodeName, dsRef, file): 

77 """Render GV node for a dataset""" 

78 labels = [dsRef.datasetType.name] 

79 labels += [f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())] 

80 _renderNode(file, nodeName, "dataset", labels) 

81 

82 

83def _renderEdge(fromName, toName, file, **kwargs): 

84 """Render GV edge""" 

85 if kwargs: 

86 attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()]) 

87 print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file) 

88 else: 

89 print(f'"{fromName}" -> "{toName}";', file=file) 

90 

91 

92def _datasetRefId(dsRef): 

93 """Make an identifying string for given ref""" 

94 dsId = [dsRef.datasetType.name] 

95 dsId += [f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.keys())] 

96 return ":".join(dsId) 

97 

98 

99def _makeDSNode(dsRef, allDatasetRefs, file): 

100 """Make new node for dataset if it does not exist. 

101 

102 Returns node name. 

103 """ 

104 dsRefId = _datasetRefId(dsRef) 

105 nodeName = allDatasetRefs.get(dsRefId) 

106 if nodeName is None: 

107 idx = len(allDatasetRefs) 

108 nodeName = "dsref_{}".format(idx) 

109 allDatasetRefs[dsRefId] = nodeName 

110 _renderDSNode(nodeName, dsRef, file) 

111 return nodeName 

112 

113# ------------------------ 

114# Exported definitions -- 

115# ------------------------ 

116 

117 

118def graph2dot(qgraph, file): 

119 """Convert QuantumGraph into GraphViz digraph. 

120 

121 This method is mostly for documentation/presentation purposes. 

122 

123 Parameters 

124 ---------- 

125 qgraph: `pipe.base.QuantumGraph` 

126 QuantumGraph instance. 

127 file : str or file object 

128 File where GraphViz graph (DOT language) is written, can be a file name 

129 or file object. 

130 

131 Raises 

132 ------ 

133 `OSError` is raised when output file cannot be open. 

134 `ImportError` is raised when task class cannot be imported. 

135 """ 

136 # open a file if needed 

137 close = False 

138 if not hasattr(file, "write"): 

139 file = open(file, "w") 

140 close = True 

141 

142 print("digraph QuantumGraph {", file=file) 

143 

144 allDatasetRefs = {} 

145 for taskId, taskDef in enumerate(qgraph.taskGraph): 

146 

147 quanta = qgraph.getQuantaForTask(taskDef) 

148 for qId, quantum in enumerate(quanta): 

149 

150 # node for a task 

151 taskNodeName = "task_{}_{}".format(taskId, qId) 

152 _renderTaskNode(taskNodeName, taskDef, file) 

153 

154 # quantum inputs 

155 for dsRefs in quantum.inputs.values(): 

156 for dsRef in dsRefs: 

157 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

158 _renderEdge(nodeName, taskNodeName, file) 

159 

160 # quantum outputs 

161 for dsRefs in quantum.outputs.values(): 

162 for dsRef in dsRefs: 

163 nodeName = _makeDSNode(dsRef, allDatasetRefs, file) 

164 _renderEdge(taskNodeName, nodeName, file) 

165 

166 print("}", file=file) 

167 if close: 

168 file.close() 

169 

170 

171def pipeline2dot(pipeline, file): 

172 """Convert Pipeline into GraphViz digraph. 

173 

174 This method is mostly for documentation/presentation purposes. 

175 Unlike other methods this method does not validate graph consistency. 

176 

177 Parameters 

178 ---------- 

179 pipeline : `pipe.base.Pipeline` 

180 Pipeline description. 

181 file : str or file object 

182 File where GraphViz graph (DOT language) is written, can be a file name 

183 or file object. 

184 

185 Raises 

186 ------ 

187 `OSError` is raised when output file cannot be open. 

188 `ImportError` is raised when task class cannot be imported. 

189 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

190 provided. 

191 """ 

192 universe = DimensionUniverse() 

193 

194 def expand_dimensions(dimensions): 

195 """Returns expanded list of dimensions, with special skypix treatment. 

196 

197 Parameters 

198 ---------- 

199 dimensions : `list` [`str`] 

200 

201 Returns 

202 ------- 

203 dimensions : `list` [`str`] 

204 """ 

205 dimensions = set(dimensions) 

206 skypix_dim = [] 

207 if "skypix" in dimensions: 

208 dimensions.remove("skypix") 

209 skypix_dim = ["skypix"] 

210 dimensions = universe.extract(dimensions) 

211 return list(dimensions.names) + skypix_dim 

212 

213 # open a file if needed 

214 close = False 

215 if not hasattr(file, "write"): 

216 file = open(file, "w") 

217 close = True 

218 

219 print("digraph Pipeline {", file=file) 

220 

221 allDatasets = set() 

222 if isinstance(pipeline, Pipeline): 

223 pipeline = pipeline.toExpandedPipeline() 

224 for idx, taskDef in enumerate(pipeline): 

225 

226 # node for a task 

227 taskNodeName = "task{}".format(idx) 

228 _renderTaskNode(taskNodeName, taskDef, file, idx) 

229 

230 for attr in iterConnections(taskDef.connections, 'inputs'): 

231 if attr.name not in allDatasets: 

232 dimensions = expand_dimensions(attr.dimensions) 

233 _renderDSTypeNode(attr.name, dimensions, file) 

234 allDatasets.add(attr.name) 

235 _renderEdge(attr.name, taskNodeName, file) 

236 

237 for attr in iterConnections(taskDef.connections, 'prerequisiteInputs'): 

238 if attr.name not in allDatasets: 

239 dimensions = expand_dimensions(attr.dimensions) 

240 _renderDSTypeNode(attr.name, dimensions, file) 

241 allDatasets.add(attr.name) 

242 # use dashed line for prerequisite edges to distinguish them 

243 _renderEdge(attr.name, taskNodeName, file, style="dashed") 

244 

245 for attr in iterConnections(taskDef.connections, 'outputs'): 

246 if attr.name not in allDatasets: 

247 dimensions = expand_dimensions(attr.dimensions) 

248 _renderDSTypeNode(attr.name, dimensions, file) 

249 allDatasets.add(attr.name) 

250 _renderEdge(taskNodeName, attr.name, file) 

251 

252 print("}", file=file) 

253 if close: 

254 file.close()