Coverage for python/lsst/pipe/base/pipeline.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Union, Generator, TYPE_CHECKING
35import copy
37# -----------------------------
38# Imports for other modules --
39from lsst.daf.butler import DatasetType, Registry, SkyPixDimension
40from lsst.daf.butler.core.utils import NamedValueSet
41from lsst.utils import doImport
42from .configOverrides import ConfigOverrides
43from .connections import iterConnections
44from .pipelineTask import PipelineTask
46from . import pipelineIR
47from . import pipeTools
49if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 from lsst.obs.base.instrument import Instrument
52# ----------------------------------
53# Local non-exported definitions --
54# ----------------------------------
56# ------------------------
57# Exported definitions --
58# ------------------------
61class TaskDef:
62 """TaskDef is a collection of information about task needed by Pipeline.
64 The information includes task name, configuration object and optional
65 task class. This class is just a collection of attributes and it exposes
66 all of them so that attributes could potentially be modified in place
67 (e.g. if configuration needs extra overrides).
69 Attributes
70 ----------
71 taskName : `str`
72 `PipelineTask` class name, currently it is not specified whether this
73 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
74 Framework should be prepared to handle all cases.
75 config : `lsst.pex.config.Config`
76 Instance of the configuration class corresponding to this task class,
77 usually with all overrides applied.
78 taskClass : `type` or ``None``
79 `PipelineTask` class object, can be ``None``. If ``None`` then
80 framework will have to locate and load class.
81 label : `str`, optional
82 Task label, usually a short string unique in a pipeline.
83 """
84 def __init__(self, taskName, config, taskClass=None, label=""):
85 self.taskName = taskName
86 self.config = config
87 self.taskClass = taskClass
88 self.label = label
89 self.connections = config.connections.ConnectionsClass(config=config)
91 @property
92 def metadataDatasetName(self):
93 """Name of a dataset type for metadata of this task, `None` if
94 metadata is not to be saved (`str`)
95 """
96 if self.config.saveMetadata:
97 return self.label + "_metadata"
98 else:
99 return None
101 def __str__(self):
102 rep = "TaskDef(" + self.taskName
103 if self.label:
104 rep += ", label=" + self.label
105 rep += ")"
106 return rep
109class Pipeline:
110 """A `Pipeline` is a representation of a series of tasks to run, and the
111 configuration for those tasks.
113 Parameters
114 ----------
115 description : `str`
116 A description of that this pipeline does.
117 """
118 def __init__(self, description: str) -> Pipeline:
119 pipeline_dict = {"description": description, "tasks": {}}
120 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
122 @classmethod
123 def fromFile(cls, filename: str) -> Pipeline:
124 """Load a pipeline defined in a pipeline yaml file.
126 Parameters
127 ----------
128 filename: `str`
129 A path that points to a pipeline defined in yaml format
131 Returns
132 -------
133 pipeline: `Pipeline`
134 """
135 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
136 return pipeline
138 @classmethod
139 def fromString(cls, pipeline_string: str) -> Pipeline:
140 """Create a pipeline from string formatted as a pipeline document.
142 Parameters
143 ----------
144 pipeline_string : `str`
145 A string that is formatted according like a pipeline document
147 Returns
148 -------
149 pipeline: `Pipeline`
150 """
151 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
152 return pipeline
154 @classmethod
155 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
156 """Create a pipeline from an already created `PipelineIR` object.
158 Parameters
159 ----------
160 deserialized_pipeline: `PipelineIR`
161 An already created pipeline intermediate representation object
163 Returns
164 -------
165 pipeline: `Pipeline`
166 """
167 pipeline = cls.__new__(cls)
168 pipeline._pipelineIR = deserialized_pipeline
169 return pipeline
171 @classmethod
172 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
173 """Create a new pipeline by copying an already existing `Pipeline`.
175 Parameters
176 ----------
177 pipeline: `Pipeline`
178 An already created pipeline intermediate representation object
180 Returns
181 -------
182 pipeline: `Pipeline`
183 """
184 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
186 def __str__(self) -> str:
187 return str(self._pipelineIR)
189 def addInstrument(self, instrument: Union[Instrument, str]):
190 """Add an instrument to the pipeline, or replace an instrument that is
191 already defined.
193 Parameters
194 ----------
195 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
196 Either a derived class object of a `lsst.daf.butler.instrument` or a
197 string corresponding to a fully qualified
198 `lsst.daf.butler.instrument` name.
199 """
200 if isinstance(instrument, str):
201 pass
202 else:
203 # TODO: assume that this is a subclass of Instrument, no type checking
204 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
205 self._pipelineIR.instrument = instrument
207 def addTask(self, task: Union[PipelineTask, str], label: str):
208 """Add a new task to the pipeline, or replace a task that is already
209 associated with the supplied label.
211 Parameters
212 ----------
213 task: `PipelineTask` or `str`
214 Either a derived class object of a `PipelineTask` or a string
215 corresponding to a fully qualified `PipelineTask` name.
216 label: `str`
217 A label that is used to identify the `PipelineTask` being added
218 """
219 if isinstance(task, str):
220 taskName = task
221 elif issubclass(task, PipelineTask):
222 taskName = f"{task.__module__}.{task.__qualname__}"
223 else:
224 raise ValueError("task must be either a child class of PipelineTask or a string containing"
225 " a fully qualified name to one")
226 if not label:
227 # in some cases (with command line-generated pipeline) tasks can
228 # be defined without label which is not acceptable, use task
229 # _DefaultName in that case
230 if isinstance(task, str):
231 task = doImport(task)
232 label = task._DefaultName
233 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
235 def removeTask(self, label: str):
236 """Remove a task from the pipeline.
238 Parameters
239 ----------
240 label : `str`
241 The label used to identify the task that is to be removed
243 Raises
244 ------
245 KeyError
246 If no task with that label exists in the pipeline
248 """
249 self._pipelineIR.tasks.pop(label)
251 def addConfigOverride(self, label: str, key: str, value: object):
252 """Apply single config override.
254 Parameters
255 ----------
256 label : `str`
257 Label of the task.
258 key: `str`
259 Fully-qualified field name.
260 value : object
261 Value to be given to a field.
262 """
263 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
265 def addConfigFile(self, label: str, filename: str):
266 """Add overrides from a specified file.
268 Parameters
269 ----------
270 label : `str`
271 The label used to identify the task associated with config to
272 modify
273 filename : `str`
274 Path to the override file.
275 """
276 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
278 def addConfigPython(self, label: str, pythonString: str):
279 """Add Overrides by running a snippet of python code against a config.
281 Parameters
282 ----------
283 label : `str`
284 The label used to identity the task associated with config to
285 modify.
286 pythonString: `str`
287 A string which is valid python code to be executed. This is done
288 with config as the only local accessible value.
289 """
290 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
292 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
293 if label not in self._pipelineIR.tasks:
294 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
295 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
297 def toFile(self, filename: str):
298 self._pipelineIR.to_file(filename)
300 def toExpandedPipeline(self) -> Generator[TaskDef]:
301 """Returns a generator of TaskDefs which can be used to create quantum
302 graphs.
304 Returns
305 -------
306 generator : generator of `TaskDef`
307 The generator returned will be the sorted iterator of tasks which
308 are to be used in constructing a quantum graph.
310 Raises
311 ------
312 NotImplementedError
313 If a dataId is supplied in a config block. This is in place for
314 future use
315 """
316 taskDefs = []
317 for label, taskIR in self._pipelineIR.tasks.items():
318 taskClass = doImport(taskIR.klass)
319 taskName = taskClass.__qualname__
320 config = taskClass.ConfigClass()
321 overrides = ConfigOverrides()
322 if self._pipelineIR.instrument is not None:
323 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
324 if taskIR.config is not None:
325 for configIR in taskIR.config:
326 if configIR.dataId is not None:
327 raise NotImplementedError("Specializing a config on a partial data id is not yet "
328 "supported in Pipeline definition")
329 # only apply override if it applies to everything
330 if configIR.dataId is None:
331 if configIR.file:
332 for configFile in configIR.file:
333 overrides.addFileOverride(configFile)
334 if configIR.python is not None:
335 overrides.addPythonOverride(configIR.python)
336 for key, value in configIR.rest.items():
337 overrides.addValueOverride(key, value)
338 overrides.applyTo(config)
339 # This may need to be revisited
340 config.validate()
341 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
343 # lets evaluate the contracts
344 if self._pipelineIR.contracts is not None:
345 label_to_config = {x.label: x.config for x in taskDefs}
346 for contract in self._pipelineIR.contracts:
347 # execute this in its own line so it can raise a good error message if there was problems
348 # with the eval
349 success = eval(contract.contract, None, label_to_config)
350 if not success:
351 extra_info = f": {contract.msg}" if contract.msg is not None else ""
352 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
353 f"satisfied{extra_info}")
355 yield from pipeTools.orderPipeline(taskDefs)
357 def __len__(self):
358 return len(self._pipelineIR.tasks)
360 def __eq__(self, other: "Pipeline"):
361 if not isinstance(other, Pipeline):
362 return False
363 return self._pipelineIR == other._pipelineIR
366@dataclass(frozen=True)
367class TaskDatasetTypes:
368 """An immutable struct that extracts and classifies the dataset types used
369 by a `PipelineTask`
370 """
372 initInputs: NamedValueSet[DatasetType]
373 """Dataset types that are needed as inputs in order to construct this Task.
375 Task-level `initInputs` may be classified as either
376 `~PipelineDatasetTypes.initInputs` or
377 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
378 """
380 initOutputs: NamedValueSet[DatasetType]
381 """Dataset types that may be written after constructing this Task.
383 Task-level `initOutputs` may be classified as either
384 `~PipelineDatasetTypes.initOutputs` or
385 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
386 """
388 inputs: NamedValueSet[DatasetType]
389 """Dataset types that are regular inputs to this Task.
391 If an input dataset needed for a Quantum cannot be found in the input
392 collection(s) or produced by another Task in the Pipeline, that Quantum
393 (and all dependent Quanta) will not be produced.
395 Task-level `inputs` may be classified as either
396 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
397 at the Pipeline level.
398 """
400 prerequisites: NamedValueSet[DatasetType]
401 """Dataset types that are prerequisite inputs to this Task.
403 Prerequisite inputs must exist in the input collection(s) before the
404 pipeline is run, but do not constrain the graph - if a prerequisite is
405 missing for a Quantum, `PrerequisiteMissingError` is raised.
407 Prerequisite inputs are not resolved until the second stage of
408 QuantumGraph generation.
409 """
411 outputs: NamedValueSet[DatasetType]
412 """Dataset types that are produced by this Task.
414 Task-level `outputs` may be classified as either
415 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
416 at the Pipeline level.
417 """
419 @classmethod
420 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
421 """Extract and classify the dataset types from a single `PipelineTask`.
423 Parameters
424 ----------
425 taskDef: `TaskDef`
426 An instance of a `TaskDef` class for a particular `PipelineTask`.
427 registry: `Registry`
428 Registry used to construct normalized `DatasetType` objects and
429 retrieve those that are incomplete.
431 Returns
432 -------
433 types: `TaskDatasetTypes`
434 The dataset types used by this task.
435 """
436 def makeDatasetTypesSet(connectionType, freeze=True):
437 """Constructs a set of true `DatasetType` objects
439 Parameters
440 ----------
441 connectionType : `str`
442 Name of the connection type to produce a set for, corresponds
443 to an attribute of type `list` on the connection class instance
444 freeze : `bool`, optional
445 If `True`, call `NamedValueSet.freeze` on the object returned.
447 Returns
448 -------
449 datasetTypes : `NamedValueSet`
450 A set of all datasetTypes which correspond to the input
451 connection type specified in the connection class of this
452 `PipelineTask`
454 Notes
455 -----
456 This function is a closure over the variables ``registry`` and
457 ``taskDef``.
458 """
459 datasetTypes = NamedValueSet()
460 for c in iterConnections(taskDef.connections, connectionType):
461 dimensions = set(getattr(c, 'dimensions', set()))
462 if "skypix" in dimensions:
463 try:
464 datasetType = registry.getDatasetType(c.name)
465 except LookupError as err:
466 raise LookupError(
467 f"DatasetType '{c.name}' referenced by "
468 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
469 f"placeholder, but does not already exist in the registry. "
470 f"Note that reference catalog names are now used as the dataset "
471 f"type name instead of 'ref_cat'."
472 ) from err
473 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
474 rest2 = set(dim.name for dim in datasetType.dimensions
475 if not isinstance(dim, SkyPixDimension))
476 if rest1 != rest2:
477 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
478 f"connections ({rest1}) are inconsistent with those in "
479 f"registry's version of this dataset ({rest2}).")
480 else:
481 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
482 c.storageClass)
483 try:
484 registryDatasetType = registry.getDatasetType(c.name)
485 except KeyError:
486 registryDatasetType = datasetType
487 if datasetType != registryDatasetType:
488 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
489 f"registry definition ({registryDatasetType})")
490 datasetTypes.add(datasetType)
491 if freeze:
492 datasetTypes.freeze()
493 return datasetTypes
495 # optionally add output dataset for metadata
496 outputs = makeDatasetTypesSet("outputs", freeze=False)
497 if taskDef.metadataDatasetName is not None:
498 # Metadata is supposed to be of the PropertyList type, its dimensions
499 # correspond to a task quantum
500 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
501 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertyList")}
502 outputs.freeze()
504 return cls(
505 initInputs=makeDatasetTypesSet("initInputs"),
506 initOutputs=makeDatasetTypesSet("initOutputs"),
507 inputs=makeDatasetTypesSet("inputs"),
508 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
509 outputs=outputs,
510 )
513@dataclass(frozen=True)
514class PipelineDatasetTypes:
515 """An immutable struct that classifies the dataset types used in a
516 `Pipeline`.
517 """
519 initInputs: NamedValueSet[DatasetType]
520 """Dataset types that are needed as inputs in order to construct the Tasks
521 in this Pipeline.
523 This does not include dataset types that are produced when constructing
524 other Tasks in the Pipeline (these are classified as `initIntermediates`).
525 """
527 initOutputs: NamedValueSet[DatasetType]
528 """Dataset types that may be written after constructing the Tasks in this
529 Pipeline.
531 This does not include dataset types that are also used as inputs when
532 constructing other Tasks in the Pipeline (these are classified as
533 `initIntermediates`).
534 """
536 initIntermediates: NamedValueSet[DatasetType]
537 """Dataset types that are both used when constructing one or more Tasks
538 in the Pipeline and produced as a side-effect of constructing another
539 Task in the Pipeline.
540 """
542 inputs: NamedValueSet[DatasetType]
543 """Dataset types that are regular inputs for the full pipeline.
545 If an input dataset needed for a Quantum cannot be found in the input
546 collection(s), that Quantum (and all dependent Quanta) will not be
547 produced.
548 """
550 prerequisites: NamedValueSet[DatasetType]
551 """Dataset types that are prerequisite inputs for the full Pipeline.
553 Prerequisite inputs must exist in the input collection(s) before the
554 pipeline is run, but do not constrain the graph - if a prerequisite is
555 missing for a Quantum, `PrerequisiteMissingError` is raised.
557 Prerequisite inputs are not resolved until the second stage of
558 QuantumGraph generation.
559 """
561 intermediates: NamedValueSet[DatasetType]
562 """Dataset types that are output by one Task in the Pipeline and consumed
563 as inputs by one or more other Tasks in the Pipeline.
564 """
566 outputs: NamedValueSet[DatasetType]
567 """Dataset types that are output by a Task in the Pipeline and not consumed
568 by any other Task in the Pipeline.
569 """
571 byTask: Mapping[str, TaskDatasetTypes]
572 """Per-Task dataset types, keyed by label in the `Pipeline`.
574 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
575 neither has been modified since the dataset types were extracted, of
576 course).
577 """
579 @classmethod
580 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
581 """Extract and classify the dataset types from all tasks in a
582 `Pipeline`.
584 Parameters
585 ----------
586 pipeline: `Pipeline`
587 An ordered collection of tasks that can be run together.
588 registry: `Registry`
589 Registry used to construct normalized `DatasetType` objects and
590 retrieve those that are incomplete.
592 Returns
593 -------
594 types: `PipelineDatasetTypes`
595 The dataset types used by this `Pipeline`.
597 Raises
598 ------
599 ValueError
600 Raised if Tasks are inconsistent about which datasets are marked
601 prerequisite. This indicates that the Tasks cannot be run as part
602 of the same `Pipeline`.
603 """
604 allInputs = NamedValueSet()
605 allOutputs = NamedValueSet()
606 allInitInputs = NamedValueSet()
607 allInitOutputs = NamedValueSet()
608 prerequisites = NamedValueSet()
609 byTask = dict()
610 if isinstance(pipeline, Pipeline):
611 pipeline = pipeline.toExpandedPipeline()
612 for taskDef in pipeline:
613 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
614 allInitInputs |= thisTask.initInputs
615 allInitOutputs |= thisTask.initOutputs
616 allInputs |= thisTask.inputs
617 prerequisites |= thisTask.prerequisites
618 allOutputs |= thisTask.outputs
619 byTask[taskDef.label] = thisTask
620 if not prerequisites.isdisjoint(allInputs):
621 raise ValueError("{} marked as both prerequisites and regular inputs".format(
622 {dt.name for dt in allInputs & prerequisites}
623 ))
624 if not prerequisites.isdisjoint(allOutputs):
625 raise ValueError("{} marked as both prerequisites and outputs".format(
626 {dt.name for dt in allOutputs & prerequisites}
627 ))
628 # Make sure that components which are marked as inputs get treated as
629 # intermediates if there is an output which produces the composite
630 # containing the component
631 intermediateComponents = NamedValueSet()
632 intermediateComposites = NamedValueSet()
633 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
634 for dsType in allInputs:
635 # get the name of a possible component
636 name, component = dsType.nameAndComponent()
637 # if there is a component name, that means this is a component
638 # DatasetType, if there is an output which produces the parent of
639 # this component, treat this input as an intermediate
640 if component is not None:
641 if name in outputNameMapping:
642 if outputNameMapping[name].dimensions != dsType.dimensions:
643 raise ValueError(f"Component dataset type {dsType.name} has different "
644 f"dimensions ({dsType.dimensions}) than its parent "
645 f"({outputNameMapping[name].dimensions}).")
646 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
647 universe=registry.dimensions)
648 intermediateComponents.add(dsType)
649 intermediateComposites.add(composite)
651 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
652 common = a.names & b.names
653 for name in common:
654 if a[name] != b[name]:
655 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
657 checkConsistency(allInitInputs, allInitOutputs)
658 checkConsistency(allInputs, allOutputs)
659 checkConsistency(allInputs, intermediateComposites)
660 checkConsistency(allOutputs, intermediateComposites)
662 def frozen(s: NamedValueSet) -> NamedValueSet:
663 s.freeze()
664 return s
666 return cls(
667 initInputs=frozen(allInitInputs - allInitOutputs),
668 initIntermediates=frozen(allInitInputs & allInitOutputs),
669 initOutputs=frozen(allInitOutputs - allInitInputs),
670 inputs=frozen(allInputs - allOutputs - intermediateComponents),
671 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
672 outputs=frozen(allOutputs - allInputs - intermediateComposites),
673 prerequisites=frozen(prerequisites),
674 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
675 )