Coverage for python/lsst/pipe/base/pipeTools.py: 16%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

65 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to manipulate or query pipelines. 

23""" 

24 

25# No one should do import * from this module 

26__all__ = ["isPipelineOrdered", "orderPipeline"] 

27 

28# ------------------------------- 

29# Imports of standard modules -- 

30# ------------------------------- 

31import itertools 

32 

33# ----------------------------- 

34# Imports for other modules -- 

35# ----------------------------- 

36from .connections import iterConnections 

37 

38# ---------------------------------- 

39# Local non-exported definitions -- 

40# ---------------------------------- 

41 

42# ------------------------ 

43# Exported definitions -- 

44# ------------------------ 

45 

46 

47class MissingTaskFactoryError(Exception): 

48 """Exception raised when client fails to provide TaskFactory instance.""" 

49 

50 pass 

51 

52 

53class DuplicateOutputError(Exception): 

54 """Exception raised when Pipeline has more than one task for the same 

55 output. 

56 """ 

57 

58 pass 

59 

60 

61class PipelineDataCycleError(Exception): 

62 """Exception raised when Pipeline has data dependency cycle.""" 

63 

64 pass 

65 

66 

67def isPipelineOrdered(pipeline, taskFactory=None): 

68 """Checks whether tasks in pipeline are correctly ordered. 

69 

70 Pipeline is correctly ordered if for any DatasetType produced by a task 

71 in a pipeline all its consumer tasks are located after producer. 

72 

73 Parameters 

74 ---------- 

75 pipeline : `pipe.base.Pipeline` 

76 Pipeline description. 

77 taskFactory: `pipe.base.TaskFactory`, optional 

78 Instance of an object which knows how to import task classes. It is 

79 only used if pipeline task definitions do not define task classes. 

80 

81 Returns 

82 ------- 

83 True for correctly ordered pipeline, False otherwise. 

84 

85 Raises 

86 ------ 

87 `ImportError` is raised when task class cannot be imported. 

88 `DuplicateOutputError` is raised when there is more than one producer for a 

89 dataset type. 

90 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

91 provided. 

92 """ 

93 # Build a map of DatasetType name to producer's index in a pipeline 

94 producerIndex = {} 

95 for idx, taskDef in enumerate(pipeline): 

96 

97 for attr in iterConnections(taskDef.connections, "outputs"): 

98 if attr.name in producerIndex: 

99 raise DuplicateOutputError( 

100 "DatasetType `{}' appears more than once as output".format(attr.name) 

101 ) 

102 producerIndex[attr.name] = idx 

103 

104 # check all inputs that are also someone's outputs 

105 for idx, taskDef in enumerate(pipeline): 

106 

107 # get task input DatasetTypes, this can only be done via class method 

108 inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs} 

109 for dsTypeDescr in inputs.values(): 

110 # all pre-existing datasets have effective index -1 

111 prodIdx = producerIndex.get(dsTypeDescr.name, -1) 

112 if prodIdx >= idx: 

113 # not good, producer is downstream 

114 return False 

115 

116 return True 

117 

118 

119def orderPipeline(pipeline): 

120 """Re-order tasks in pipeline to satisfy data dependencies. 

121 

122 When possible new ordering keeps original relative order of the tasks. 

123 

124 Parameters 

125 ---------- 

126 pipeline : `list` of `pipe.base.TaskDef` 

127 Pipeline description. 

128 

129 Returns 

130 ------- 

131 Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects). 

132 

133 Raises 

134 ------ 

135 `DuplicateOutputError` is raised when there is more than one producer for a 

136 dataset type. 

137 `PipelineDataCycleError` is also raised when pipeline has dependency 

138 cycles. `MissingTaskFactoryError` is raised when TaskFactory is needed but 

139 not provided. 

140 """ 

141 

142 # This is a modified version of Kahn's algorithm that preserves order 

143 

144 # build mapping of the tasks to their inputs and outputs 

145 inputs = {} # maps task index to its input DatasetType names 

146 outputs = {} # maps task index to its output DatasetType names 

147 allInputs = set() # all inputs of all tasks 

148 allOutputs = set() # all outputs of all tasks 

149 for idx, taskDef in enumerate(pipeline): 

150 # task outputs 

151 dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs} 

152 for dsTypeDescr in dsMap.values(): 

153 if dsTypeDescr.name in allOutputs: 

154 raise DuplicateOutputError( 

155 "DatasetType `{}' appears more than once as output".format(dsTypeDescr.name) 

156 ) 

157 outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values()) 

158 allOutputs.update(outputs[idx]) 

159 

160 # task inputs 

161 connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs) 

162 dsMap = [getattr(taskDef.connections, name).name for name in connectionInputs] 

163 inputs[idx] = set(dsMap) 

164 allInputs.update(inputs[idx]) 

165 

166 # for simplicity add pseudo-node which is a producer for all pre-existing 

167 # inputs, its index is -1 

168 preExisting = allInputs - allOutputs 

169 outputs[-1] = preExisting 

170 

171 # Set of nodes with no incoming edges, initially set to pseudo-node 

172 queue = [-1] 

173 result = [] 

174 while queue: 

175 

176 # move to final list, drop -1 

177 idx = queue.pop(0) 

178 if idx >= 0: 

179 result.append(idx) 

180 

181 # remove task outputs from other tasks inputs 

182 thisTaskOutputs = outputs.get(idx, set()) 

183 for taskInputs in inputs.values(): 

184 taskInputs -= thisTaskOutputs 

185 

186 # find all nodes with no incoming edges and move them to the queue 

187 topNodes = [key for key, value in inputs.items() if not value] 

188 queue += topNodes 

189 for key in topNodes: 

190 del inputs[key] 

191 

192 # keep queue ordered 

193 queue.sort() 

194 

195 # if there is something left it means cycles 

196 if inputs: 

197 # format it in usable way 

198 loops = [] 

199 for idx, inputNames in inputs.items(): 

200 taskName = pipeline[idx].label 

201 outputNames = outputs[idx] 

202 edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames) 

203 loops.append(edge) 

204 raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops)) 

205 

206 return [pipeline[idx] for idx in result]