Coverage for python/lsst/pipe/base/pipeTools.py: 18%

69 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-05-24 02:42 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to manipulate or query pipelines. 

23""" 

24 

25from __future__ import annotations 

26 

27# No one should do import * from this module 

28__all__ = ["isPipelineOrdered", "orderPipeline"] 

29 

30# ------------------------------- 

31# Imports of standard modules -- 

32# ------------------------------- 

33import itertools 

34from typing import TYPE_CHECKING, Iterable, List, Optional, Union 

35 

36# ----------------------------- 

37# Imports for other modules -- 

38# ----------------------------- 

39from .connections import iterConnections 

40 

41if TYPE_CHECKING: 41 ↛ 42line 41 didn't jump to line 42, because the condition on line 41 was never true

42 from .pipeline import Pipeline, TaskDef 

43 from .taskFactory import TaskFactory 

44 

45# ---------------------------------- 

46# Local non-exported definitions -- 

47# ---------------------------------- 

48 

49# ------------------------ 

50# Exported definitions -- 

51# ------------------------ 

52 

53 

54class MissingTaskFactoryError(Exception): 

55 """Exception raised when client fails to provide TaskFactory instance.""" 

56 

57 pass 

58 

59 

60class DuplicateOutputError(Exception): 

61 """Exception raised when Pipeline has more than one task for the same 

62 output. 

63 """ 

64 

65 pass 

66 

67 

68class PipelineDataCycleError(Exception): 

69 """Exception raised when Pipeline has data dependency cycle.""" 

70 

71 pass 

72 

73 

74def isPipelineOrdered( 

75 pipeline: Union[Pipeline, Iterable[TaskDef]], taskFactory: Optional[TaskFactory] = None 

76) -> bool: 

77 """Checks whether tasks in pipeline are correctly ordered. 

78 

79 Pipeline is correctly ordered if for any DatasetType produced by a task 

80 in a pipeline all its consumer tasks are located after producer. 

81 

82 Parameters 

83 ---------- 

84 pipeline : `pipe.base.Pipeline` 

85 Pipeline description. 

86 taskFactory: `pipe.base.TaskFactory`, optional 

87 Instance of an object which knows how to import task classes. It is 

88 only used if pipeline task definitions do not define task classes. 

89 

90 Returns 

91 ------- 

92 True for correctly ordered pipeline, False otherwise. 

93 

94 Raises 

95 ------ 

96 `ImportError` is raised when task class cannot be imported. 

97 `DuplicateOutputError` is raised when there is more than one producer for a 

98 dataset type. 

99 `MissingTaskFactoryError` is raised when TaskFactory is needed but not 

100 provided. 

101 """ 

102 # Build a map of DatasetType name to producer's index in a pipeline 

103 producerIndex = {} 

104 for idx, taskDef in enumerate(pipeline): 

105 

106 for attr in iterConnections(taskDef.connections, "outputs"): 

107 if attr.name in producerIndex: 

108 raise DuplicateOutputError( 

109 "DatasetType `{}' appears more than once as output".format(attr.name) 

110 ) 

111 producerIndex[attr.name] = idx 

112 

113 # check all inputs that are also someone's outputs 

114 for idx, taskDef in enumerate(pipeline): 

115 

116 # get task input DatasetTypes, this can only be done via class method 

117 inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs} 

118 for dsTypeDescr in inputs.values(): 

119 # all pre-existing datasets have effective index -1 

120 prodIdx = producerIndex.get(dsTypeDescr.name, -1) 

121 if prodIdx >= idx: 

122 # not good, producer is downstream 

123 return False 

124 

125 return True 

126 

127 

128def orderPipeline(pipeline: List[TaskDef]) -> List[TaskDef]: 

129 """Re-order tasks in pipeline to satisfy data dependencies. 

130 

131 When possible new ordering keeps original relative order of the tasks. 

132 

133 Parameters 

134 ---------- 

135 pipeline : `list` of `pipe.base.TaskDef` 

136 Pipeline description. 

137 

138 Returns 

139 ------- 

140 Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects). 

141 

142 Raises 

143 ------ 

144 `DuplicateOutputError` is raised when there is more than one producer for a 

145 dataset type. 

146 `PipelineDataCycleError` is also raised when pipeline has dependency 

147 cycles. `MissingTaskFactoryError` is raised when TaskFactory is needed but 

148 not provided. 

149 """ 

150 

151 # This is a modified version of Kahn's algorithm that preserves order 

152 

153 # build mapping of the tasks to their inputs and outputs 

154 inputs = {} # maps task index to its input DatasetType names 

155 outputs = {} # maps task index to its output DatasetType names 

156 allInputs = set() # all inputs of all tasks 

157 allOutputs = set() # all outputs of all tasks 

158 for idx, taskDef in enumerate(pipeline): 

159 # task outputs 

160 dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs} 

161 for dsTypeDescr in dsMap.values(): 

162 if dsTypeDescr.name in allOutputs: 

163 raise DuplicateOutputError( 

164 "DatasetType `{}' appears more than once as output".format(dsTypeDescr.name) 

165 ) 

166 outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values()) 

167 allOutputs.update(outputs[idx]) 

168 

169 # task inputs 

170 connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs) 

171 inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs) 

172 allInputs.update(inputs[idx]) 

173 

174 # for simplicity add pseudo-node which is a producer for all pre-existing 

175 # inputs, its index is -1 

176 preExisting = allInputs - allOutputs 

177 outputs[-1] = preExisting 

178 

179 # Set of nodes with no incoming edges, initially set to pseudo-node 

180 queue = [-1] 

181 result = [] 

182 while queue: 

183 

184 # move to final list, drop -1 

185 idx = queue.pop(0) 

186 if idx >= 0: 

187 result.append(idx) 

188 

189 # remove task outputs from other tasks inputs 

190 thisTaskOutputs = outputs.get(idx, set()) 

191 for taskInputs in inputs.values(): 

192 taskInputs -= thisTaskOutputs 

193 

194 # find all nodes with no incoming edges and move them to the queue 

195 topNodes = [key for key, value in inputs.items() if not value] 

196 queue += topNodes 

197 for key in topNodes: 

198 del inputs[key] 

199 

200 # keep queue ordered 

201 queue.sort() 

202 

203 # if there is something left it means cycles 

204 if inputs: 

205 # format it in usable way 

206 loops = [] 

207 for idx, inputNames in inputs.items(): 

208 taskName = pipeline[idx].label 

209 outputNames = outputs[idx] 

210 edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames) 

211 loops.append(edge) 

212 raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops)) 

213 

214 return [pipeline[idx] for idx in result]