Coverage for python/lsst/pipe/base/pipeTools.py: 14%

69 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 11:14 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Module defining few methods to manipulate or query pipelines. 

23""" 

24 

25from __future__ import annotations 

26 

27# No one should do import * from this module 

28__all__ = ["isPipelineOrdered", "orderPipeline"] 

29 

30# ------------------------------- 

31# Imports of standard modules -- 

32# ------------------------------- 

33import itertools 

34from collections.abc import Iterable 

35from typing import TYPE_CHECKING 

36 

37# ----------------------------- 

38# Imports for other modules -- 

39# ----------------------------- 

40from .connections import iterConnections 

41 

42if TYPE_CHECKING: 

43 from .pipeline import Pipeline, TaskDef 

44 from .taskFactory import TaskFactory 

45 

46# ---------------------------------- 

47# Local non-exported definitions -- 

48# ---------------------------------- 

49 

50# ------------------------ 

51# Exported definitions -- 

52# ------------------------ 

53 

54 

55class MissingTaskFactoryError(Exception): 

56 """Exception raised when client fails to provide TaskFactory instance.""" 

57 

58 pass 

59 

60 

61class DuplicateOutputError(Exception): 

62 """Exception raised when Pipeline has more than one task for the same 

63 output. 

64 """ 

65 

66 pass 

67 

68 

69class PipelineDataCycleError(Exception): 

70 """Exception raised when Pipeline has data dependency cycle.""" 

71 

72 pass 

73 

74 

75def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool: 

76 """Check whether tasks in pipeline are correctly ordered. 

77 

78 Pipeline is correctly ordered if for any DatasetType produced by a task 

79 in a pipeline all its consumer tasks are located after producer. 

80 

81 Parameters 

82 ---------- 

83 pipeline : `pipe.base.Pipeline` 

84 Pipeline description. 

85 taskFactory: `pipe.base.TaskFactory`, optional 

86 Instance of an object which knows how to import task classes. It is 

87 only used if pipeline task definitions do not define task classes. 

88 

89 Returns 

90 ------- 

91 True for correctly ordered pipeline, False otherwise. 

92 

93 Raises 

94 ------ 

95 ImportError 

96 Raised when task class cannot be imported. 

97 DuplicateOutputError 

98 Raised when there is more than one producer for a dataset type. 

99 MissingTaskFactoryError 

100 Raised when TaskFactory is needed but not provided. 

101 """ 

102 # Build a map of DatasetType name to producer's index in a pipeline 

103 producerIndex = {} 

104 for idx, taskDef in enumerate(pipeline): 

105 for attr in iterConnections(taskDef.connections, "outputs"): 

106 if attr.name in producerIndex: 

107 raise DuplicateOutputError( 

108 "DatasetType `{}' appears more than once as output".format(attr.name) 

109 ) 

110 producerIndex[attr.name] = idx 

111 

112 # check all inputs that are also someone's outputs 

113 for idx, taskDef in enumerate(pipeline): 

114 # get task input DatasetTypes, this can only be done via class method 

115 inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs} 

116 for dsTypeDescr in inputs.values(): 

117 # all pre-existing datasets have effective index -1 

118 prodIdx = producerIndex.get(dsTypeDescr.name, -1) 

119 if prodIdx >= idx: 

120 # not good, producer is downstream 

121 return False 

122 

123 return True 

124 

125 

126def orderPipeline(pipeline: list[TaskDef]) -> list[TaskDef]: 

127 """Re-order tasks in pipeline to satisfy data dependencies. 

128 

129 When possible new ordering keeps original relative order of the tasks. 

130 

131 Parameters 

132 ---------- 

133 pipeline : `list` of `pipe.base.TaskDef` 

134 Pipeline description. 

135 

136 Returns 

137 ------- 

138 Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects). 

139 

140 Raises 

141 ------ 

142 `DuplicateOutputError` is raised when there is more than one producer for a 

143 dataset type. 

144 `PipelineDataCycleError` is also raised when pipeline has dependency 

145 cycles. `MissingTaskFactoryError` is raised when `TaskFactory` is needed 

146 but not provided. 

147 """ 

148 # This is a modified version of Kahn's algorithm that preserves order 

149 

150 # build mapping of the tasks to their inputs and outputs 

151 inputs = {} # maps task index to its input DatasetType names 

152 outputs = {} # maps task index to its output DatasetType names 

153 allInputs = set() # all inputs of all tasks 

154 allOutputs = set() # all outputs of all tasks 

155 dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task 

156 for idx, taskDef in enumerate(pipeline): 

157 # task outputs 

158 dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs} 

159 for dsTypeDescr in dsMap.values(): 

160 if dsTypeDescr.name in allOutputs: 

161 raise DuplicateOutputError( 

162 f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an " 

163 f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'." 

164 ) 

165 dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label 

166 outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values()) 

167 allOutputs.update(outputs[idx]) 

168 

169 # task inputs 

170 connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs) 

171 inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs) 

172 allInputs.update(inputs[idx]) 

173 

174 # for simplicity add pseudo-node which is a producer for all pre-existing 

175 # inputs, its index is -1 

176 preExisting = allInputs - allOutputs 

177 outputs[-1] = preExisting 

178 

179 # Set of nodes with no incoming edges, initially set to pseudo-node 

180 queue = [-1] 

181 result = [] 

182 while queue: 

183 # move to final list, drop -1 

184 idx = queue.pop(0) 

185 if idx >= 0: 

186 result.append(idx) 

187 

188 # remove task outputs from other tasks inputs 

189 thisTaskOutputs = outputs.get(idx, set()) 

190 for taskInputs in inputs.values(): 

191 taskInputs -= thisTaskOutputs 

192 

193 # find all nodes with no incoming edges and move them to the queue 

194 topNodes = [key for key, value in inputs.items() if not value] 

195 queue += topNodes 

196 for key in topNodes: 

197 del inputs[key] 

198 

199 # keep queue ordered 

200 queue.sort() 

201 

202 # if there is something left it means cycles 

203 if inputs: 

204 # format it in usable way 

205 loops = [] 

206 for idx, inputNames in inputs.items(): 

207 taskName = pipeline[idx].label 

208 outputNames = outputs[idx] 

209 edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames) 

210 loops.append(edge) 

211 raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops)) 

212 

213 return [pipeline[idx] for idx in result]