lsst.pipe.base  18.1.0-6-g48bdcd3
pipeTools.py
Go to the documentation of this file.
1 # This file is part of pipe_base.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (http://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 
22 """Module defining few methods to manipulate or query pipelines.
23 """
24 
25 # No one should do import * from this module
26 __all__ = ["isPipelineOrdered", "orderPipeline"]
27 
28 # -------------------------------
29 # Imports of standard modules --
30 # -------------------------------
31 import itertools
32 
33 # -----------------------------
34 # Imports for other modules --
35 # -----------------------------
36 from .pipeline import Pipeline
37 from .connections import iterConnections
38 
39 # ----------------------------------
40 # Local non-exported definitions --
41 # ----------------------------------
42 
43 
44 def _loadTaskClass(taskDef, taskFactory):
45  """Import task class if necessary.
46 
47  Raises
48  ------
49  `ImportError` is raised when task class cannot be imported.
50  `MissingTaskFactoryError` is raised when TaskFactory is needed but not
51  provided.
52  """
53  taskClass = taskDef.taskClass
54  if not taskClass:
55  if not taskFactory:
56  raise MissingTaskFactoryError("Task class is not defined but task "
57  "factory instance is not provided")
58  taskClass = taskFactory.loadTaskClass(taskDef.taskName)
59  return taskClass
60 
61 # ------------------------
62 # Exported definitions --
63 # ------------------------
64 
65 
66 class MissingTaskFactoryError(Exception):
67  """Exception raised when client fails to provide TaskFactory instance.
68  """
69  pass
70 
71 
72 class DuplicateOutputError(Exception):
73  """Exception raised when Pipeline has more than one task for the same
74  output.
75  """
76  pass
77 
78 
79 class PipelineDataCycleError(Exception):
80  """Exception raised when Pipeline has data dependency cycle.
81  """
82  pass
83 
84 
85 def isPipelineOrdered(pipeline, taskFactory=None):
86  """Checks whether tasks in pipeline are correctly ordered.
87 
88  Pipeline is correctly ordered if for any DatasetType produced by a task
89  in a pipeline all its consumer tasks are located after producer.
90 
91  Parameters
92  ----------
93  pipeline : `pipe.base.Pipeline`
94  Pipeline description.
95  taskFactory: `pipe.base.TaskFactory`, optional
96  Instance of an object which knows how to import task classes. It is
97  only used if pipeline task definitions do not define task classes.
98 
99  Returns
100  -------
101  True for correctly ordered pipeline, False otherwise.
102 
103  Raises
104  ------
105  `ImportError` is raised when task class cannot be imported.
106  `DuplicateOutputError` is raised when there is more than one producer for a
107  dataset type.
108  `MissingTaskFactoryError` is raised when TaskFactory is needed but not
109  provided.
110  """
111  # Build a map of DatasetType name to producer's index in a pipeline
112  producerIndex = {}
113  for idx, taskDef in enumerate(pipeline):
114 
115  for attr in iterConnections(taskDef.connections, 'outputs'):
116  if attr.name in producerIndex:
117  raise DuplicateOutputError("DatasetType `{}' appears more than "
118  "once as output".format(attr.name))
119  producerIndex[attr.name] = idx
120 
121  # check all inputs that are also someone's outputs
122  for idx, taskDef in enumerate(pipeline):
123 
124  # get task input DatasetTypes, this can only be done via class method
125  inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
126  for dsTypeDescr in inputs.values():
127  # all pre-existing datasets have effective index -1
128  prodIdx = producerIndex.get(dsTypeDescr.name, -1)
129  if prodIdx >= idx:
130  # not good, producer is downstream
131  return False
132 
133  return True
134 
135 
136 def orderPipeline(pipeline, taskFactory=None):
137  """Re-order tasks in pipeline to satisfy data dependencies.
138 
139  When possible new ordering keeps original relative order of the tasks.
140 
141  Parameters
142  ----------
143  pipeline : `pipe.base.Pipeline`
144  Pipeline description.
145  taskFactory: `pipe.base.TaskFactory`, optional
146  Instance of an object which knows how to import task classes. It is
147  only used if pipeline task definitions do not define task classes.
148 
149  Returns
150  -------
151  Correctly ordered pipeline (`pipe.base.Pipeline` instance).
152 
153  Raises
154  ------
155  `ImportError` is raised when task class cannot be imported.
156  `DuplicateOutputError` is raised when there is more than one producer for a
157  dataset type.
158  `PipelineDataCycleError` is also raised when pipeline has dependency
159  cycles. `MissingTaskFactoryError` is raised when TaskFactory is needed but
160  not provided.
161  """
162 
163  # This is a modified version of Kahn's algorithm that preserves order
164 
165  # build mapping of the tasks to their inputs and outputs
166  inputs = {} # maps task index to its input DatasetType names
167  outputs = {} # maps task index to its output DatasetType names
168  allInputs = set() # all inputs of all tasks
169  allOutputs = set() # all outputs of all tasks
170  for idx, taskDef in enumerate(pipeline):
171  # task outputs
172  dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
173  for dsTypeDescr in dsMap.values():
174  if dsTypeDescr.name in allOutputs:
175  raise DuplicateOutputError("DatasetType `{}' appears more than "
176  "once as output".format(dsTypeDescr.name))
177  outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
178  allOutputs.update(outputs[idx])
179 
180  # task inputs
181  connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
182  dsMap = [getattr(taskDef.connections, name).name for name in connectionInputs]
183  inputs[idx] = set(dsMap)
184  allInputs.update(inputs[idx])
185 
186  # for simplicity add pseudo-node which is a producer for all pre-existing
187  # inputs, its index is -1
188  preExisting = allInputs - allOutputs
189  outputs[-1] = preExisting
190 
191  # Set of nodes with no incoming edges, initially set to pseudo-node
192  queue = [-1]
193  result = []
194  while queue:
195 
196  # move to final list, drop -1
197  idx = queue.pop(0)
198  if idx >= 0:
199  result.append(idx)
200 
201  # remove task outputs from other tasks inputs
202  thisTaskOutputs = outputs.get(idx, set())
203  for taskInputs in inputs.values():
204  taskInputs -= thisTaskOutputs
205 
206  # find all nodes with no incoming edges and move them to the queue
207  topNodes = [key for key, value in inputs.items() if not value]
208  queue += topNodes
209  for key in topNodes:
210  del inputs[key]
211 
212  # keep queue ordered
213  queue.sort()
214 
215  # if there is something left it means cycles
216  if inputs:
217  # format it in usable way
218  loops = []
219  for idx, inputNames in inputs.items():
220  taskName = pipeline[idx].label
221  outputNames = outputs[idx]
222  edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
223  loops.append(edge)
224  raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))
225 
226  return Pipeline(pipeline[idx] for idx in result)
def isPipelineOrdered(pipeline, taskFactory=None)
Definition: pipeTools.py:85
def orderPipeline(pipeline, taskFactory=None)
Definition: pipeTools.py:136