Coverage for python/lsst/pipe/base/pipeTools.py : 11%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to manipulate or query pipelines.
23"""
25# No one should do import * from this module
26__all__ = ["isPipelineOrdered", "orderPipeline"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31import itertools
33# -----------------------------
34# Imports for other modules --
35# -----------------------------
36from .connections import iterConnections
38# ----------------------------------
39# Local non-exported definitions --
40# ----------------------------------
43def _loadTaskClass(taskDef, taskFactory):
44 """Import task class if necessary.
46 Raises
47 ------
48 `ImportError` is raised when task class cannot be imported.
49 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
50 provided.
51 """
52 taskClass = taskDef.taskClass
53 if not taskClass:
54 if not taskFactory:
55 raise MissingTaskFactoryError("Task class is not defined but task "
56 "factory instance is not provided")
57 taskClass = taskFactory.loadTaskClass(taskDef.taskName)
58 return taskClass
60# ------------------------
61# Exported definitions --
62# ------------------------
65class MissingTaskFactoryError(Exception):
66 """Exception raised when client fails to provide TaskFactory instance.
67 """
68 pass
71class DuplicateOutputError(Exception):
72 """Exception raised when Pipeline has more than one task for the same
73 output.
74 """
75 pass
78class PipelineDataCycleError(Exception):
79 """Exception raised when Pipeline has data dependency cycle.
80 """
81 pass
84def isPipelineOrdered(pipeline, taskFactory=None):
85 """Checks whether tasks in pipeline are correctly ordered.
87 Pipeline is correctly ordered if for any DatasetType produced by a task
88 in a pipeline all its consumer tasks are located after producer.
90 Parameters
91 ----------
92 pipeline : `pipe.base.Pipeline`
93 Pipeline description.
94 taskFactory: `pipe.base.TaskFactory`, optional
95 Instance of an object which knows how to import task classes. It is
96 only used if pipeline task definitions do not define task classes.
98 Returns
99 -------
100 True for correctly ordered pipeline, False otherwise.
102 Raises
103 ------
104 `ImportError` is raised when task class cannot be imported.
105 `DuplicateOutputError` is raised when there is more than one producer for a
106 dataset type.
107 `MissingTaskFactoryError` is raised when TaskFactory is needed but not
108 provided.
109 """
110 # Build a map of DatasetType name to producer's index in a pipeline
111 producerIndex = {}
112 for idx, taskDef in enumerate(pipeline):
114 for attr in iterConnections(taskDef.connections, 'outputs'):
115 if attr.name in producerIndex:
116 raise DuplicateOutputError("DatasetType `{}' appears more than "
117 "once as output".format(attr.name))
118 producerIndex[attr.name] = idx
120 # check all inputs that are also someone's outputs
121 for idx, taskDef in enumerate(pipeline):
123 # get task input DatasetTypes, this can only be done via class method
124 inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
125 for dsTypeDescr in inputs.values():
126 # all pre-existing datasets have effective index -1
127 prodIdx = producerIndex.get(dsTypeDescr.name, -1)
128 if prodIdx >= idx:
129 # not good, producer is downstream
130 return False
132 return True
135def orderPipeline(pipeline):
136 """Re-order tasks in pipeline to satisfy data dependencies.
138 When possible new ordering keeps original relative order of the tasks.
140 Parameters
141 ----------
142 pipeline : `list` of `pipe.base.TaskDef`
143 Pipeline description.
145 Returns
146 -------
147 Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects).
149 Raises
150 ------
151 `DuplicateOutputError` is raised when there is more than one producer for a
152 dataset type.
153 `PipelineDataCycleError` is also raised when pipeline has dependency
154 cycles. `MissingTaskFactoryError` is raised when TaskFactory is needed but
155 not provided.
156 """
158 # This is a modified version of Kahn's algorithm that preserves order
160 # build mapping of the tasks to their inputs and outputs
161 inputs = {} # maps task index to its input DatasetType names
162 outputs = {} # maps task index to its output DatasetType names
163 allInputs = set() # all inputs of all tasks
164 allOutputs = set() # all outputs of all tasks
165 for idx, taskDef in enumerate(pipeline):
166 # task outputs
167 dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
168 for dsTypeDescr in dsMap.values():
169 if dsTypeDescr.name in allOutputs:
170 raise DuplicateOutputError("DatasetType `{}' appears more than "
171 "once as output".format(dsTypeDescr.name))
172 outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
173 allOutputs.update(outputs[idx])
175 # task inputs
176 connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
177 dsMap = [getattr(taskDef.connections, name).name for name in connectionInputs]
178 inputs[idx] = set(dsMap)
179 allInputs.update(inputs[idx])
181 # for simplicity add pseudo-node which is a producer for all pre-existing
182 # inputs, its index is -1
183 preExisting = allInputs - allOutputs
184 outputs[-1] = preExisting
186 # Set of nodes with no incoming edges, initially set to pseudo-node
187 queue = [-1]
188 result = []
189 while queue:
191 # move to final list, drop -1
192 idx = queue.pop(0)
193 if idx >= 0:
194 result.append(idx)
196 # remove task outputs from other tasks inputs
197 thisTaskOutputs = outputs.get(idx, set())
198 for taskInputs in inputs.values():
199 taskInputs -= thisTaskOutputs
201 # find all nodes with no incoming edges and move them to the queue
202 topNodes = [key for key, value in inputs.items() if not value]
203 queue += topNodes
204 for key in topNodes:
205 del inputs[key]
207 # keep queue ordered
208 queue.sort()
210 # if there is something left it means cycles
211 if inputs:
212 # format it in usable way
213 loops = []
214 for idx, inputNames in inputs.items():
215 taskName = pipeline[idx].label
216 outputNames = outputs[idx]
217 edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
218 loops.append(edge)
219 raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))
221 return [pipeline[idx] for idx in result]