Coverage for python/lsst/pipe/base/pipeTools.py: 14%
69 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-16 09:02 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-16 09:02 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module defining few methods to manipulate or query pipelines.
23"""
25from __future__ import annotations
27# No one should do import * from this module
28__all__ = ["isPipelineOrdered", "orderPipeline"]
30# -------------------------------
31# Imports of standard modules --
32# -------------------------------
33import itertools
34from collections.abc import Iterable
35from typing import TYPE_CHECKING
37# -----------------------------
38# Imports for other modules --
39# -----------------------------
40from .connections import iterConnections
42if TYPE_CHECKING:
43 from .pipeline import Pipeline, TaskDef
44 from .taskFactory import TaskFactory
46# ----------------------------------
47# Local non-exported definitions --
48# ----------------------------------
50# ------------------------
51# Exported definitions --
52# ------------------------
55class MissingTaskFactoryError(Exception):
56 """Exception raised when client fails to provide TaskFactory instance."""
58 pass
61class DuplicateOutputError(Exception):
62 """Exception raised when Pipeline has more than one task for the same
63 output.
64 """
66 pass
69class PipelineDataCycleError(Exception):
70 """Exception raised when Pipeline has data dependency cycle."""
72 pass
75def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool:
76 """Check whether tasks in pipeline are correctly ordered.
78 Pipeline is correctly ordered if for any DatasetType produced by a task
79 in a pipeline all its consumer tasks are located after producer.
81 Parameters
82 ----------
83 pipeline : `pipe.base.Pipeline`
84 Pipeline description.
85 taskFactory: `pipe.base.TaskFactory`, optional
86 Instance of an object which knows how to import task classes. It is
87 only used if pipeline task definitions do not define task classes.
89 Returns
90 -------
91 True for correctly ordered pipeline, False otherwise.
93 Raises
94 ------
95 ImportError
96 Raised when task class cannot be imported.
97 DuplicateOutputError
98 Raised when there is more than one producer for a dataset type.
99 MissingTaskFactoryError
100 Raised when TaskFactory is needed but not provided.
101 """
102 # Build a map of DatasetType name to producer's index in a pipeline
103 producerIndex = {}
104 for idx, taskDef in enumerate(pipeline):
105 for attr in iterConnections(taskDef.connections, "outputs"):
106 if attr.name in producerIndex:
107 raise DuplicateOutputError(
108 "DatasetType `{}' appears more than once as output".format(attr.name)
109 )
110 producerIndex[attr.name] = idx
112 # check all inputs that are also someone's outputs
113 for idx, taskDef in enumerate(pipeline):
114 # get task input DatasetTypes, this can only be done via class method
115 inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
116 for dsTypeDescr in inputs.values():
117 # all pre-existing datasets have effective index -1
118 prodIdx = producerIndex.get(dsTypeDescr.name, -1)
119 if prodIdx >= idx:
120 # not good, producer is downstream
121 return False
123 return True
126def orderPipeline(pipeline: list[TaskDef]) -> list[TaskDef]:
127 """Re-order tasks in pipeline to satisfy data dependencies.
129 When possible new ordering keeps original relative order of the tasks.
131 Parameters
132 ----------
133 pipeline : `list` of `pipe.base.TaskDef`
134 Pipeline description.
136 Returns
137 -------
138 Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects).
140 Raises
141 ------
142 `DuplicateOutputError` is raised when there is more than one producer for a
143 dataset type.
144 `PipelineDataCycleError` is also raised when pipeline has dependency
145 cycles. `MissingTaskFactoryError` is raised when `TaskFactory` is needed
146 but not provided.
147 """
148 # This is a modified version of Kahn's algorithm that preserves order
150 # build mapping of the tasks to their inputs and outputs
151 inputs = {} # maps task index to its input DatasetType names
152 outputs = {} # maps task index to its output DatasetType names
153 allInputs = set() # all inputs of all tasks
154 allOutputs = set() # all outputs of all tasks
155 dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task
156 for idx, taskDef in enumerate(pipeline):
157 # task outputs
158 dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
159 for dsTypeDescr in dsMap.values():
160 if dsTypeDescr.name in allOutputs:
161 raise DuplicateOutputError(
162 f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an "
163 f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'."
164 )
165 dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label
166 outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
167 allOutputs.update(outputs[idx])
169 # task inputs
170 connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
171 inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs)
172 allInputs.update(inputs[idx])
174 # for simplicity add pseudo-node which is a producer for all pre-existing
175 # inputs, its index is -1
176 preExisting = allInputs - allOutputs
177 outputs[-1] = preExisting
179 # Set of nodes with no incoming edges, initially set to pseudo-node
180 queue = [-1]
181 result = []
182 while queue:
183 # move to final list, drop -1
184 idx = queue.pop(0)
185 if idx >= 0:
186 result.append(idx)
188 # remove task outputs from other tasks inputs
189 thisTaskOutputs = outputs.get(idx, set())
190 for taskInputs in inputs.values():
191 taskInputs -= thisTaskOutputs
193 # find all nodes with no incoming edges and move them to the queue
194 topNodes = [key for key, value in inputs.items() if not value]
195 queue += topNodes
196 for key in topNodes:
197 del inputs[key]
199 # keep queue ordered
200 queue.sort()
202 # if there is something left it means cycles
203 if inputs:
204 # format it in usable way
205 loops = []
206 for idx, inputNames in inputs.items():
207 taskName = pipeline[idx].label
208 outputNames = outputs[idx]
209 edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
210 loops.append(edge)
211 raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))
213 return [pipeline[idx] for idx in result]