Coverage for python/lsst/pipe/base/pipeline.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Union, Generator, TYPE_CHECKING
35import copy
36import os
38# -----------------------------
39# Imports for other modules --
40from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
41from lsst.utils import doImport
42from .configOverrides import ConfigOverrides
43from .connections import iterConnections
44from .pipelineTask import PipelineTask
46from . import pipelineIR
47from . import pipeTools
49if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 from lsst.obs.base.instrument import Instrument
52# ----------------------------------
53# Local non-exported definitions --
54# ----------------------------------
56# ------------------------
57# Exported definitions --
58# ------------------------
61class TaskDef:
62 """TaskDef is a collection of information about task needed by Pipeline.
64 The information includes task name, configuration object and optional
65 task class. This class is just a collection of attributes and it exposes
66 all of them so that attributes could potentially be modified in place
67 (e.g. if configuration needs extra overrides).
69 Attributes
70 ----------
71 taskName : `str`
72 `PipelineTask` class name, currently it is not specified whether this
73 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
74 Framework should be prepared to handle all cases.
75 config : `lsst.pex.config.Config`
76 Instance of the configuration class corresponding to this task class,
77 usually with all overrides applied.
78 taskClass : `type` or ``None``
79 `PipelineTask` class object, can be ``None``. If ``None`` then
80 framework will have to locate and load class.
81 label : `str`, optional
82 Task label, usually a short string unique in a pipeline.
83 """
84 def __init__(self, taskName, config, taskClass=None, label=""):
85 self.taskName = taskName
86 self.config = config
87 self.taskClass = taskClass
88 self.label = label
89 self.connections = config.connections.ConnectionsClass(config=config)
91 @property
92 def configDatasetName(self):
93 """Name of a dataset type for configuration of this task (`str`)
94 """
95 return self.label + "_config"
97 @property
98 def metadataDatasetName(self):
99 """Name of a dataset type for metadata of this task, `None` if
100 metadata is not to be saved (`str`)
101 """
102 if self.config.saveMetadata:
103 return self.label + "_metadata"
104 else:
105 return None
107 def __str__(self):
108 rep = "TaskDef(" + self.taskName
109 if self.label:
110 rep += ", label=" + self.label
111 rep += ")"
112 return rep
115class Pipeline:
116 """A `Pipeline` is a representation of a series of tasks to run, and the
117 configuration for those tasks.
119 Parameters
120 ----------
121 description : `str`
122 A description of that this pipeline does.
123 """
124 def __init__(self, description: str) -> Pipeline:
125 pipeline_dict = {"description": description, "tasks": {}}
126 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
128 @classmethod
129 def fromFile(cls, filename: str) -> Pipeline:
130 """Load a pipeline defined in a pipeline yaml file.
132 Parameters
133 ----------
134 filename: `str`
135 A path that points to a pipeline defined in yaml format
137 Returns
138 -------
139 pipeline: `Pipeline`
140 """
141 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
142 return pipeline
144 @classmethod
145 def fromString(cls, pipeline_string: str) -> Pipeline:
146 """Create a pipeline from string formatted as a pipeline document.
148 Parameters
149 ----------
150 pipeline_string : `str`
151 A string that is formatted according like a pipeline document
153 Returns
154 -------
155 pipeline: `Pipeline`
156 """
157 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
158 return pipeline
160 @classmethod
161 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
162 """Create a pipeline from an already created `PipelineIR` object.
164 Parameters
165 ----------
166 deserialized_pipeline: `PipelineIR`
167 An already created pipeline intermediate representation object
169 Returns
170 -------
171 pipeline: `Pipeline`
172 """
173 pipeline = cls.__new__(cls)
174 pipeline._pipelineIR = deserialized_pipeline
175 return pipeline
177 @classmethod
178 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
179 """Create a new pipeline by copying an already existing `Pipeline`.
181 Parameters
182 ----------
183 pipeline: `Pipeline`
184 An already created pipeline intermediate representation object
186 Returns
187 -------
188 pipeline: `Pipeline`
189 """
190 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
192 def __str__(self) -> str:
193 return str(self._pipelineIR)
195 def addInstrument(self, instrument: Union[Instrument, str]):
196 """Add an instrument to the pipeline, or replace an instrument that is
197 already defined.
199 Parameters
200 ----------
201 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
202 Either a derived class object of a `lsst.daf.butler.instrument` or a
203 string corresponding to a fully qualified
204 `lsst.daf.butler.instrument` name.
205 """
206 if isinstance(instrument, str):
207 pass
208 else:
209 # TODO: assume that this is a subclass of Instrument, no type checking
210 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
211 self._pipelineIR.instrument = instrument
213 def addTask(self, task: Union[PipelineTask, str], label: str):
214 """Add a new task to the pipeline, or replace a task that is already
215 associated with the supplied label.
217 Parameters
218 ----------
219 task: `PipelineTask` or `str`
220 Either a derived class object of a `PipelineTask` or a string
221 corresponding to a fully qualified `PipelineTask` name.
222 label: `str`
223 A label that is used to identify the `PipelineTask` being added
224 """
225 if isinstance(task, str):
226 taskName = task
227 elif issubclass(task, PipelineTask):
228 taskName = f"{task.__module__}.{task.__qualname__}"
229 else:
230 raise ValueError("task must be either a child class of PipelineTask or a string containing"
231 " a fully qualified name to one")
232 if not label:
233 # in some cases (with command line-generated pipeline) tasks can
234 # be defined without label which is not acceptable, use task
235 # _DefaultName in that case
236 if isinstance(task, str):
237 task = doImport(task)
238 label = task._DefaultName
239 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
241 def removeTask(self, label: str):
242 """Remove a task from the pipeline.
244 Parameters
245 ----------
246 label : `str`
247 The label used to identify the task that is to be removed
249 Raises
250 ------
251 KeyError
252 If no task with that label exists in the pipeline
254 """
255 self._pipelineIR.tasks.pop(label)
257 def addConfigOverride(self, label: str, key: str, value: object):
258 """Apply single config override.
260 Parameters
261 ----------
262 label : `str`
263 Label of the task.
264 key: `str`
265 Fully-qualified field name.
266 value : object
267 Value to be given to a field.
268 """
269 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
271 def addConfigFile(self, label: str, filename: str):
272 """Add overrides from a specified file.
274 Parameters
275 ----------
276 label : `str`
277 The label used to identify the task associated with config to
278 modify
279 filename : `str`
280 Path to the override file.
281 """
282 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
284 def addConfigPython(self, label: str, pythonString: str):
285 """Add Overrides by running a snippet of python code against a config.
287 Parameters
288 ----------
289 label : `str`
290 The label used to identity the task associated with config to
291 modify.
292 pythonString: `str`
293 A string which is valid python code to be executed. This is done
294 with config as the only local accessible value.
295 """
296 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
298 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
299 if label not in self._pipelineIR.tasks:
300 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
301 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
303 def toFile(self, filename: str):
304 self._pipelineIR.to_file(filename)
306 def toExpandedPipeline(self) -> Generator[TaskDef]:
307 """Returns a generator of TaskDefs which can be used to create quantum
308 graphs.
310 Returns
311 -------
312 generator : generator of `TaskDef`
313 The generator returned will be the sorted iterator of tasks which
314 are to be used in constructing a quantum graph.
316 Raises
317 ------
318 NotImplementedError
319 If a dataId is supplied in a config block. This is in place for
320 future use
321 """
322 taskDefs = []
323 for label, taskIR in self._pipelineIR.tasks.items():
324 taskClass = doImport(taskIR.klass)
325 taskName = taskClass.__qualname__
326 config = taskClass.ConfigClass()
327 overrides = ConfigOverrides()
328 if self._pipelineIR.instrument is not None:
329 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
330 if taskIR.config is not None:
331 for configIR in taskIR.config:
332 if configIR.dataId is not None:
333 raise NotImplementedError("Specializing a config on a partial data id is not yet "
334 "supported in Pipeline definition")
335 # only apply override if it applies to everything
336 if configIR.dataId is None:
337 if configIR.file:
338 for configFile in configIR.file:
339 overrides.addFileOverride(os.path.expandvars(configFile))
340 if configIR.python is not None:
341 overrides.addPythonOverride(configIR.python)
342 for key, value in configIR.rest.items():
343 overrides.addValueOverride(key, value)
344 overrides.applyTo(config)
345 # This may need to be revisited
346 config.validate()
347 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
349 # lets evaluate the contracts
350 if self._pipelineIR.contracts is not None:
351 label_to_config = {x.label: x.config for x in taskDefs}
352 for contract in self._pipelineIR.contracts:
353 # execute this in its own line so it can raise a good error message if there was problems
354 # with the eval
355 success = eval(contract.contract, None, label_to_config)
356 if not success:
357 extra_info = f": {contract.msg}" if contract.msg is not None else ""
358 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
359 f"satisfied{extra_info}")
361 yield from pipeTools.orderPipeline(taskDefs)
363 def __len__(self):
364 return len(self._pipelineIR.tasks)
366 def __eq__(self, other: "Pipeline"):
367 if not isinstance(other, Pipeline):
368 return False
369 return self._pipelineIR == other._pipelineIR
372@dataclass(frozen=True)
373class TaskDatasetTypes:
374 """An immutable struct that extracts and classifies the dataset types used
375 by a `PipelineTask`
376 """
378 initInputs: NamedValueSet[DatasetType]
379 """Dataset types that are needed as inputs in order to construct this Task.
381 Task-level `initInputs` may be classified as either
382 `~PipelineDatasetTypes.initInputs` or
383 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
384 """
386 initOutputs: NamedValueSet[DatasetType]
387 """Dataset types that may be written after constructing this Task.
389 Task-level `initOutputs` may be classified as either
390 `~PipelineDatasetTypes.initOutputs` or
391 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
392 """
394 inputs: NamedValueSet[DatasetType]
395 """Dataset types that are regular inputs to this Task.
397 If an input dataset needed for a Quantum cannot be found in the input
398 collection(s) or produced by another Task in the Pipeline, that Quantum
399 (and all dependent Quanta) will not be produced.
401 Task-level `inputs` may be classified as either
402 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
403 at the Pipeline level.
404 """
406 prerequisites: NamedValueSet[DatasetType]
407 """Dataset types that are prerequisite inputs to this Task.
409 Prerequisite inputs must exist in the input collection(s) before the
410 pipeline is run, but do not constrain the graph - if a prerequisite is
411 missing for a Quantum, `PrerequisiteMissingError` is raised.
413 Prerequisite inputs are not resolved until the second stage of
414 QuantumGraph generation.
415 """
417 outputs: NamedValueSet[DatasetType]
418 """Dataset types that are produced by this Task.
420 Task-level `outputs` may be classified as either
421 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
422 at the Pipeline level.
423 """
425 @classmethod
426 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
427 """Extract and classify the dataset types from a single `PipelineTask`.
429 Parameters
430 ----------
431 taskDef: `TaskDef`
432 An instance of a `TaskDef` class for a particular `PipelineTask`.
433 registry: `Registry`
434 Registry used to construct normalized `DatasetType` objects and
435 retrieve those that are incomplete.
437 Returns
438 -------
439 types: `TaskDatasetTypes`
440 The dataset types used by this task.
441 """
442 def makeDatasetTypesSet(connectionType, freeze=True):
443 """Constructs a set of true `DatasetType` objects
445 Parameters
446 ----------
447 connectionType : `str`
448 Name of the connection type to produce a set for, corresponds
449 to an attribute of type `list` on the connection class instance
450 freeze : `bool`, optional
451 If `True`, call `NamedValueSet.freeze` on the object returned.
453 Returns
454 -------
455 datasetTypes : `NamedValueSet`
456 A set of all datasetTypes which correspond to the input
457 connection type specified in the connection class of this
458 `PipelineTask`
460 Notes
461 -----
462 This function is a closure over the variables ``registry`` and
463 ``taskDef``.
464 """
465 datasetTypes = NamedValueSet()
466 for c in iterConnections(taskDef.connections, connectionType):
467 dimensions = set(getattr(c, 'dimensions', set()))
468 if "skypix" in dimensions:
469 try:
470 datasetType = registry.getDatasetType(c.name)
471 except LookupError as err:
472 raise LookupError(
473 f"DatasetType '{c.name}' referenced by "
474 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
475 f"placeholder, but does not already exist in the registry. "
476 f"Note that reference catalog names are now used as the dataset "
477 f"type name instead of 'ref_cat'."
478 ) from err
479 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
480 rest2 = set(dim.name for dim in datasetType.dimensions
481 if not isinstance(dim, SkyPixDimension))
482 if rest1 != rest2:
483 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
484 f"connections ({rest1}) are inconsistent with those in "
485 f"registry's version of this dataset ({rest2}).")
486 else:
487 # Component dataset types are not explicitly in the
488 # registry. This complicates consistency checks with
489 # registry and requires we work out the composite storage
490 # class.
491 registryDatasetType = None
492 try:
493 registryDatasetType = registry.getDatasetType(c.name)
494 except KeyError:
495 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
496 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
497 if componentName else None
498 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
499 c.storageClass,
500 parentStorageClass=parentStorageClass)
501 registryDatasetType = datasetType
502 else:
503 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
504 c.storageClass,
505 parentStorageClass=registryDatasetType.parentStorageClass)
507 if registryDatasetType and datasetType != registryDatasetType:
508 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
509 f"registry definition ({registryDatasetType}) "
510 f"for {taskDef.label}.")
511 datasetTypes.add(datasetType)
512 if freeze:
513 datasetTypes.freeze()
514 return datasetTypes
516 # optionally add output dataset for metadata
517 outputs = makeDatasetTypesSet("outputs", freeze=False)
518 if taskDef.metadataDatasetName is not None:
519 # Metadata is supposed to be of the PropertySet type, its dimensions
520 # correspond to a task quantum
521 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
522 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
523 outputs.freeze()
525 return cls(
526 initInputs=makeDatasetTypesSet("initInputs"),
527 initOutputs=makeDatasetTypesSet("initOutputs"),
528 inputs=makeDatasetTypesSet("inputs"),
529 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
530 outputs=outputs,
531 )
534@dataclass(frozen=True)
535class PipelineDatasetTypes:
536 """An immutable struct that classifies the dataset types used in a
537 `Pipeline`.
538 """
540 initInputs: NamedValueSet[DatasetType]
541 """Dataset types that are needed as inputs in order to construct the Tasks
542 in this Pipeline.
544 This does not include dataset types that are produced when constructing
545 other Tasks in the Pipeline (these are classified as `initIntermediates`).
546 """
548 initOutputs: NamedValueSet[DatasetType]
549 """Dataset types that may be written after constructing the Tasks in this
550 Pipeline.
552 This does not include dataset types that are also used as inputs when
553 constructing other Tasks in the Pipeline (these are classified as
554 `initIntermediates`).
555 """
557 initIntermediates: NamedValueSet[DatasetType]
558 """Dataset types that are both used when constructing one or more Tasks
559 in the Pipeline and produced as a side-effect of constructing another
560 Task in the Pipeline.
561 """
563 inputs: NamedValueSet[DatasetType]
564 """Dataset types that are regular inputs for the full pipeline.
566 If an input dataset needed for a Quantum cannot be found in the input
567 collection(s), that Quantum (and all dependent Quanta) will not be
568 produced.
569 """
571 prerequisites: NamedValueSet[DatasetType]
572 """Dataset types that are prerequisite inputs for the full Pipeline.
574 Prerequisite inputs must exist in the input collection(s) before the
575 pipeline is run, but do not constrain the graph - if a prerequisite is
576 missing for a Quantum, `PrerequisiteMissingError` is raised.
578 Prerequisite inputs are not resolved until the second stage of
579 QuantumGraph generation.
580 """
582 intermediates: NamedValueSet[DatasetType]
583 """Dataset types that are output by one Task in the Pipeline and consumed
584 as inputs by one or more other Tasks in the Pipeline.
585 """
587 outputs: NamedValueSet[DatasetType]
588 """Dataset types that are output by a Task in the Pipeline and not consumed
589 by any other Task in the Pipeline.
590 """
592 byTask: Mapping[str, TaskDatasetTypes]
593 """Per-Task dataset types, keyed by label in the `Pipeline`.
595 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
596 neither has been modified since the dataset types were extracted, of
597 course).
598 """
600 @classmethod
601 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
602 """Extract and classify the dataset types from all tasks in a
603 `Pipeline`.
605 Parameters
606 ----------
607 pipeline: `Pipeline`
608 An ordered collection of tasks that can be run together.
609 registry: `Registry`
610 Registry used to construct normalized `DatasetType` objects and
611 retrieve those that are incomplete.
613 Returns
614 -------
615 types: `PipelineDatasetTypes`
616 The dataset types used by this `Pipeline`.
618 Raises
619 ------
620 ValueError
621 Raised if Tasks are inconsistent about which datasets are marked
622 prerequisite. This indicates that the Tasks cannot be run as part
623 of the same `Pipeline`.
624 """
625 allInputs = NamedValueSet()
626 allOutputs = NamedValueSet()
627 allInitInputs = NamedValueSet()
628 allInitOutputs = NamedValueSet()
629 prerequisites = NamedValueSet()
630 byTask = dict()
631 if isinstance(pipeline, Pipeline):
632 pipeline = pipeline.toExpandedPipeline()
633 for taskDef in pipeline:
634 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
635 allInitInputs |= thisTask.initInputs
636 allInitOutputs |= thisTask.initOutputs
637 allInputs |= thisTask.inputs
638 prerequisites |= thisTask.prerequisites
639 allOutputs |= thisTask.outputs
640 byTask[taskDef.label] = thisTask
641 if not prerequisites.isdisjoint(allInputs):
642 raise ValueError("{} marked as both prerequisites and regular inputs".format(
643 {dt.name for dt in allInputs & prerequisites}
644 ))
645 if not prerequisites.isdisjoint(allOutputs):
646 raise ValueError("{} marked as both prerequisites and outputs".format(
647 {dt.name for dt in allOutputs & prerequisites}
648 ))
649 # Make sure that components which are marked as inputs get treated as
650 # intermediates if there is an output which produces the composite
651 # containing the component
652 intermediateComponents = NamedValueSet()
653 intermediateComposites = NamedValueSet()
654 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
655 for dsType in allInputs:
656 # get the name of a possible component
657 name, component = dsType.nameAndComponent()
658 # if there is a component name, that means this is a component
659 # DatasetType, if there is an output which produces the parent of
660 # this component, treat this input as an intermediate
661 if component is not None:
662 if name in outputNameMapping:
663 if outputNameMapping[name].dimensions != dsType.dimensions:
664 raise ValueError(f"Component dataset type {dsType.name} has different "
665 f"dimensions ({dsType.dimensions}) than its parent "
666 f"({outputNameMapping[name].dimensions}).")
667 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
668 universe=registry.dimensions)
669 intermediateComponents.add(dsType)
670 intermediateComposites.add(composite)
672 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
673 common = a.names & b.names
674 for name in common:
675 if a[name] != b[name]:
676 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
678 checkConsistency(allInitInputs, allInitOutputs)
679 checkConsistency(allInputs, allOutputs)
680 checkConsistency(allInputs, intermediateComposites)
681 checkConsistency(allOutputs, intermediateComposites)
683 def frozen(s: NamedValueSet) -> NamedValueSet:
684 s.freeze()
685 return s
687 return cls(
688 initInputs=frozen(allInitInputs - allInitOutputs),
689 initIntermediates=frozen(allInitInputs & allInitOutputs),
690 initOutputs=frozen(allInitOutputs - allInitInputs),
691 inputs=frozen(allInputs - allOutputs - intermediateComponents),
692 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
693 outputs=frozen(allOutputs - allInputs - intermediateComposites),
694 prerequisites=frozen(prerequisites),
695 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
696 )