Coverage for python/lsst/pipe/base/pipeline.py : 20%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Union, Generator, TYPE_CHECKING
35import copy
37# -----------------------------
38# Imports for other modules --
39from lsst.daf.butler import DatasetType, Registry, SkyPixDimension
40from lsst.daf.butler.core.utils import NamedValueSet
41from lsst.utils import doImport
42from .configOverrides import ConfigOverrides
43from .connections import iterConnections
44from .pipelineTask import PipelineTask
46from . import pipelineIR
47from . import pipeTools
49if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 from lsst.obs.base.instrument import Instrument
52# ----------------------------------
53# Local non-exported definitions --
54# ----------------------------------
56# ------------------------
57# Exported definitions --
58# ------------------------
61class TaskDef:
62 """TaskDef is a collection of information about task needed by Pipeline.
64 The information includes task name, configuration object and optional
65 task class. This class is just a collection of attributes and it exposes
66 all of them so that attributes could potentially be modified in place
67 (e.g. if configuration needs extra overrides).
69 Attributes
70 ----------
71 taskName : `str`
72 `PipelineTask` class name, currently it is not specified whether this
73 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
74 Framework should be prepared to handle all cases.
75 config : `lsst.pex.config.Config`
76 Instance of the configuration class corresponding to this task class,
77 usually with all overrides applied.
78 taskClass : `type` or ``None``
79 `PipelineTask` class object, can be ``None``. If ``None`` then
80 framework will have to locate and load class.
81 label : `str`, optional
82 Task label, usually a short string unique in a pipeline.
83 """
84 def __init__(self, taskName, config, taskClass=None, label=""):
85 self.taskName = taskName
86 self.config = config
87 self.taskClass = taskClass
88 self.label = label
89 self.connections = config.connections.ConnectionsClass(config=config)
91 @property
92 def configDatasetName(self):
93 """Name of a dataset type for configuration of this task (`str`)
94 """
95 return self.label + "_config"
97 @property
98 def metadataDatasetName(self):
99 """Name of a dataset type for metadata of this task, `None` if
100 metadata is not to be saved (`str`)
101 """
102 if self.config.saveMetadata:
103 return self.label + "_metadata"
104 else:
105 return None
107 def __str__(self):
108 rep = "TaskDef(" + self.taskName
109 if self.label:
110 rep += ", label=" + self.label
111 rep += ")"
112 return rep
115class Pipeline:
116 """A `Pipeline` is a representation of a series of tasks to run, and the
117 configuration for those tasks.
119 Parameters
120 ----------
121 description : `str`
122 A description of that this pipeline does.
123 """
124 def __init__(self, description: str) -> Pipeline:
125 pipeline_dict = {"description": description, "tasks": {}}
126 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
128 @classmethod
129 def fromFile(cls, filename: str) -> Pipeline:
130 """Load a pipeline defined in a pipeline yaml file.
132 Parameters
133 ----------
134 filename: `str`
135 A path that points to a pipeline defined in yaml format
137 Returns
138 -------
139 pipeline: `Pipeline`
140 """
141 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
142 return pipeline
144 @classmethod
145 def fromString(cls, pipeline_string: str) -> Pipeline:
146 """Create a pipeline from string formatted as a pipeline document.
148 Parameters
149 ----------
150 pipeline_string : `str`
151 A string that is formatted according like a pipeline document
153 Returns
154 -------
155 pipeline: `Pipeline`
156 """
157 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
158 return pipeline
160 @classmethod
161 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
162 """Create a pipeline from an already created `PipelineIR` object.
164 Parameters
165 ----------
166 deserialized_pipeline: `PipelineIR`
167 An already created pipeline intermediate representation object
169 Returns
170 -------
171 pipeline: `Pipeline`
172 """
173 pipeline = cls.__new__(cls)
174 pipeline._pipelineIR = deserialized_pipeline
175 return pipeline
177 @classmethod
178 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
179 """Create a new pipeline by copying an already existing `Pipeline`.
181 Parameters
182 ----------
183 pipeline: `Pipeline`
184 An already created pipeline intermediate representation object
186 Returns
187 -------
188 pipeline: `Pipeline`
189 """
190 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
192 def __str__(self) -> str:
193 return str(self._pipelineIR)
195 def addInstrument(self, instrument: Union[Instrument, str]):
196 """Add an instrument to the pipeline, or replace an instrument that is
197 already defined.
199 Parameters
200 ----------
201 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
202 Either a derived class object of a `lsst.daf.butler.instrument` or a
203 string corresponding to a fully qualified
204 `lsst.daf.butler.instrument` name.
205 """
206 if isinstance(instrument, str):
207 pass
208 else:
209 # TODO: assume that this is a subclass of Instrument, no type checking
210 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
211 self._pipelineIR.instrument = instrument
213 def addTask(self, task: Union[PipelineTask, str], label: str):
214 """Add a new task to the pipeline, or replace a task that is already
215 associated with the supplied label.
217 Parameters
218 ----------
219 task: `PipelineTask` or `str`
220 Either a derived class object of a `PipelineTask` or a string
221 corresponding to a fully qualified `PipelineTask` name.
222 label: `str`
223 A label that is used to identify the `PipelineTask` being added
224 """
225 if isinstance(task, str):
226 taskName = task
227 elif issubclass(task, PipelineTask):
228 taskName = f"{task.__module__}.{task.__qualname__}"
229 else:
230 raise ValueError("task must be either a child class of PipelineTask or a string containing"
231 " a fully qualified name to one")
232 if not label:
233 # in some cases (with command line-generated pipeline) tasks can
234 # be defined without label which is not acceptable, use task
235 # _DefaultName in that case
236 if isinstance(task, str):
237 task = doImport(task)
238 label = task._DefaultName
239 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
241 def removeTask(self, label: str):
242 """Remove a task from the pipeline.
244 Parameters
245 ----------
246 label : `str`
247 The label used to identify the task that is to be removed
249 Raises
250 ------
251 KeyError
252 If no task with that label exists in the pipeline
254 """
255 self._pipelineIR.tasks.pop(label)
257 def addConfigOverride(self, label: str, key: str, value: object):
258 """Apply single config override.
260 Parameters
261 ----------
262 label : `str`
263 Label of the task.
264 key: `str`
265 Fully-qualified field name.
266 value : object
267 Value to be given to a field.
268 """
269 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
271 def addConfigFile(self, label: str, filename: str):
272 """Add overrides from a specified file.
274 Parameters
275 ----------
276 label : `str`
277 The label used to identify the task associated with config to
278 modify
279 filename : `str`
280 Path to the override file.
281 """
282 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
284 def addConfigPython(self, label: str, pythonString: str):
285 """Add Overrides by running a snippet of python code against a config.
287 Parameters
288 ----------
289 label : `str`
290 The label used to identity the task associated with config to
291 modify.
292 pythonString: `str`
293 A string which is valid python code to be executed. This is done
294 with config as the only local accessible value.
295 """
296 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
298 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
299 if label not in self._pipelineIR.tasks:
300 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
301 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
303 def toFile(self, filename: str):
304 self._pipelineIR.to_file(filename)
306 def toExpandedPipeline(self) -> Generator[TaskDef]:
307 """Returns a generator of TaskDefs which can be used to create quantum
308 graphs.
310 Returns
311 -------
312 generator : generator of `TaskDef`
313 The generator returned will be the sorted iterator of tasks which
314 are to be used in constructing a quantum graph.
316 Raises
317 ------
318 NotImplementedError
319 If a dataId is supplied in a config block. This is in place for
320 future use
321 """
322 taskDefs = []
323 for label, taskIR in self._pipelineIR.tasks.items():
324 taskClass = doImport(taskIR.klass)
325 taskName = taskClass.__qualname__
326 config = taskClass.ConfigClass()
327 overrides = ConfigOverrides()
328 if self._pipelineIR.instrument is not None:
329 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
330 if taskIR.config is not None:
331 for configIR in taskIR.config:
332 if configIR.dataId is not None:
333 raise NotImplementedError("Specializing a config on a partial data id is not yet "
334 "supported in Pipeline definition")
335 # only apply override if it applies to everything
336 if configIR.dataId is None:
337 if configIR.file:
338 for configFile in configIR.file:
339 overrides.addFileOverride(configFile)
340 if configIR.python is not None:
341 overrides.addPythonOverride(configIR.python)
342 for key, value in configIR.rest.items():
343 overrides.addValueOverride(key, value)
344 overrides.applyTo(config)
345 # This may need to be revisited
346 config.validate()
347 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
349 # lets evaluate the contracts
350 if self._pipelineIR.contracts is not None:
351 label_to_config = {x.label: x.config for x in taskDefs}
352 for contract in self._pipelineIR.contracts:
353 # execute this in its own line so it can raise a good error message if there was problems
354 # with the eval
355 success = eval(contract.contract, None, label_to_config)
356 if not success:
357 extra_info = f": {contract.msg}" if contract.msg is not None else ""
358 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
359 f"satisfied{extra_info}")
361 yield from pipeTools.orderPipeline(taskDefs)
363 def __len__(self):
364 return len(self._pipelineIR.tasks)
366 def __eq__(self, other: "Pipeline"):
367 if not isinstance(other, Pipeline):
368 return False
369 return self._pipelineIR == other._pipelineIR
372@dataclass(frozen=True)
373class TaskDatasetTypes:
374 """An immutable struct that extracts and classifies the dataset types used
375 by a `PipelineTask`
376 """
378 initInputs: NamedValueSet[DatasetType]
379 """Dataset types that are needed as inputs in order to construct this Task.
381 Task-level `initInputs` may be classified as either
382 `~PipelineDatasetTypes.initInputs` or
383 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
384 """
386 initOutputs: NamedValueSet[DatasetType]
387 """Dataset types that may be written after constructing this Task.
389 Task-level `initOutputs` may be classified as either
390 `~PipelineDatasetTypes.initOutputs` or
391 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
392 """
394 inputs: NamedValueSet[DatasetType]
395 """Dataset types that are regular inputs to this Task.
397 If an input dataset needed for a Quantum cannot be found in the input
398 collection(s) or produced by another Task in the Pipeline, that Quantum
399 (and all dependent Quanta) will not be produced.
401 Task-level `inputs` may be classified as either
402 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
403 at the Pipeline level.
404 """
406 prerequisites: NamedValueSet[DatasetType]
407 """Dataset types that are prerequisite inputs to this Task.
409 Prerequisite inputs must exist in the input collection(s) before the
410 pipeline is run, but do not constrain the graph - if a prerequisite is
411 missing for a Quantum, `PrerequisiteMissingError` is raised.
413 Prerequisite inputs are not resolved until the second stage of
414 QuantumGraph generation.
415 """
417 outputs: NamedValueSet[DatasetType]
418 """Dataset types that are produced by this Task.
420 Task-level `outputs` may be classified as either
421 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
422 at the Pipeline level.
423 """
425 @classmethod
426 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
427 """Extract and classify the dataset types from a single `PipelineTask`.
429 Parameters
430 ----------
431 taskDef: `TaskDef`
432 An instance of a `TaskDef` class for a particular `PipelineTask`.
433 registry: `Registry`
434 Registry used to construct normalized `DatasetType` objects and
435 retrieve those that are incomplete.
437 Returns
438 -------
439 types: `TaskDatasetTypes`
440 The dataset types used by this task.
441 """
442 def makeDatasetTypesSet(connectionType, freeze=True):
443 """Constructs a set of true `DatasetType` objects
445 Parameters
446 ----------
447 connectionType : `str`
448 Name of the connection type to produce a set for, corresponds
449 to an attribute of type `list` on the connection class instance
450 freeze : `bool`, optional
451 If `True`, call `NamedValueSet.freeze` on the object returned.
453 Returns
454 -------
455 datasetTypes : `NamedValueSet`
456 A set of all datasetTypes which correspond to the input
457 connection type specified in the connection class of this
458 `PipelineTask`
460 Notes
461 -----
462 This function is a closure over the variables ``registry`` and
463 ``taskDef``.
464 """
465 datasetTypes = NamedValueSet()
466 for c in iterConnections(taskDef.connections, connectionType):
467 dimensions = set(getattr(c, 'dimensions', set()))
468 if "skypix" in dimensions:
469 try:
470 datasetType = registry.getDatasetType(c.name)
471 except LookupError as err:
472 raise LookupError(
473 f"DatasetType '{c.name}' referenced by "
474 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
475 f"placeholder, but does not already exist in the registry. "
476 f"Note that reference catalog names are now used as the dataset "
477 f"type name instead of 'ref_cat'."
478 ) from err
479 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
480 rest2 = set(dim.name for dim in datasetType.dimensions
481 if not isinstance(dim, SkyPixDimension))
482 if rest1 != rest2:
483 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
484 f"connections ({rest1}) are inconsistent with those in "
485 f"registry's version of this dataset ({rest2}).")
486 else:
487 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
488 c.storageClass)
489 try:
490 registryDatasetType = registry.getDatasetType(c.name)
491 except KeyError:
492 registryDatasetType = datasetType
493 if datasetType != registryDatasetType:
494 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
495 f"registry definition ({registryDatasetType})")
496 datasetTypes.add(datasetType)
497 if freeze:
498 datasetTypes.freeze()
499 return datasetTypes
501 # optionally add output dataset for metadata
502 outputs = makeDatasetTypesSet("outputs", freeze=False)
503 if taskDef.metadataDatasetName is not None:
504 # Metadata is supposed to be of the PropertyList type, its dimensions
505 # correspond to a task quantum
506 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
507 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertyList")}
508 outputs.freeze()
510 return cls(
511 initInputs=makeDatasetTypesSet("initInputs"),
512 initOutputs=makeDatasetTypesSet("initOutputs"),
513 inputs=makeDatasetTypesSet("inputs"),
514 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
515 outputs=outputs,
516 )
519@dataclass(frozen=True)
520class PipelineDatasetTypes:
521 """An immutable struct that classifies the dataset types used in a
522 `Pipeline`.
523 """
525 initInputs: NamedValueSet[DatasetType]
526 """Dataset types that are needed as inputs in order to construct the Tasks
527 in this Pipeline.
529 This does not include dataset types that are produced when constructing
530 other Tasks in the Pipeline (these are classified as `initIntermediates`).
531 """
533 initOutputs: NamedValueSet[DatasetType]
534 """Dataset types that may be written after constructing the Tasks in this
535 Pipeline.
537 This does not include dataset types that are also used as inputs when
538 constructing other Tasks in the Pipeline (these are classified as
539 `initIntermediates`).
540 """
542 initIntermediates: NamedValueSet[DatasetType]
543 """Dataset types that are both used when constructing one or more Tasks
544 in the Pipeline and produced as a side-effect of constructing another
545 Task in the Pipeline.
546 """
548 inputs: NamedValueSet[DatasetType]
549 """Dataset types that are regular inputs for the full pipeline.
551 If an input dataset needed for a Quantum cannot be found in the input
552 collection(s), that Quantum (and all dependent Quanta) will not be
553 produced.
554 """
556 prerequisites: NamedValueSet[DatasetType]
557 """Dataset types that are prerequisite inputs for the full Pipeline.
559 Prerequisite inputs must exist in the input collection(s) before the
560 pipeline is run, but do not constrain the graph - if a prerequisite is
561 missing for a Quantum, `PrerequisiteMissingError` is raised.
563 Prerequisite inputs are not resolved until the second stage of
564 QuantumGraph generation.
565 """
567 intermediates: NamedValueSet[DatasetType]
568 """Dataset types that are output by one Task in the Pipeline and consumed
569 as inputs by one or more other Tasks in the Pipeline.
570 """
572 outputs: NamedValueSet[DatasetType]
573 """Dataset types that are output by a Task in the Pipeline and not consumed
574 by any other Task in the Pipeline.
575 """
577 byTask: Mapping[str, TaskDatasetTypes]
578 """Per-Task dataset types, keyed by label in the `Pipeline`.
580 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
581 neither has been modified since the dataset types were extracted, of
582 course).
583 """
585 @classmethod
586 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
587 """Extract and classify the dataset types from all tasks in a
588 `Pipeline`.
590 Parameters
591 ----------
592 pipeline: `Pipeline`
593 An ordered collection of tasks that can be run together.
594 registry: `Registry`
595 Registry used to construct normalized `DatasetType` objects and
596 retrieve those that are incomplete.
598 Returns
599 -------
600 types: `PipelineDatasetTypes`
601 The dataset types used by this `Pipeline`.
603 Raises
604 ------
605 ValueError
606 Raised if Tasks are inconsistent about which datasets are marked
607 prerequisite. This indicates that the Tasks cannot be run as part
608 of the same `Pipeline`.
609 """
610 allInputs = NamedValueSet()
611 allOutputs = NamedValueSet()
612 allInitInputs = NamedValueSet()
613 allInitOutputs = NamedValueSet()
614 prerequisites = NamedValueSet()
615 byTask = dict()
616 if isinstance(pipeline, Pipeline):
617 pipeline = pipeline.toExpandedPipeline()
618 for taskDef in pipeline:
619 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
620 allInitInputs |= thisTask.initInputs
621 allInitOutputs |= thisTask.initOutputs
622 allInputs |= thisTask.inputs
623 prerequisites |= thisTask.prerequisites
624 allOutputs |= thisTask.outputs
625 byTask[taskDef.label] = thisTask
626 if not prerequisites.isdisjoint(allInputs):
627 raise ValueError("{} marked as both prerequisites and regular inputs".format(
628 {dt.name for dt in allInputs & prerequisites}
629 ))
630 if not prerequisites.isdisjoint(allOutputs):
631 raise ValueError("{} marked as both prerequisites and outputs".format(
632 {dt.name for dt in allOutputs & prerequisites}
633 ))
634 # Make sure that components which are marked as inputs get treated as
635 # intermediates if there is an output which produces the composite
636 # containing the component
637 intermediateComponents = NamedValueSet()
638 intermediateComposites = NamedValueSet()
639 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
640 for dsType in allInputs:
641 # get the name of a possible component
642 name, component = dsType.nameAndComponent()
643 # if there is a component name, that means this is a component
644 # DatasetType, if there is an output which produces the parent of
645 # this component, treat this input as an intermediate
646 if component is not None:
647 if name in outputNameMapping:
648 if outputNameMapping[name].dimensions != dsType.dimensions:
649 raise ValueError(f"Component dataset type {dsType.name} has different "
650 f"dimensions ({dsType.dimensions}) than its parent "
651 f"({outputNameMapping[name].dimensions}).")
652 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
653 universe=registry.dimensions)
654 intermediateComponents.add(dsType)
655 intermediateComposites.add(composite)
657 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
658 common = a.names & b.names
659 for name in common:
660 if a[name] != b[name]:
661 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
663 checkConsistency(allInitInputs, allInitOutputs)
664 checkConsistency(allInputs, allOutputs)
665 checkConsistency(allInputs, intermediateComposites)
666 checkConsistency(allOutputs, intermediateComposites)
668 def frozen(s: NamedValueSet) -> NamedValueSet:
669 s.freeze()
670 return s
672 return cls(
673 initInputs=frozen(allInitInputs - allInitOutputs),
674 initIntermediates=frozen(allInitInputs & allInitOutputs),
675 initOutputs=frozen(allInitOutputs - allInitInputs),
676 inputs=frozen(allInputs - allOutputs - intermediateComponents),
677 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
678 outputs=frozen(allOutputs - allInputs - intermediateComposites),
679 prerequisites=frozen(prerequisites),
680 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
681 )