Coverage for python/lsst/pipe/base/pipeTools.py: 11%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

65 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to manipulate or query pipelines. 

23""" 

24 

25# No one should do import * from this module 

26__all__ = ["isPipelineOrdered", "orderPipeline"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31import itertools 

32 

33# ----------------------------- 

34# Imports for other modules -- 

35# ----------------------------- 

36from .connections import iterConnections 

37 

38# ---------------------------------- 

39# Local non-exported definitions -- 

40# ---------------------------------- 

41 

42# ------------------------ 

43# Exported definitions -- 

44# ------------------------ 

45 

46 

47class MissingTaskFactoryError(Exception): 

48 """Exception raised when client fails to provide TaskFactory instance. 

49 """ 

50 pass 

51 

52 

53class DuplicateOutputError(Exception): 

54 """Exception raised when Pipeline has more than one task for the same 

55 output. 

56 """ 

57 pass 

58 

59 

60class PipelineDataCycleError(Exception): 

61 """Exception raised when Pipeline has data dependency cycle. 

62 """ 

63 pass 

64 

65 

66def isPipelineOrdered(pipeline, taskFactory=None): 

67 """Checks whether tasks in pipeline are correctly ordered. 

68 

69 Pipeline is correctly ordered if for any DatasetType produced by a task 

70 in a pipeline all its consumer tasks are located after producer. 

71 

72 Parameters 

73 ---------- 

74 pipeline : `pipe.base.Pipeline` 

75 Pipeline description. 

76 taskFactory: `pipe.base.TaskFactory`, optional 

77 Instance of an object which knows how to import task classes. It is 

78 only used if pipeline task definitions do not define task classes. 

79 

80 Returns 

81 ------- 

82 True for correctly ordered pipeline, False otherwise. 

83 

84 Raises 

85 ------ 

86 `ImportError` is raised when task class cannot be imported. 

87 `DuplicateOutputError` is raised when there is more than one producer for a 

88 dataset type. 

89 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

90 provided. 

91 """ 

92 # Build a map of DatasetType name to producer's index in a pipeline 

93 producerIndex = {} 

94 for idx, taskDef in enumerate(pipeline): 

95 

96 for attr in iterConnections(taskDef.connections, 'outputs'): 

97 if attr.name in producerIndex: 

98 raise DuplicateOutputError("DatasetType `{}' appears more than " 

99 "once as output".format(attr.name)) 

100 producerIndex[attr.name] = idx 

101 

102 # check all inputs that are also someone's outputs 

103 for idx, taskDef in enumerate(pipeline): 

104 

105 # get task input DatasetTypes, this can only be done via class method 

106 inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs} 

107 for dsTypeDescr in inputs.values(): 

108 # all pre-existing datasets have effective index -1 

109 prodIdx = producerIndex.get(dsTypeDescr.name, -1) 

110 if prodIdx >= idx: 

111 # not good, producer is downstream 

112 return False 

113 

114 return True 

115 

116 

117def orderPipeline(pipeline): 

118 """Re-order tasks in pipeline to satisfy data dependencies. 

119 

120 When possible new ordering keeps original relative order of the tasks. 

121 

122 Parameters 

123 ---------- 

124 pipeline : `list` of `pipe.base.TaskDef` 

125 Pipeline description. 

126 

127 Returns 

128 ------- 

129 Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects). 

130 

131 Raises 

132 ------ 

133 `DuplicateOutputError` is raised when there is more than one producer for a 

134 dataset type. 

135 `PipelineDataCycleError` is also raised when pipeline has dependency 

136 cycles. `MissingTaskFactoryError` is raised when TaskFactory is needed but 

137 not provided. 

138 """ 

139 

140 # This is a modified version of Kahn's algorithm that preserves order 

141 

142 # build mapping of the tasks to their inputs and outputs 

143 inputs = {} # maps task index to its input DatasetType names 

144 outputs = {} # maps task index to its output DatasetType names 

145 allInputs = set() # all inputs of all tasks 

146 allOutputs = set() # all outputs of all tasks 

147 for idx, taskDef in enumerate(pipeline): 

148 # task outputs 

149 dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs} 

150 for dsTypeDescr in dsMap.values(): 

151 if dsTypeDescr.name in allOutputs: 

152 raise DuplicateOutputError("DatasetType `{}' appears more than " 

153 "once as output".format(dsTypeDescr.name)) 

154 outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values()) 

155 allOutputs.update(outputs[idx]) 

156 

157 # task inputs 

158 connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs) 

159 dsMap = [getattr(taskDef.connections, name).name for name in connectionInputs] 

160 inputs[idx] = set(dsMap) 

161 allInputs.update(inputs[idx]) 

162 

163 # for simplicity add pseudo-node which is a producer for all pre-existing 

164 # inputs, its index is -1 

165 preExisting = allInputs - allOutputs 

166 outputs[-1] = preExisting 

167 

168 # Set of nodes with no incoming edges, initially set to pseudo-node 

169 queue = [-1] 

170 result = [] 

171 while queue: 

172 

173 # move to final list, drop -1 

174 idx = queue.pop(0) 

175 if idx >= 0: 

176 result.append(idx) 

177 

178 # remove task outputs from other tasks inputs 

179 thisTaskOutputs = outputs.get(idx, set()) 

180 for taskInputs in inputs.values(): 

181 taskInputs -= thisTaskOutputs 

182 

183 # find all nodes with no incoming edges and move them to the queue 

184 topNodes = [key for key, value in inputs.items() if not value] 

185 queue += topNodes 

186 for key in topNodes: 

187 del inputs[key] 

188 

189 # keep queue ordered 

190 queue.sort() 

191 

192 # if there is something left it means cycles 

193 if inputs: 

194 # format it in usable way 

195 loops = [] 

196 for idx, inputNames in inputs.items(): 

197 taskName = pipeline[idx].label 

198 outputNames = outputs[idx] 

199 edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames) 

200 loops.append(edge) 

201 raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops)) 

202 

203 return [pipeline[idx] for idx in result]