Coverage for python/lsst/pipe/base/pipeline.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Union, Generator, TYPE_CHECKING
35import copy
37# -----------------------------
38# Imports for other modules --
39from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
40from lsst.utils import doImport
41from .configOverrides import ConfigOverrides
42from .connections import iterConnections
43from .pipelineTask import PipelineTask
45from . import pipelineIR
46from . import pipeTools
48if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true
49 from lsst.obs.base.instrument import Instrument
51# ----------------------------------
52# Local non-exported definitions --
53# ----------------------------------
55# ------------------------
56# Exported definitions --
57# ------------------------
60class TaskDef:
61 """TaskDef is a collection of information about task needed by Pipeline.
63 The information includes task name, configuration object and optional
64 task class. This class is just a collection of attributes and it exposes
65 all of them so that attributes could potentially be modified in place
66 (e.g. if configuration needs extra overrides).
68 Attributes
69 ----------
70 taskName : `str`
71 `PipelineTask` class name, currently it is not specified whether this
72 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
73 Framework should be prepared to handle all cases.
74 config : `lsst.pex.config.Config`
75 Instance of the configuration class corresponding to this task class,
76 usually with all overrides applied.
77 taskClass : `type` or ``None``
78 `PipelineTask` class object, can be ``None``. If ``None`` then
79 framework will have to locate and load class.
80 label : `str`, optional
81 Task label, usually a short string unique in a pipeline.
82 """
83 def __init__(self, taskName, config, taskClass=None, label=""):
84 self.taskName = taskName
85 self.config = config
86 self.taskClass = taskClass
87 self.label = label
88 self.connections = config.connections.ConnectionsClass(config=config)
90 @property
91 def configDatasetName(self):
92 """Name of a dataset type for configuration of this task (`str`)
93 """
94 return self.label + "_config"
96 @property
97 def metadataDatasetName(self):
98 """Name of a dataset type for metadata of this task, `None` if
99 metadata is not to be saved (`str`)
100 """
101 if self.config.saveMetadata:
102 return self.label + "_metadata"
103 else:
104 return None
106 def __str__(self):
107 rep = "TaskDef(" + self.taskName
108 if self.label:
109 rep += ", label=" + self.label
110 rep += ")"
111 return rep
114class Pipeline:
115 """A `Pipeline` is a representation of a series of tasks to run, and the
116 configuration for those tasks.
118 Parameters
119 ----------
120 description : `str`
121 A description of that this pipeline does.
122 """
123 def __init__(self, description: str) -> Pipeline:
124 pipeline_dict = {"description": description, "tasks": {}}
125 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
127 @classmethod
128 def fromFile(cls, filename: str) -> Pipeline:
129 """Load a pipeline defined in a pipeline yaml file.
131 Parameters
132 ----------
133 filename: `str`
134 A path that points to a pipeline defined in yaml format
136 Returns
137 -------
138 pipeline: `Pipeline`
139 """
140 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
141 return pipeline
143 @classmethod
144 def fromString(cls, pipeline_string: str) -> Pipeline:
145 """Create a pipeline from string formatted as a pipeline document.
147 Parameters
148 ----------
149 pipeline_string : `str`
150 A string that is formatted according like a pipeline document
152 Returns
153 -------
154 pipeline: `Pipeline`
155 """
156 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
157 return pipeline
159 @classmethod
160 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
161 """Create a pipeline from an already created `PipelineIR` object.
163 Parameters
164 ----------
165 deserialized_pipeline: `PipelineIR`
166 An already created pipeline intermediate representation object
168 Returns
169 -------
170 pipeline: `Pipeline`
171 """
172 pipeline = cls.__new__(cls)
173 pipeline._pipelineIR = deserialized_pipeline
174 return pipeline
176 @classmethod
177 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
178 """Create a new pipeline by copying an already existing `Pipeline`.
180 Parameters
181 ----------
182 pipeline: `Pipeline`
183 An already created pipeline intermediate representation object
185 Returns
186 -------
187 pipeline: `Pipeline`
188 """
189 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
191 def __str__(self) -> str:
192 return str(self._pipelineIR)
194 def addInstrument(self, instrument: Union[Instrument, str]):
195 """Add an instrument to the pipeline, or replace an instrument that is
196 already defined.
198 Parameters
199 ----------
200 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
201 Either a derived class object of a `lsst.daf.butler.instrument` or a
202 string corresponding to a fully qualified
203 `lsst.daf.butler.instrument` name.
204 """
205 if isinstance(instrument, str):
206 pass
207 else:
208 # TODO: assume that this is a subclass of Instrument, no type checking
209 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
210 self._pipelineIR.instrument = instrument
212 def addTask(self, task: Union[PipelineTask, str], label: str):
213 """Add a new task to the pipeline, or replace a task that is already
214 associated with the supplied label.
216 Parameters
217 ----------
218 task: `PipelineTask` or `str`
219 Either a derived class object of a `PipelineTask` or a string
220 corresponding to a fully qualified `PipelineTask` name.
221 label: `str`
222 A label that is used to identify the `PipelineTask` being added
223 """
224 if isinstance(task, str):
225 taskName = task
226 elif issubclass(task, PipelineTask):
227 taskName = f"{task.__module__}.{task.__qualname__}"
228 else:
229 raise ValueError("task must be either a child class of PipelineTask or a string containing"
230 " a fully qualified name to one")
231 if not label:
232 # in some cases (with command line-generated pipeline) tasks can
233 # be defined without label which is not acceptable, use task
234 # _DefaultName in that case
235 if isinstance(task, str):
236 task = doImport(task)
237 label = task._DefaultName
238 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
240 def removeTask(self, label: str):
241 """Remove a task from the pipeline.
243 Parameters
244 ----------
245 label : `str`
246 The label used to identify the task that is to be removed
248 Raises
249 ------
250 KeyError
251 If no task with that label exists in the pipeline
253 """
254 self._pipelineIR.tasks.pop(label)
256 def addConfigOverride(self, label: str, key: str, value: object):
257 """Apply single config override.
259 Parameters
260 ----------
261 label : `str`
262 Label of the task.
263 key: `str`
264 Fully-qualified field name.
265 value : object
266 Value to be given to a field.
267 """
268 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
270 def addConfigFile(self, label: str, filename: str):
271 """Add overrides from a specified file.
273 Parameters
274 ----------
275 label : `str`
276 The label used to identify the task associated with config to
277 modify
278 filename : `str`
279 Path to the override file.
280 """
281 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
283 def addConfigPython(self, label: str, pythonString: str):
284 """Add Overrides by running a snippet of python code against a config.
286 Parameters
287 ----------
288 label : `str`
289 The label used to identity the task associated with config to
290 modify.
291 pythonString: `str`
292 A string which is valid python code to be executed. This is done
293 with config as the only local accessible value.
294 """
295 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
297 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
298 if label not in self._pipelineIR.tasks:
299 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
300 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
302 def toFile(self, filename: str):
303 self._pipelineIR.to_file(filename)
305 def toExpandedPipeline(self) -> Generator[TaskDef]:
306 """Returns a generator of TaskDefs which can be used to create quantum
307 graphs.
309 Returns
310 -------
311 generator : generator of `TaskDef`
312 The generator returned will be the sorted iterator of tasks which
313 are to be used in constructing a quantum graph.
315 Raises
316 ------
317 NotImplementedError
318 If a dataId is supplied in a config block. This is in place for
319 future use
320 """
321 taskDefs = []
322 for label, taskIR in self._pipelineIR.tasks.items():
323 taskClass = doImport(taskIR.klass)
324 taskName = taskClass.__qualname__
325 config = taskClass.ConfigClass()
326 overrides = ConfigOverrides()
327 if self._pipelineIR.instrument is not None:
328 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
329 if taskIR.config is not None:
330 for configIR in taskIR.config:
331 if configIR.dataId is not None:
332 raise NotImplementedError("Specializing a config on a partial data id is not yet "
333 "supported in Pipeline definition")
334 # only apply override if it applies to everything
335 if configIR.dataId is None:
336 if configIR.file:
337 for configFile in configIR.file:
338 overrides.addFileOverride(configFile)
339 if configIR.python is not None:
340 overrides.addPythonOverride(configIR.python)
341 for key, value in configIR.rest.items():
342 overrides.addValueOverride(key, value)
343 overrides.applyTo(config)
344 # This may need to be revisited
345 config.validate()
346 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
348 # lets evaluate the contracts
349 if self._pipelineIR.contracts is not None:
350 label_to_config = {x.label: x.config for x in taskDefs}
351 for contract in self._pipelineIR.contracts:
352 # execute this in its own line so it can raise a good error message if there was problems
353 # with the eval
354 success = eval(contract.contract, None, label_to_config)
355 if not success:
356 extra_info = f": {contract.msg}" if contract.msg is not None else ""
357 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
358 f"satisfied{extra_info}")
360 yield from pipeTools.orderPipeline(taskDefs)
362 def __len__(self):
363 return len(self._pipelineIR.tasks)
365 def __eq__(self, other: "Pipeline"):
366 if not isinstance(other, Pipeline):
367 return False
368 return self._pipelineIR == other._pipelineIR
371@dataclass(frozen=True)
372class TaskDatasetTypes:
373 """An immutable struct that extracts and classifies the dataset types used
374 by a `PipelineTask`
375 """
377 initInputs: NamedValueSet[DatasetType]
378 """Dataset types that are needed as inputs in order to construct this Task.
380 Task-level `initInputs` may be classified as either
381 `~PipelineDatasetTypes.initInputs` or
382 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
383 """
385 initOutputs: NamedValueSet[DatasetType]
386 """Dataset types that may be written after constructing this Task.
388 Task-level `initOutputs` may be classified as either
389 `~PipelineDatasetTypes.initOutputs` or
390 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
391 """
393 inputs: NamedValueSet[DatasetType]
394 """Dataset types that are regular inputs to this Task.
396 If an input dataset needed for a Quantum cannot be found in the input
397 collection(s) or produced by another Task in the Pipeline, that Quantum
398 (and all dependent Quanta) will not be produced.
400 Task-level `inputs` may be classified as either
401 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
402 at the Pipeline level.
403 """
405 prerequisites: NamedValueSet[DatasetType]
406 """Dataset types that are prerequisite inputs to this Task.
408 Prerequisite inputs must exist in the input collection(s) before the
409 pipeline is run, but do not constrain the graph - if a prerequisite is
410 missing for a Quantum, `PrerequisiteMissingError` is raised.
412 Prerequisite inputs are not resolved until the second stage of
413 QuantumGraph generation.
414 """
416 outputs: NamedValueSet[DatasetType]
417 """Dataset types that are produced by this Task.
419 Task-level `outputs` may be classified as either
420 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
421 at the Pipeline level.
422 """
424 @classmethod
425 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
426 """Extract and classify the dataset types from a single `PipelineTask`.
428 Parameters
429 ----------
430 taskDef: `TaskDef`
431 An instance of a `TaskDef` class for a particular `PipelineTask`.
432 registry: `Registry`
433 Registry used to construct normalized `DatasetType` objects and
434 retrieve those that are incomplete.
436 Returns
437 -------
438 types: `TaskDatasetTypes`
439 The dataset types used by this task.
440 """
441 def makeDatasetTypesSet(connectionType, freeze=True):
442 """Constructs a set of true `DatasetType` objects
444 Parameters
445 ----------
446 connectionType : `str`
447 Name of the connection type to produce a set for, corresponds
448 to an attribute of type `list` on the connection class instance
449 freeze : `bool`, optional
450 If `True`, call `NamedValueSet.freeze` on the object returned.
452 Returns
453 -------
454 datasetTypes : `NamedValueSet`
455 A set of all datasetTypes which correspond to the input
456 connection type specified in the connection class of this
457 `PipelineTask`
459 Notes
460 -----
461 This function is a closure over the variables ``registry`` and
462 ``taskDef``.
463 """
464 datasetTypes = NamedValueSet()
465 for c in iterConnections(taskDef.connections, connectionType):
466 dimensions = set(getattr(c, 'dimensions', set()))
467 if "skypix" in dimensions:
468 try:
469 datasetType = registry.getDatasetType(c.name)
470 except LookupError as err:
471 raise LookupError(
472 f"DatasetType '{c.name}' referenced by "
473 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
474 f"placeholder, but does not already exist in the registry. "
475 f"Note that reference catalog names are now used as the dataset "
476 f"type name instead of 'ref_cat'."
477 ) from err
478 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
479 rest2 = set(dim.name for dim in datasetType.dimensions
480 if not isinstance(dim, SkyPixDimension))
481 if rest1 != rest2:
482 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
483 f"connections ({rest1}) are inconsistent with those in "
484 f"registry's version of this dataset ({rest2}).")
485 else:
486 # Component dataset types are not explicitly in the
487 # registry. This complicates consistency checks with
488 # registry and requires we work out the composite storage
489 # class.
490 registryDatasetType = None
491 try:
492 registryDatasetType = registry.getDatasetType(c.name)
493 except KeyError:
494 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
495 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
496 if componentName else None
497 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
498 c.storageClass,
499 parentStorageClass=parentStorageClass)
500 registryDatasetType = datasetType
501 else:
502 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
503 c.storageClass,
504 parentStorageClass=registryDatasetType.parentStorageClass)
506 if registryDatasetType and datasetType != registryDatasetType:
507 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
508 f"registry definition ({registryDatasetType}) "
509 f"for {taskDef.label}.")
510 datasetTypes.add(datasetType)
511 if freeze:
512 datasetTypes.freeze()
513 return datasetTypes
515 # optionally add output dataset for metadata
516 outputs = makeDatasetTypesSet("outputs", freeze=False)
517 if taskDef.metadataDatasetName is not None:
518 # Metadata is supposed to be of the PropertyList type, its dimensions
519 # correspond to a task quantum
520 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
521 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertyList")}
522 outputs.freeze()
524 return cls(
525 initInputs=makeDatasetTypesSet("initInputs"),
526 initOutputs=makeDatasetTypesSet("initOutputs"),
527 inputs=makeDatasetTypesSet("inputs"),
528 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
529 outputs=outputs,
530 )
533@dataclass(frozen=True)
534class PipelineDatasetTypes:
535 """An immutable struct that classifies the dataset types used in a
536 `Pipeline`.
537 """
539 initInputs: NamedValueSet[DatasetType]
540 """Dataset types that are needed as inputs in order to construct the Tasks
541 in this Pipeline.
543 This does not include dataset types that are produced when constructing
544 other Tasks in the Pipeline (these are classified as `initIntermediates`).
545 """
547 initOutputs: NamedValueSet[DatasetType]
548 """Dataset types that may be written after constructing the Tasks in this
549 Pipeline.
551 This does not include dataset types that are also used as inputs when
552 constructing other Tasks in the Pipeline (these are classified as
553 `initIntermediates`).
554 """
556 initIntermediates: NamedValueSet[DatasetType]
557 """Dataset types that are both used when constructing one or more Tasks
558 in the Pipeline and produced as a side-effect of constructing another
559 Task in the Pipeline.
560 """
562 inputs: NamedValueSet[DatasetType]
563 """Dataset types that are regular inputs for the full pipeline.
565 If an input dataset needed for a Quantum cannot be found in the input
566 collection(s), that Quantum (and all dependent Quanta) will not be
567 produced.
568 """
570 prerequisites: NamedValueSet[DatasetType]
571 """Dataset types that are prerequisite inputs for the full Pipeline.
573 Prerequisite inputs must exist in the input collection(s) before the
574 pipeline is run, but do not constrain the graph - if a prerequisite is
575 missing for a Quantum, `PrerequisiteMissingError` is raised.
577 Prerequisite inputs are not resolved until the second stage of
578 QuantumGraph generation.
579 """
581 intermediates: NamedValueSet[DatasetType]
582 """Dataset types that are output by one Task in the Pipeline and consumed
583 as inputs by one or more other Tasks in the Pipeline.
584 """
586 outputs: NamedValueSet[DatasetType]
587 """Dataset types that are output by a Task in the Pipeline and not consumed
588 by any other Task in the Pipeline.
589 """
591 byTask: Mapping[str, TaskDatasetTypes]
592 """Per-Task dataset types, keyed by label in the `Pipeline`.
594 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
595 neither has been modified since the dataset types were extracted, of
596 course).
597 """
599 @classmethod
600 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
601 """Extract and classify the dataset types from all tasks in a
602 `Pipeline`.
604 Parameters
605 ----------
606 pipeline: `Pipeline`
607 An ordered collection of tasks that can be run together.
608 registry: `Registry`
609 Registry used to construct normalized `DatasetType` objects and
610 retrieve those that are incomplete.
612 Returns
613 -------
614 types: `PipelineDatasetTypes`
615 The dataset types used by this `Pipeline`.
617 Raises
618 ------
619 ValueError
620 Raised if Tasks are inconsistent about which datasets are marked
621 prerequisite. This indicates that the Tasks cannot be run as part
622 of the same `Pipeline`.
623 """
624 allInputs = NamedValueSet()
625 allOutputs = NamedValueSet()
626 allInitInputs = NamedValueSet()
627 allInitOutputs = NamedValueSet()
628 prerequisites = NamedValueSet()
629 byTask = dict()
630 if isinstance(pipeline, Pipeline):
631 pipeline = pipeline.toExpandedPipeline()
632 for taskDef in pipeline:
633 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
634 allInitInputs |= thisTask.initInputs
635 allInitOutputs |= thisTask.initOutputs
636 allInputs |= thisTask.inputs
637 prerequisites |= thisTask.prerequisites
638 allOutputs |= thisTask.outputs
639 byTask[taskDef.label] = thisTask
640 if not prerequisites.isdisjoint(allInputs):
641 raise ValueError("{} marked as both prerequisites and regular inputs".format(
642 {dt.name for dt in allInputs & prerequisites}
643 ))
644 if not prerequisites.isdisjoint(allOutputs):
645 raise ValueError("{} marked as both prerequisites and outputs".format(
646 {dt.name for dt in allOutputs & prerequisites}
647 ))
648 # Make sure that components which are marked as inputs get treated as
649 # intermediates if there is an output which produces the composite
650 # containing the component
651 intermediateComponents = NamedValueSet()
652 intermediateComposites = NamedValueSet()
653 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
654 for dsType in allInputs:
655 # get the name of a possible component
656 name, component = dsType.nameAndComponent()
657 # if there is a component name, that means this is a component
658 # DatasetType, if there is an output which produces the parent of
659 # this component, treat this input as an intermediate
660 if component is not None:
661 if name in outputNameMapping:
662 if outputNameMapping[name].dimensions != dsType.dimensions:
663 raise ValueError(f"Component dataset type {dsType.name} has different "
664 f"dimensions ({dsType.dimensions}) than its parent "
665 f"({outputNameMapping[name].dimensions}).")
666 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
667 universe=registry.dimensions)
668 intermediateComponents.add(dsType)
669 intermediateComposites.add(composite)
671 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
672 common = a.names & b.names
673 for name in common:
674 if a[name] != b[name]:
675 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
677 checkConsistency(allInitInputs, allInitOutputs)
678 checkConsistency(allInputs, allOutputs)
679 checkConsistency(allInputs, intermediateComposites)
680 checkConsistency(allOutputs, intermediateComposites)
682 def frozen(s: NamedValueSet) -> NamedValueSet:
683 s.freeze()
684 return s
686 return cls(
687 initInputs=frozen(allInitInputs - allInitOutputs),
688 initIntermediates=frozen(allInitInputs & allInitOutputs),
689 initOutputs=frozen(allInitOutputs - allInitInputs),
690 inputs=frozen(allInputs - allOutputs - intermediateComponents),
691 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
692 outputs=frozen(allOutputs - allInputs - intermediateComposites),
693 prerequisites=frozen(prerequisites),
694 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
695 )