Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to manipulate or query pipelines. 

23""" 

24 

25# No one should do import * from this module 

26__all__ = ["isPipelineOrdered", "orderPipeline"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31import itertools 

32 

33# ----------------------------- 

34# Imports for other modules -- 

35# ----------------------------- 

36from .connections import iterConnections 

37 

38# ---------------------------------- 

39# Local non-exported definitions -- 

40# ---------------------------------- 

41 

42 

43def _loadTaskClass(taskDef, taskFactory): 

44 """Import task class if necessary. 

45 

46 Raises 

47 ------ 

48 `ImportError` is raised when task class cannot be imported. 

49 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

50 provided. 

51 """ 

52 taskClass = taskDef.taskClass 

53 if not taskClass: 

54 if not taskFactory: 

55 raise MissingTaskFactoryError("Task class is not defined but task " 

56 "factory instance is not provided") 

57 taskClass = taskFactory.loadTaskClass(taskDef.taskName) 

58 return taskClass 

59 

60# ------------------------ 

61# Exported definitions -- 

62# ------------------------ 

63 

64 

65class MissingTaskFactoryError(Exception): 

66 """Exception raised when client fails to provide TaskFactory instance. 

67 """ 

68 pass 

69 

70 

71class DuplicateOutputError(Exception): 

72 """Exception raised when Pipeline has more than one task for the same 

73 output. 

74 """ 

75 pass 

76 

77 

78class PipelineDataCycleError(Exception): 

79 """Exception raised when Pipeline has data dependency cycle. 

80 """ 

81 pass 

82 

83 

84def isPipelineOrdered(pipeline, taskFactory=None): 

85 """Checks whether tasks in pipeline are correctly ordered. 

86 

87 Pipeline is correctly ordered if for any DatasetType produced by a task 

88 in a pipeline all its consumer tasks are located after producer. 

89 

90 Parameters 

91 ---------- 

92 pipeline : `pipe.base.Pipeline` 

93 Pipeline description. 

94 taskFactory: `pipe.base.TaskFactory`, optional 

95 Instance of an object which knows how to import task classes. It is 

96 only used if pipeline task definitions do not define task classes. 

97 

98 Returns 

99 ------- 

100 True for correctly ordered pipeline, False otherwise. 

101 

102 Raises 

103 ------ 

104 `ImportError` is raised when task class cannot be imported. 

105 `DuplicateOutputError` is raised when there is more than one producer for a 

106 dataset type. 

107 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

108 provided. 

109 """ 

110 # Build a map of DatasetType name to producer's index in a pipeline 

111 producerIndex = {} 

112 for idx, taskDef in enumerate(pipeline): 

113 

114 for attr in iterConnections(taskDef.connections, 'outputs'): 

115 if attr.name in producerIndex: 

116 raise DuplicateOutputError("DatasetType `{}' appears more than " 

117 "once as output".format(attr.name)) 

118 producerIndex[attr.name] = idx 

119 

120 # check all inputs that are also someone's outputs 

121 for idx, taskDef in enumerate(pipeline): 

122 

123 # get task input DatasetTypes, this can only be done via class method 

124 inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs} 

125 for dsTypeDescr in inputs.values(): 

126 # all pre-existing datasets have effective index -1 

127 prodIdx = producerIndex.get(dsTypeDescr.name, -1) 

128 if prodIdx >= idx: 

129 # not good, producer is downstream 

130 return False 

131 

132 return True 

133 

134 

135def orderPipeline(pipeline): 

136 """Re-order tasks in pipeline to satisfy data dependencies. 

137 

138 When possible new ordering keeps original relative order of the tasks. 

139 

140 Parameters 

141 ---------- 

142 pipeline : `list` of `pipe.base.TaskDef` 

143 Pipeline description. 

144 

145 Returns 

146 ------- 

147 Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects). 

148 

149 Raises 

150 ------ 

151 `DuplicateOutputError` is raised when there is more than one producer for a 

152 dataset type. 

153 `PipelineDataCycleError` is also raised when pipeline has dependency 

154 cycles. `MissingTaskFactoryError` is raised when TaskFactory is needed but 

155 not provided. 

156 """ 

157 

158 # This is a modified version of Kahn's algorithm that preserves order 

159 

160 # build mapping of the tasks to their inputs and outputs 

161 inputs = {} # maps task index to its input DatasetType names 

162 outputs = {} # maps task index to its output DatasetType names 

163 allInputs = set() # all inputs of all tasks 

164 allOutputs = set() # all outputs of all tasks 

165 for idx, taskDef in enumerate(pipeline): 

166 # task outputs 

167 dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs} 

168 for dsTypeDescr in dsMap.values(): 

169 if dsTypeDescr.name in allOutputs: 

170 raise DuplicateOutputError("DatasetType `{}' appears more than " 

171 "once as output".format(dsTypeDescr.name)) 

172 outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values()) 

173 allOutputs.update(outputs[idx]) 

174 

175 # task inputs 

176 connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs) 

177 dsMap = [getattr(taskDef.connections, name).name for name in connectionInputs] 

178 inputs[idx] = set(dsMap) 

179 allInputs.update(inputs[idx]) 

180 

181 # for simplicity add pseudo-node which is a producer for all pre-existing 

182 # inputs, its index is -1 

183 preExisting = allInputs - allOutputs 

184 outputs[-1] = preExisting 

185 

186 # Set of nodes with no incoming edges, initially set to pseudo-node 

187 queue = [-1] 

188 result = [] 

189 while queue: 

190 

191 # move to final list, drop -1 

192 idx = queue.pop(0) 

193 if idx >= 0: 

194 result.append(idx) 

195 

196 # remove task outputs from other tasks inputs 

197 thisTaskOutputs = outputs.get(idx, set()) 

198 for taskInputs in inputs.values(): 

199 taskInputs -= thisTaskOutputs 

200 

201 # find all nodes with no incoming edges and move them to the queue 

202 topNodes = [key for key, value in inputs.items() if not value] 

203 queue += topNodes 

204 for key in topNodes: 

205 del inputs[key] 

206 

207 # keep queue ordered 

208 queue.sort() 

209 

210 # if there is something left it means cycles 

211 if inputs: 

212 # format it in usable way 

213 loops = [] 

214 for idx, inputNames in inputs.items(): 

215 taskName = pipeline[idx].label 

216 outputNames = outputs[idx] 

217 edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames) 

218 loops.append(edge) 

219 raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops)) 

220 

221 return [pipeline[idx] for idx in result]