lsst.pipe.base  18.0.0-2-g0ee56d7+8
pipeTools.py
Go to the documentation of this file.
1 # This file is part of pipe_base.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (http://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 
22 """Module defining few methods to manipulate or query pipelines.
23 """
24 
25 # No one should do import * from this module
26 __all__ = ["isPipelineOrdered", "orderPipeline"]
27 
28 # -------------------------------
29 # Imports of standard modules --
30 # -------------------------------
31 
32 # -----------------------------
33 # Imports for other modules --
34 # -----------------------------
35 from .pipeline import Pipeline
36 
37 # ----------------------------------
38 # Local non-exported definitions --
39 # ----------------------------------
40 
41 
42 def _loadTaskClass(taskDef, taskFactory):
43  """Import task class if necessary.
44 
45  Raises
46  ------
47  `ImportError` is raised when task class cannot be imported.
48  `MissingTaskFactoryError` is raised when TaskFactory is needed but not provided.
49  """
50  taskClass = taskDef.taskClass
51  if not taskClass:
52  if not taskFactory:
53  raise MissingTaskFactoryError("Task class is not defined but task "
54  "factory instance is not provided")
55  taskClass = taskFactory.loadTaskClass(taskDef.taskName)
56  return taskClass
57 
58 # ------------------------
59 # Exported definitions --
60 # ------------------------
61 
62 
63 class MissingTaskFactoryError(Exception):
64  """Exception raised when client fails to provide TaskFactory instance.
65  """
66  pass
67 
68 
69 class DuplicateOutputError(Exception):
70  """Exception raised when Pipeline has more than one task for the same
71  output.
72  """
73  pass
74 
75 
76 class PipelineDataCycleError(Exception):
77  """Exception raised when Pipeline has data dependency cycle.
78  """
79  pass
80 
81 
82 def isPipelineOrdered(pipeline, taskFactory=None):
83  """Checks whether tasks in pipeline are correctly ordered.
84 
85  Pipeline is correctly ordered if for any DatasetType produced by a task
86  in a pipeline all its consumer tasks are located after producer.
87 
88  Parameters
89  ----------
90  pipeline : `pipe.base.Pipeline`
91  Pipeline description.
92  taskFactory: `pipe.base.TaskFactory`, optional
93  Instance of an object which knows how to import task classes. It is only
94  used if pipeline task definitions do not define task classes.
95 
96  Returns
97  -------
98  True for correctly ordered pipeline, False otherwise.
99 
100  Raises
101  ------
102  `ImportError` is raised when task class cannot be imported.
103  `DuplicateOutputError` is raised when there is more than one producer for a
104  dataset type.
105  `MissingTaskFactoryError` is raised when TaskFactory is needed but not
106  provided.
107  """
108  # Build a map of DatasetType name to producer's index in a pipeline
109  producerIndex = {}
110  for idx, taskDef in enumerate(pipeline):
111 
112  # we will need task class for next operations, make sure it is loaded
113  taskDef.taskClass = _loadTaskClass(taskDef, taskFactory)
114 
115  # get task output DatasetTypes, this can only be done via class method
116  outputs = taskDef.taskClass.getOutputDatasetTypes(taskDef.config)
117  for dsTypeDescr in outputs.values():
118  if dsTypeDescr.name in producerIndex:
119  raise DuplicateOutputError("DatasetType `{}' appears more than "
120  "once as output".format(dsTypeDescr.name))
121  producerIndex[dsTypeDescr.name] = idx
122 
123  # check all inputs that are also someone's outputs
124  for idx, taskDef in enumerate(pipeline):
125 
126  # get task input DatasetTypes, this can only be done via class method
127  inputs = taskDef.taskClass.getInputDatasetTypes(taskDef.config)
128  for dsTypeDescr in inputs.values():
129  # all pre-existing datasets have effective index -1
130  prodIdx = producerIndex.get(dsTypeDescr.name, -1)
131  if prodIdx >= idx:
132  # not good, producer is downstream
133  return False
134 
135  return True
136 
137 
138 def orderPipeline(pipeline, taskFactory=None):
139  """Re-order tasks in pipeline to satisfy data dependencies.
140 
141  When possible new ordering keeps original relative order of the tasks.
142 
143  Parameters
144  ----------
145  pipeline : `pipe.base.Pipeline`
146  Pipeline description.
147  taskFactory: `pipe.base.TaskFactory`, optional
148  Instance of an object which knows how to import task classes. It is only
149  used if pipeline task definitions do not define task classes.
150 
151  Returns
152  -------
153  Correctly ordered pipeline (`pipe.base.Pipeline` instance).
154 
155  Raises
156  ------
157  `ImportError` is raised when task class cannot be imported.
158  `DuplicateOutputError` is raised when there is more than one producer for a
159  dataset type.
160  `PipelineDataCycleError` is also raised when pipeline has dependency cycles.
161  `MissingTaskFactoryError` is raised when TaskFactory is needed but not
162  provided.
163  """
164 
165  # This is a modified version of Kahn's algorithm that preserves order
166 
167  # build mapping of the tasks to their inputs and outputs
168  inputs = {} # maps task index to its input DatasetType names
169  outputs = {} # maps task index to its output DatasetType names
170  allInputs = set() # all inputs of all tasks
171  allOutputs = set() # all outputs of all tasks
172  for idx, taskDef in enumerate(pipeline):
173 
174  # we will need task class for next operations, make sure it is loaded
175  taskClass = _loadTaskClass(taskDef, taskFactory)
176 
177  # task outputs
178  dsMap = taskClass.getOutputDatasetTypes(taskDef.config)
179  for dsTypeDescr in dsMap.values():
180  if dsTypeDescr.name in allOutputs:
181  raise DuplicateOutputError("DatasetType `{}' appears more than "
182  "once as output".format(dsTypeDescr.name))
183  outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
184  allOutputs.update(outputs[idx])
185 
186  # task inputs
187  dsMap = taskClass.getInputDatasetTypes(taskDef.config)
188  inputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
189  allInputs.update(inputs[idx])
190 
191  # for simplicity add pseudo-node which is a producer for all pre-existing
192  # inputs, its index is -1
193  preExisting = allInputs - allOutputs
194  outputs[-1] = preExisting
195 
196  # Set of nodes with no incoming edges, initially set to pseudo-node
197  queue = [-1]
198  result = []
199  while queue:
200 
201  # move to final list, drop -1
202  idx = queue.pop(0)
203  if idx >= 0:
204  result.append(idx)
205 
206  # remove task outputs from other tasks inputs
207  thisTaskOutputs = outputs.get(idx, set())
208  for taskInputs in inputs.values():
209  taskInputs -= thisTaskOutputs
210 
211  # find all nodes with no incoming edges and move them to the queue
212  topNodes = [key for key, value in inputs.items() if not value]
213  queue += topNodes
214  for key in topNodes:
215  del inputs[key]
216 
217  # keep queue ordered
218  queue.sort()
219 
220  # if there is something left it means cycles
221  if inputs:
222  # format it in usable way
223  loops = []
224  for idx, inputNames in inputs.items():
225  taskName = pipeline[idx].label
226  outputNames = outputs[idx]
227  edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
228  loops.append(edge)
229  raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))
230 
231  return Pipeline(pipeline[idx] for idx in result)
def isPipelineOrdered(pipeline, taskFactory=None)
Definition: pipeTools.py:82
def orderPipeline(pipeline, taskFactory=None)
Definition: pipeTools.py:138