Coverage for python/lsst/pipe/base/pipeline.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import FrozenSet, Mapping, Union, Generator, TYPE_CHECKING
35import copy
37# -----------------------------
38# Imports for other modules --
39from lsst.daf.butler import DatasetType, Registry, SkyPixDimension
40from lsst.utils import doImport
41from .configOverrides import ConfigOverrides
42from .connections import iterConnections
43from .pipelineTask import PipelineTask
45from . import pipelineIR
46from . import pipeTools
48if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true
49 from lsst.obs.base.instrument import Instrument
51# ----------------------------------
52# Local non-exported definitions --
53# ----------------------------------
55# ------------------------
56# Exported definitions --
57# ------------------------
60class TaskDef:
61 """TaskDef is a collection of information about task needed by Pipeline.
63 The information includes task name, configuration object and optional
64 task class. This class is just a collection of attributes and it exposes
65 all of them so that attributes could potentially be modified in place
66 (e.g. if configuration needs extra overrides).
68 Attributes
69 ----------
70 taskName : `str`
71 `PipelineTask` class name, currently it is not specified whether this
72 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
73 Framework should be prepared to handle all cases.
74 config : `lsst.pex.config.Config`
75 Instance of the configuration class corresponding to this task class,
76 usually with all overrides applied.
77 taskClass : `type` or ``None``
78 `PipelineTask` class object, can be ``None``. If ``None`` then
79 framework will have to locate and load class.
80 label : `str`, optional
81 Task label, usually a short string unique in a pipeline.
82 """
83 def __init__(self, taskName, config, taskClass=None, label=""):
84 self.taskName = taskName
85 self.config = config
86 self.taskClass = taskClass
87 self.label = label
88 self.connections = config.connections.ConnectionsClass(config=config)
90 @property
91 def metadataDatasetName(self):
92 """Name of a dataset type for metadata of this task, `None` if
93 metadata is not to be saved (`str`)
94 """
95 if self.config.saveMetadata:
96 return self.label + "_metadata"
97 else:
98 return None
100 def __str__(self):
101 rep = "TaskDef(" + self.taskName
102 if self.label:
103 rep += ", label=" + self.label
104 rep += ")"
105 return rep
108class Pipeline:
109 """A `Pipeline` is a representation of a series of tasks to run, and the
110 configuration for those tasks.
112 Parameters
113 ----------
114 description : `str`
115 A description of that this pipeline does.
116 """
117 def __init__(self, description: str) -> Pipeline:
118 pipeline_dict = {"description": description, "tasks": {}}
119 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
121 @classmethod
122 def fromFile(cls, filename: str) -> Pipeline:
123 """Load a pipeline defined in a pipeline yaml file.
125 Parameters
126 ----------
127 filename: `str`
128 A path that points to a pipeline defined in yaml format
130 Returns
131 -------
132 pipeline: `Pipeline`
133 """
134 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
135 return pipeline
137 @classmethod
138 def fromString(cls, pipeline_string: str) -> Pipeline:
139 """Create a pipeline from string formatted as a pipeline document.
141 Parameters
142 ----------
143 pipeline_string : `str`
144 A string that is formatted according like a pipeline document
146 Returns
147 -------
148 pipeline: `Pipeline`
149 """
150 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
151 return pipeline
153 @classmethod
154 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
155 """Create a pipeline from an already created `PipelineIR` object.
157 Parameters
158 ----------
159 deserialized_pipeline: `PipelineIR`
160 An already created pipeline intermediate representation object
162 Returns
163 -------
164 pipeline: `Pipeline`
165 """
166 pipeline = cls.__new__(cls)
167 pipeline._pipelineIR = deserialized_pipeline
168 return pipeline
170 @classmethod
171 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
172 """Create a new pipeline by copying an already existing `Pipeline`.
174 Parameters
175 ----------
176 pipeline: `Pipeline`
177 An already created pipeline intermediate representation object
179 Returns
180 -------
181 pipeline: `Pipeline`
182 """
183 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
185 def __str__(self) -> str:
186 return str(self._pipelineIR)
188 def addInstrument(self, instrument: Union[Instrument, str]):
189 """Add an instrument to the pipeline, or replace an instrument that is
190 already defined.
192 Parameters
193 ----------
194 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
195 Either a derived class object of a `lsst.daf.butler.instrument` or a
196 string corresponding to a fully qualified
197 `lsst.daf.butler.instrument` name.
198 """
199 if isinstance(instrument, str):
200 pass
201 else:
202 # TODO: assume that this is a subclass of Instrument, no type checking
203 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
204 self._pipelineIR.instrument = instrument
206 def addTask(self, task: Union[PipelineTask, str], label: str):
207 """Add a new task to the pipeline, or replace a task that is already
208 associated with the supplied label.
210 Parameters
211 ----------
212 task: `PipelineTask` or `str`
213 Either a derived class object of a `PipelineTask` or a string
214 corresponding to a fully qualified `PipelineTask` name.
215 label: `str`
216 A label that is used to identify the `PipelineTask` being added
217 """
218 if isinstance(task, str):
219 taskName = task
220 elif issubclass(task, PipelineTask):
221 taskName = f"{task.__module__}.{task.__qualname__}"
222 else:
223 raise ValueError("task must be either a child class of PipelineTask or a string containing"
224 " a fully qualified name to one")
225 if not label:
226 # in some cases (with command line-generated pipeline) tasks can
227 # be defined without label which is not acceptable, use task
228 # _DefaultName in that case
229 if isinstance(task, str):
230 task = doImport(task)
231 label = task._DefaultName
232 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
234 def removeTask(self, label: str):
235 """Remove a task from the pipeline.
237 Parameters
238 ----------
239 label : `str`
240 The label used to identify the task that is to be removed
242 Raises
243 ------
244 KeyError
245 If no task with that label exists in the pipeline
247 """
248 self._pipelineIR.tasks.pop(label)
250 def addConfigOverride(self, label: str, key: str, value: object):
251 """Apply single config override.
253 Parameters
254 ----------
255 label : `str`
256 Label of the task.
257 key: `str`
258 Fully-qualified field name.
259 value : object
260 Value to be given to a field.
261 """
262 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
264 def addConfigFile(self, label: str, filename: str):
265 """Add overrides from a specified file.
267 Parameters
268 ----------
269 label : `str`
270 The label used to identify the task associated with config to
271 modify
272 filename : `str`
273 Path to the override file.
274 """
275 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
277 def addConfigPython(self, label: str, pythonString: str):
278 """Add Overrides by running a snippet of python code against a config.
280 Parameters
281 ----------
282 label : `str`
283 The label used to identity the task associated with config to
284 modify.
285 pythonString: `str`
286 A string which is valid python code to be executed. This is done
287 with config as the only local accessible value.
288 """
289 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
291 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
292 if label not in self._pipelineIR.tasks:
293 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
294 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
296 def toFile(self, filename: str):
297 self._pipelineIR.to_file(filename)
299 def toExpandedPipeline(self) -> Generator[TaskDef]:
300 """Returns a generator of TaskDefs which can be used to create quantum
301 graphs.
303 Returns
304 -------
305 generator : generator of `TaskDef`
306 The generator returned will be the sorted iterator of tasks which
307 are to be used in constructing a quantum graph.
309 Raises
310 ------
311 NotImplementedError
312 If a dataId is supplied in a config block. This is in place for
313 future use
314 """
315 taskDefs = []
316 for label, taskIR in self._pipelineIR.tasks.items():
317 taskClass = doImport(taskIR.klass)
318 taskName = taskClass.__qualname__
319 config = taskClass.ConfigClass()
320 overrides = ConfigOverrides()
321 if self._pipelineIR.instrument is not None:
322 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
323 if taskIR.config is not None:
324 for configIR in taskIR.config:
325 if configIR.dataId is not None:
326 raise NotImplementedError("Specializing a config on a partial data id is not yet "
327 "supported in Pipeline definition")
328 # only apply override if it applies to everything
329 if configIR.dataId is None:
330 if configIR.file:
331 for configFile in configIR.file:
332 overrides.addFileOverride(configFile)
333 if configIR.python is not None:
334 overrides.addPythonOverride(configIR.python)
335 for key, value in configIR.rest.items():
336 overrides.addValueOverride(key, value)
337 overrides.applyTo(config)
338 # This may need to be revisited
339 config.validate()
340 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
342 # lets evaluate the contracts
343 if self._pipelineIR.contracts is not None:
344 label_to_config = {x.label: x.config for x in taskDefs}
345 for contract in self._pipelineIR.contracts:
346 # execute this in its own line so it can raise a good error message if there was problems
347 # with the eval
348 success = eval(contract.contract, None, label_to_config)
349 if not success:
350 extra_info = f": {contract.msg}" if contract.msg is not None else ""
351 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
352 f"satisfied{extra_info}")
354 yield from pipeTools.orderPipeline(taskDefs)
356 def __len__(self):
357 return len(self._pipelineIR.tasks)
359 def __eq__(self, other: "Pipeline"):
360 if not isinstance(other, Pipeline):
361 return False
362 return self._pipelineIR == other._pipelineIR
365@dataclass(frozen=True)
366class TaskDatasetTypes:
367 """An immutable struct that extracts and classifies the dataset types used
368 by a `PipelineTask`
369 """
371 initInputs: FrozenSet[DatasetType]
372 """Dataset types that are needed as inputs in order to construct this Task.
374 Task-level `initInputs` may be classified as either
375 `~PipelineDatasetTypes.initInputs` or
376 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
377 """
379 initOutputs: FrozenSet[DatasetType]
380 """Dataset types that may be written after constructing this Task.
382 Task-level `initOutputs` may be classified as either
383 `~PipelineDatasetTypes.initOutputs` or
384 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
385 """
387 inputs: FrozenSet[DatasetType]
388 """Dataset types that are regular inputs to this Task.
390 If an input dataset needed for a Quantum cannot be found in the input
391 collection(s) or produced by another Task in the Pipeline, that Quantum
392 (and all dependent Quanta) will not be produced.
394 Task-level `inputs` may be classified as either
395 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
396 at the Pipeline level.
397 """
399 prerequisites: FrozenSet[DatasetType]
400 """Dataset types that are prerequisite inputs to this Task.
402 Prerequisite inputs must exist in the input collection(s) before the
403 pipeline is run, but do not constrain the graph - if a prerequisite is
404 missing for a Quantum, `PrerequisiteMissingError` is raised.
406 Prerequisite inputs are not resolved until the second stage of
407 QuantumGraph generation.
408 """
410 outputs: FrozenSet[DatasetType]
411 """Dataset types that are produced by this Task.
413 Task-level `outputs` may be classified as either
414 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
415 at the Pipeline level.
416 """
418 @classmethod
419 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
420 """Extract and classify the dataset types from a single `PipelineTask`.
422 Parameters
423 ----------
424 taskDef: `TaskDef`
425 An instance of a `TaskDef` class for a particular `PipelineTask`.
426 registry: `Registry`
427 Registry used to construct normalized `DatasetType` objects and
428 retrieve those that are incomplete.
430 Returns
431 -------
432 types: `TaskDatasetTypes`
433 The dataset types used by this task.
434 """
435 def makeDatasetTypesSet(connectionType):
436 """Constructs a set of true `DatasetType` objects
438 Parameters
439 ----------
440 connectionType : `str`
441 Name of the connection type to produce a set for, corresponds
442 to an attribute of type `list` on the connection class instance
444 Returns
445 -------
446 datasetTypes : `frozenset`
447 A set of all datasetTypes which correspond to the input
448 connection type specified in the connection class of this
449 `PipelineTask`
451 Notes
452 -----
453 This function is a closure over the variables ``registry`` and
454 ``taskDef``.
455 """
456 datasetTypes = []
457 for c in iterConnections(taskDef.connections, connectionType):
458 dimensions = set(getattr(c, 'dimensions', set()))
459 if "skypix" in dimensions:
460 try:
461 datasetType = registry.getDatasetType(c.name)
462 except LookupError as err:
463 raise LookupError(
464 f"DatasetType '{c.name}' referenced by "
465 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
466 f"placeholder, but does not already exist in the registry. "
467 f"Note that reference catalog names are now used as the dataset "
468 f"type name instead of 'ref_cat'."
469 ) from err
470 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
471 rest2 = set(dim.name for dim in datasetType.dimensions
472 if not isinstance(dim, SkyPixDimension))
473 if rest1 != rest2:
474 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
475 f"connections ({rest1}) are inconsistent with those in "
476 f"registry's version of this dataset ({rest2}).")
477 else:
478 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
479 c.storageClass)
480 try:
481 registryDatasetType = registry.getDatasetType(c.name)
482 except KeyError:
483 registryDatasetType = datasetType
484 if datasetType != registryDatasetType:
485 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
486 f"registry definition ({registryDatasetType})")
487 datasetTypes.append(datasetType)
488 return frozenset(datasetTypes)
490 # optionally add output dataset for metadata
491 outputs = makeDatasetTypesSet("outputs")
492 if taskDef.metadataDatasetName is not None:
493 # Metadata is supposed to be of the PropertyList type, its dimensions
494 # correspond to a task quantum
495 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
496 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertyList")}
498 return cls(
499 initInputs=makeDatasetTypesSet("initInputs"),
500 initOutputs=makeDatasetTypesSet("initOutputs"),
501 inputs=makeDatasetTypesSet("inputs"),
502 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
503 outputs=outputs,
504 )
507@dataclass(frozen=True)
508class PipelineDatasetTypes:
509 """An immutable struct that classifies the dataset types used in a
510 `Pipeline`.
511 """
513 initInputs: FrozenSet[DatasetType]
514 """Dataset types that are needed as inputs in order to construct the Tasks
515 in this Pipeline.
517 This does not include dataset types that are produced when constructing
518 other Tasks in the Pipeline (these are classified as `initIntermediates`).
519 """
521 initOutputs: FrozenSet[DatasetType]
522 """Dataset types that may be written after constructing the Tasks in this
523 Pipeline.
525 This does not include dataset types that are also used as inputs when
526 constructing other Tasks in the Pipeline (these are classified as
527 `initIntermediates`).
528 """
530 initIntermediates: FrozenSet[DatasetType]
531 """Dataset types that are both used when constructing one or more Tasks
532 in the Pipeline and produced as a side-effect of constructing another
533 Task in the Pipeline.
534 """
536 inputs: FrozenSet[DatasetType]
537 """Dataset types that are regular inputs for the full pipeline.
539 If an input dataset needed for a Quantum cannot be found in the input
540 collection(s), that Quantum (and all dependent Quanta) will not be
541 produced.
542 """
544 prerequisites: FrozenSet[DatasetType]
545 """Dataset types that are prerequisite inputs for the full Pipeline.
547 Prerequisite inputs must exist in the input collection(s) before the
548 pipeline is run, but do not constrain the graph - if a prerequisite is
549 missing for a Quantum, `PrerequisiteMissingError` is raised.
551 Prerequisite inputs are not resolved until the second stage of
552 QuantumGraph generation.
553 """
555 intermediates: FrozenSet[DatasetType]
556 """Dataset types that are output by one Task in the Pipeline and consumed
557 as inputs by one or more other Tasks in the Pipeline.
558 """
560 outputs: FrozenSet[DatasetType]
561 """Dataset types that are output by a Task in the Pipeline and not consumed
562 by any other Task in the Pipeline.
563 """
565 byTask: Mapping[str, TaskDatasetTypes]
566 """Per-Task dataset types, keyed by label in the `Pipeline`.
568 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
569 neither has been modified since the dataset types were extracted, of
570 course).
571 """
573 @classmethod
574 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
575 """Extract and classify the dataset types from all tasks in a
576 `Pipeline`.
578 Parameters
579 ----------
580 pipeline: `Pipeline`
581 An ordered collection of tasks that can be run together.
582 registry: `Registry`
583 Registry used to construct normalized `DatasetType` objects and
584 retrieve those that are incomplete.
586 Returns
587 -------
588 types: `PipelineDatasetTypes`
589 The dataset types used by this `Pipeline`.
591 Raises
592 ------
593 ValueError
594 Raised if Tasks are inconsistent about which datasets are marked
595 prerequisite. This indicates that the Tasks cannot be run as part
596 of the same `Pipeline`.
597 """
598 allInputs = set()
599 allOutputs = set()
600 allInitInputs = set()
601 allInitOutputs = set()
602 prerequisites = set()
603 byTask = dict()
604 if isinstance(pipeline, Pipeline):
605 pipeline = pipeline.toExpandedPipeline()
606 for taskDef in pipeline:
607 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
608 allInitInputs.update(thisTask.initInputs)
609 allInitOutputs.update(thisTask.initOutputs)
610 allInputs.update(thisTask.inputs)
611 prerequisites.update(thisTask.prerequisites)
612 allOutputs.update(thisTask.outputs)
613 byTask[taskDef.label] = thisTask
614 if not prerequisites.isdisjoint(allInputs):
615 raise ValueError("{} marked as both prerequisites and regular inputs".format(
616 {dt.name for dt in allInputs & prerequisites}
617 ))
618 if not prerequisites.isdisjoint(allOutputs):
619 raise ValueError("{} marked as both prerequisites and outputs".format(
620 {dt.name for dt in allOutputs & prerequisites}
621 ))
622 # Make sure that components which are marked as inputs get treated as
623 # intermediates if there is an output which produces the composite
624 # containing the component
625 intermediateComponents = set()
626 intermediateComposites = set()
627 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
628 for dsType in allInputs:
629 # get the name of a possible component
630 name, component = dsType.nameAndComponent()
631 # if there is a component name, that means this is a component
632 # DatasetType, if there is an output which produces the parent of
633 # this component, treat this input as an intermediate
634 if component is not None:
635 if name in outputNameMapping and outputNameMapping[name].dimensions == dsType.dimensions:
636 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
637 universe=registry.dimensions)
638 intermediateComponents.add(dsType)
639 intermediateComposites.add(composite)
640 return cls(
641 initInputs=frozenset(allInitInputs - allInitOutputs),
642 initIntermediates=frozenset(allInitInputs & allInitOutputs),
643 initOutputs=frozenset(allInitOutputs - allInitInputs),
644 inputs=frozenset(allInputs - allOutputs - intermediateComponents),
645 intermediates=frozenset(allInputs & allOutputs | intermediateComponents),
646 outputs=frozenset(allOutputs - allInputs - intermediateComposites),
647 prerequisites=frozenset(prerequisites),
648 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
649 )