Coverage for python/lsst/pipe/base/pipeline.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Union, Generator, TYPE_CHECKING
35import copy
37# -----------------------------
38# Imports for other modules --
39from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
40from lsst.utils import doImport
41from .configOverrides import ConfigOverrides
42from .connections import iterConnections
43from .pipelineTask import PipelineTask
45from . import pipelineIR
46from . import pipeTools
48if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true
49 from lsst.obs.base.instrument import Instrument
51# ----------------------------------
52# Local non-exported definitions --
53# ----------------------------------
55# ------------------------
56# Exported definitions --
57# ------------------------
60class TaskDef:
61 """TaskDef is a collection of information about task needed by Pipeline.
63 The information includes task name, configuration object and optional
64 task class. This class is just a collection of attributes and it exposes
65 all of them so that attributes could potentially be modified in place
66 (e.g. if configuration needs extra overrides).
68 Attributes
69 ----------
70 taskName : `str`
71 `PipelineTask` class name, currently it is not specified whether this
72 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
73 Framework should be prepared to handle all cases.
74 config : `lsst.pex.config.Config`
75 Instance of the configuration class corresponding to this task class,
76 usually with all overrides applied.
77 taskClass : `type` or ``None``
78 `PipelineTask` class object, can be ``None``. If ``None`` then
79 framework will have to locate and load class.
80 label : `str`, optional
81 Task label, usually a short string unique in a pipeline.
82 """
83 def __init__(self, taskName, config, taskClass=None, label=""):
84 self.taskName = taskName
85 self.config = config
86 self.taskClass = taskClass
87 self.label = label
88 self.connections = config.connections.ConnectionsClass(config=config)
90 @property
91 def configDatasetName(self):
92 """Name of a dataset type for configuration of this task (`str`)
93 """
94 return self.label + "_config"
96 @property
97 def metadataDatasetName(self):
98 """Name of a dataset type for metadata of this task, `None` if
99 metadata is not to be saved (`str`)
100 """
101 if self.config.saveMetadata:
102 return self.label + "_metadata"
103 else:
104 return None
106 def __str__(self):
107 rep = "TaskDef(" + self.taskName
108 if self.label:
109 rep += ", label=" + self.label
110 rep += ")"
111 return rep
114class Pipeline:
115 """A `Pipeline` is a representation of a series of tasks to run, and the
116 configuration for those tasks.
118 Parameters
119 ----------
120 description : `str`
121 A description of that this pipeline does.
122 """
123 def __init__(self, description: str) -> Pipeline:
124 pipeline_dict = {"description": description, "tasks": {}}
125 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
127 @classmethod
128 def fromFile(cls, filename: str) -> Pipeline:
129 """Load a pipeline defined in a pipeline yaml file.
131 Parameters
132 ----------
133 filename: `str`
134 A path that points to a pipeline defined in yaml format
136 Returns
137 -------
138 pipeline: `Pipeline`
139 """
140 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
141 return pipeline
143 @classmethod
144 def fromString(cls, pipeline_string: str) -> Pipeline:
145 """Create a pipeline from string formatted as a pipeline document.
147 Parameters
148 ----------
149 pipeline_string : `str`
150 A string that is formatted according like a pipeline document
152 Returns
153 -------
154 pipeline: `Pipeline`
155 """
156 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
157 return pipeline
159 @classmethod
160 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
161 """Create a pipeline from an already created `PipelineIR` object.
163 Parameters
164 ----------
165 deserialized_pipeline: `PipelineIR`
166 An already created pipeline intermediate representation object
168 Returns
169 -------
170 pipeline: `Pipeline`
171 """
172 pipeline = cls.__new__(cls)
173 pipeline._pipelineIR = deserialized_pipeline
174 return pipeline
176 @classmethod
177 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
178 """Create a new pipeline by copying an already existing `Pipeline`.
180 Parameters
181 ----------
182 pipeline: `Pipeline`
183 An already created pipeline intermediate representation object
185 Returns
186 -------
187 pipeline: `Pipeline`
188 """
189 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
191 def __str__(self) -> str:
192 return str(self._pipelineIR)
194 def addInstrument(self, instrument: Union[Instrument, str]):
195 """Add an instrument to the pipeline, or replace an instrument that is
196 already defined.
198 Parameters
199 ----------
200 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
201 Either a derived class object of a `lsst.daf.butler.instrument` or a
202 string corresponding to a fully qualified
203 `lsst.daf.butler.instrument` name.
204 """
205 if isinstance(instrument, str):
206 pass
207 else:
208 # TODO: assume that this is a subclass of Instrument, no type checking
209 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
210 self._pipelineIR.instrument = instrument
212 def addTask(self, task: Union[PipelineTask, str], label: str):
213 """Add a new task to the pipeline, or replace a task that is already
214 associated with the supplied label.
216 Parameters
217 ----------
218 task: `PipelineTask` or `str`
219 Either a derived class object of a `PipelineTask` or a string
220 corresponding to a fully qualified `PipelineTask` name.
221 label: `str`
222 A label that is used to identify the `PipelineTask` being added
223 """
224 if isinstance(task, str):
225 taskName = task
226 elif issubclass(task, PipelineTask):
227 taskName = f"{task.__module__}.{task.__qualname__}"
228 else:
229 raise ValueError("task must be either a child class of PipelineTask or a string containing"
230 " a fully qualified name to one")
231 if not label:
232 # in some cases (with command line-generated pipeline) tasks can
233 # be defined without label which is not acceptable, use task
234 # _DefaultName in that case
235 if isinstance(task, str):
236 task = doImport(task)
237 label = task._DefaultName
238 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
240 def removeTask(self, label: str):
241 """Remove a task from the pipeline.
243 Parameters
244 ----------
245 label : `str`
246 The label used to identify the task that is to be removed
248 Raises
249 ------
250 KeyError
251 If no task with that label exists in the pipeline
253 """
254 self._pipelineIR.tasks.pop(label)
256 def addConfigOverride(self, label: str, key: str, value: object):
257 """Apply single config override.
259 Parameters
260 ----------
261 label : `str`
262 Label of the task.
263 key: `str`
264 Fully-qualified field name.
265 value : object
266 Value to be given to a field.
267 """
268 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
270 def addConfigFile(self, label: str, filename: str):
271 """Add overrides from a specified file.
273 Parameters
274 ----------
275 label : `str`
276 The label used to identify the task associated with config to
277 modify
278 filename : `str`
279 Path to the override file.
280 """
281 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
283 def addConfigPython(self, label: str, pythonString: str):
284 """Add Overrides by running a snippet of python code against a config.
286 Parameters
287 ----------
288 label : `str`
289 The label used to identity the task associated with config to
290 modify.
291 pythonString: `str`
292 A string which is valid python code to be executed. This is done
293 with config as the only local accessible value.
294 """
295 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
297 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
298 if label not in self._pipelineIR.tasks:
299 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
300 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
302 def toFile(self, filename: str):
303 self._pipelineIR.to_file(filename)
305 def toExpandedPipeline(self) -> Generator[TaskDef]:
306 """Returns a generator of TaskDefs which can be used to create quantum
307 graphs.
309 Returns
310 -------
311 generator : generator of `TaskDef`
312 The generator returned will be the sorted iterator of tasks which
313 are to be used in constructing a quantum graph.
315 Raises
316 ------
317 NotImplementedError
318 If a dataId is supplied in a config block. This is in place for
319 future use
320 """
321 taskDefs = []
322 for label, taskIR in self._pipelineIR.tasks.items():
323 taskClass = doImport(taskIR.klass)
324 taskName = taskClass.__qualname__
325 config = taskClass.ConfigClass()
326 overrides = ConfigOverrides()
327 if self._pipelineIR.instrument is not None:
328 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
329 if taskIR.config is not None:
330 for configIR in taskIR.config:
331 if configIR.dataId is not None:
332 raise NotImplementedError("Specializing a config on a partial data id is not yet "
333 "supported in Pipeline definition")
334 # only apply override if it applies to everything
335 if configIR.dataId is None:
336 if configIR.file:
337 for configFile in configIR.file:
338 overrides.addFileOverride(configFile)
339 if configIR.python is not None:
340 overrides.addPythonOverride(configIR.python)
341 for key, value in configIR.rest.items():
342 overrides.addValueOverride(key, value)
343 overrides.applyTo(config)
344 # This may need to be revisited
345 config.validate()
346 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
348 # lets evaluate the contracts
349 if self._pipelineIR.contracts is not None:
350 label_to_config = {x.label: x.config for x in taskDefs}
351 for contract in self._pipelineIR.contracts:
352 # execute this in its own line so it can raise a good error message if there was problems
353 # with the eval
354 success = eval(contract.contract, None, label_to_config)
355 if not success:
356 extra_info = f": {contract.msg}" if contract.msg is not None else ""
357 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
358 f"satisfied{extra_info}")
360 yield from pipeTools.orderPipeline(taskDefs)
362 def __len__(self):
363 return len(self._pipelineIR.tasks)
365 def __eq__(self, other: "Pipeline"):
366 if not isinstance(other, Pipeline):
367 return False
368 return self._pipelineIR == other._pipelineIR
371@dataclass(frozen=True)
372class TaskDatasetTypes:
373 """An immutable struct that extracts and classifies the dataset types used
374 by a `PipelineTask`
375 """
377 initInputs: NamedValueSet[DatasetType]
378 """Dataset types that are needed as inputs in order to construct this Task.
380 Task-level `initInputs` may be classified as either
381 `~PipelineDatasetTypes.initInputs` or
382 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
383 """
385 initOutputs: NamedValueSet[DatasetType]
386 """Dataset types that may be written after constructing this Task.
388 Task-level `initOutputs` may be classified as either
389 `~PipelineDatasetTypes.initOutputs` or
390 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
391 """
393 inputs: NamedValueSet[DatasetType]
394 """Dataset types that are regular inputs to this Task.
396 If an input dataset needed for a Quantum cannot be found in the input
397 collection(s) or produced by another Task in the Pipeline, that Quantum
398 (and all dependent Quanta) will not be produced.
400 Task-level `inputs` may be classified as either
401 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
402 at the Pipeline level.
403 """
405 prerequisites: NamedValueSet[DatasetType]
406 """Dataset types that are prerequisite inputs to this Task.
408 Prerequisite inputs must exist in the input collection(s) before the
409 pipeline is run, but do not constrain the graph - if a prerequisite is
410 missing for a Quantum, `PrerequisiteMissingError` is raised.
412 Prerequisite inputs are not resolved until the second stage of
413 QuantumGraph generation.
414 """
416 outputs: NamedValueSet[DatasetType]
417 """Dataset types that are produced by this Task.
419 Task-level `outputs` may be classified as either
420 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
421 at the Pipeline level.
422 """
424 @classmethod
425 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
426 """Extract and classify the dataset types from a single `PipelineTask`.
428 Parameters
429 ----------
430 taskDef: `TaskDef`
431 An instance of a `TaskDef` class for a particular `PipelineTask`.
432 registry: `Registry`
433 Registry used to construct normalized `DatasetType` objects and
434 retrieve those that are incomplete.
436 Returns
437 -------
438 types: `TaskDatasetTypes`
439 The dataset types used by this task.
440 """
441 def makeDatasetTypesSet(connectionType, freeze=True):
442 """Constructs a set of true `DatasetType` objects
444 Parameters
445 ----------
446 connectionType : `str`
447 Name of the connection type to produce a set for, corresponds
448 to an attribute of type `list` on the connection class instance
449 freeze : `bool`, optional
450 If `True`, call `NamedValueSet.freeze` on the object returned.
452 Returns
453 -------
454 datasetTypes : `NamedValueSet`
455 A set of all datasetTypes which correspond to the input
456 connection type specified in the connection class of this
457 `PipelineTask`
459 Notes
460 -----
461 This function is a closure over the variables ``registry`` and
462 ``taskDef``.
463 """
464 datasetTypes = NamedValueSet()
465 for c in iterConnections(taskDef.connections, connectionType):
466 dimensions = set(getattr(c, 'dimensions', set()))
467 if "skypix" in dimensions:
468 try:
469 datasetType = registry.getDatasetType(c.name)
470 except LookupError as err:
471 raise LookupError(
472 f"DatasetType '{c.name}' referenced by "
473 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
474 f"placeholder, but does not already exist in the registry. "
475 f"Note that reference catalog names are now used as the dataset "
476 f"type name instead of 'ref_cat'."
477 ) from err
478 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
479 rest2 = set(dim.name for dim in datasetType.dimensions
480 if not isinstance(dim, SkyPixDimension))
481 if rest1 != rest2:
482 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
483 f"connections ({rest1}) are inconsistent with those in "
484 f"registry's version of this dataset ({rest2}).")
485 else:
486 datasetType = DatasetType(c.name, registry.dimensions.extract(dimensions),
487 c.storageClass)
488 try:
489 registryDatasetType = registry.getDatasetType(c.name)
490 except KeyError:
491 registryDatasetType = datasetType
492 if datasetType != registryDatasetType:
493 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
494 f"registry definition ({registryDatasetType})")
495 datasetTypes.add(datasetType)
496 if freeze:
497 datasetTypes.freeze()
498 return datasetTypes
500 # optionally add output dataset for metadata
501 outputs = makeDatasetTypesSet("outputs", freeze=False)
502 if taskDef.metadataDatasetName is not None:
503 # Metadata is supposed to be of the PropertyList type, its dimensions
504 # correspond to a task quantum
505 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
506 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertyList")}
507 outputs.freeze()
509 return cls(
510 initInputs=makeDatasetTypesSet("initInputs"),
511 initOutputs=makeDatasetTypesSet("initOutputs"),
512 inputs=makeDatasetTypesSet("inputs"),
513 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
514 outputs=outputs,
515 )
518@dataclass(frozen=True)
519class PipelineDatasetTypes:
520 """An immutable struct that classifies the dataset types used in a
521 `Pipeline`.
522 """
524 initInputs: NamedValueSet[DatasetType]
525 """Dataset types that are needed as inputs in order to construct the Tasks
526 in this Pipeline.
528 This does not include dataset types that are produced when constructing
529 other Tasks in the Pipeline (these are classified as `initIntermediates`).
530 """
532 initOutputs: NamedValueSet[DatasetType]
533 """Dataset types that may be written after constructing the Tasks in this
534 Pipeline.
536 This does not include dataset types that are also used as inputs when
537 constructing other Tasks in the Pipeline (these are classified as
538 `initIntermediates`).
539 """
541 initIntermediates: NamedValueSet[DatasetType]
542 """Dataset types that are both used when constructing one or more Tasks
543 in the Pipeline and produced as a side-effect of constructing another
544 Task in the Pipeline.
545 """
547 inputs: NamedValueSet[DatasetType]
548 """Dataset types that are regular inputs for the full pipeline.
550 If an input dataset needed for a Quantum cannot be found in the input
551 collection(s), that Quantum (and all dependent Quanta) will not be
552 produced.
553 """
555 prerequisites: NamedValueSet[DatasetType]
556 """Dataset types that are prerequisite inputs for the full Pipeline.
558 Prerequisite inputs must exist in the input collection(s) before the
559 pipeline is run, but do not constrain the graph - if a prerequisite is
560 missing for a Quantum, `PrerequisiteMissingError` is raised.
562 Prerequisite inputs are not resolved until the second stage of
563 QuantumGraph generation.
564 """
566 intermediates: NamedValueSet[DatasetType]
567 """Dataset types that are output by one Task in the Pipeline and consumed
568 as inputs by one or more other Tasks in the Pipeline.
569 """
571 outputs: NamedValueSet[DatasetType]
572 """Dataset types that are output by a Task in the Pipeline and not consumed
573 by any other Task in the Pipeline.
574 """
576 byTask: Mapping[str, TaskDatasetTypes]
577 """Per-Task dataset types, keyed by label in the `Pipeline`.
579 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
580 neither has been modified since the dataset types were extracted, of
581 course).
582 """
584 @classmethod
585 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
586 """Extract and classify the dataset types from all tasks in a
587 `Pipeline`.
589 Parameters
590 ----------
591 pipeline: `Pipeline`
592 An ordered collection of tasks that can be run together.
593 registry: `Registry`
594 Registry used to construct normalized `DatasetType` objects and
595 retrieve those that are incomplete.
597 Returns
598 -------
599 types: `PipelineDatasetTypes`
600 The dataset types used by this `Pipeline`.
602 Raises
603 ------
604 ValueError
605 Raised if Tasks are inconsistent about which datasets are marked
606 prerequisite. This indicates that the Tasks cannot be run as part
607 of the same `Pipeline`.
608 """
609 allInputs = NamedValueSet()
610 allOutputs = NamedValueSet()
611 allInitInputs = NamedValueSet()
612 allInitOutputs = NamedValueSet()
613 prerequisites = NamedValueSet()
614 byTask = dict()
615 if isinstance(pipeline, Pipeline):
616 pipeline = pipeline.toExpandedPipeline()
617 for taskDef in pipeline:
618 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
619 allInitInputs |= thisTask.initInputs
620 allInitOutputs |= thisTask.initOutputs
621 allInputs |= thisTask.inputs
622 prerequisites |= thisTask.prerequisites
623 allOutputs |= thisTask.outputs
624 byTask[taskDef.label] = thisTask
625 if not prerequisites.isdisjoint(allInputs):
626 raise ValueError("{} marked as both prerequisites and regular inputs".format(
627 {dt.name for dt in allInputs & prerequisites}
628 ))
629 if not prerequisites.isdisjoint(allOutputs):
630 raise ValueError("{} marked as both prerequisites and outputs".format(
631 {dt.name for dt in allOutputs & prerequisites}
632 ))
633 # Make sure that components which are marked as inputs get treated as
634 # intermediates if there is an output which produces the composite
635 # containing the component
636 intermediateComponents = NamedValueSet()
637 intermediateComposites = NamedValueSet()
638 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
639 for dsType in allInputs:
640 # get the name of a possible component
641 name, component = dsType.nameAndComponent()
642 # if there is a component name, that means this is a component
643 # DatasetType, if there is an output which produces the parent of
644 # this component, treat this input as an intermediate
645 if component is not None:
646 if name in outputNameMapping:
647 if outputNameMapping[name].dimensions != dsType.dimensions:
648 raise ValueError(f"Component dataset type {dsType.name} has different "
649 f"dimensions ({dsType.dimensions}) than its parent "
650 f"({outputNameMapping[name].dimensions}).")
651 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
652 universe=registry.dimensions)
653 intermediateComponents.add(dsType)
654 intermediateComposites.add(composite)
656 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
657 common = a.names & b.names
658 for name in common:
659 if a[name] != b[name]:
660 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
662 checkConsistency(allInitInputs, allInitOutputs)
663 checkConsistency(allInputs, allOutputs)
664 checkConsistency(allInputs, intermediateComposites)
665 checkConsistency(allOutputs, intermediateComposites)
667 def frozen(s: NamedValueSet) -> NamedValueSet:
668 s.freeze()
669 return s
671 return cls(
672 initInputs=frozen(allInitInputs - allInitOutputs),
673 initIntermediates=frozen(allInitInputs & allInitOutputs),
674 initOutputs=frozen(allInitOutputs - allInitInputs),
675 inputs=frozen(allInputs - allOutputs - intermediateComponents),
676 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
677 outputs=frozen(allOutputs - allInputs - intermediateComposites),
678 prerequisites=frozen(prerequisites),
679 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
680 )