Coverage for python/lsst/pipe/base/pipeline.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Set, Union, Generator, TYPE_CHECKING, Optional
35import copy
36import os
38# -----------------------------
39# Imports for other modules --
40from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
41from lsst.utils import doImport
42from .configOverrides import ConfigOverrides
43from .connections import iterConnections
44from .pipelineTask import PipelineTask
46from . import pipelineIR
47from . import pipeTools
49if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 from lsst.obs.base.instrument import Instrument
52# ----------------------------------
53# Local non-exported definitions --
54# ----------------------------------
56# ------------------------
57# Exported definitions --
58# ------------------------
61@dataclass
62class LabelSpecifier:
63 """A structure to specify a subset of labels to load
65 This structure may contain a set of labels to be used in subsetting a
66 pipeline, or a beginning and end point. Beginning or end may be empty,
67 in which case the range will be a half open interval. Unlike python
68 iteration bounds, end bounds are *INCLUDED*. Note that range based
69 selection is not well defined for pipelines that are not linear in nature,
70 and correct behavior is not guaranteed, or may vary from run to run.
71 """
72 labels: Optional[Set[str]] = None
73 begin: Optional[str] = None
74 end: Optional[str] = None
76 def __post_init__(self):
77 if self.labels is not None and (self.begin or self.end):
78 raise ValueError("This struct can only be initialized with a labels set or "
79 "a begin (and/or) end specifier")
82class TaskDef:
83 """TaskDef is a collection of information about task needed by Pipeline.
85 The information includes task name, configuration object and optional
86 task class. This class is just a collection of attributes and it exposes
87 all of them so that attributes could potentially be modified in place
88 (e.g. if configuration needs extra overrides).
90 Attributes
91 ----------
92 taskName : `str`
93 `PipelineTask` class name, currently it is not specified whether this
94 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
95 Framework should be prepared to handle all cases.
96 config : `lsst.pex.config.Config`
97 Instance of the configuration class corresponding to this task class,
98 usually with all overrides applied. This config will be frozen.
99 taskClass : `type` or ``None``
100 `PipelineTask` class object, can be ``None``. If ``None`` then
101 framework will have to locate and load class.
102 label : `str`, optional
103 Task label, usually a short string unique in a pipeline.
104 """
105 def __init__(self, taskName, config, taskClass=None, label=""):
106 self.taskName = taskName
107 config.freeze()
108 self.config = config
109 self.taskClass = taskClass
110 self.label = label
111 self.connections = config.connections.ConnectionsClass(config=config)
113 @property
114 def configDatasetName(self):
115 """Name of a dataset type for configuration of this task (`str`)
116 """
117 return self.label + "_config"
119 @property
120 def metadataDatasetName(self):
121 """Name of a dataset type for metadata of this task, `None` if
122 metadata is not to be saved (`str`)
123 """
124 if self.config.saveMetadata:
125 return self.label + "_metadata"
126 else:
127 return None
129 def __str__(self):
130 rep = "TaskDef(" + self.taskName
131 if self.label:
132 rep += ", label=" + self.label
133 rep += ")"
134 return rep
136 def __eq__(self, other: object) -> bool:
137 if not isinstance(other, TaskDef):
138 return False
139 return self.config == other.config and\
140 self.taskClass == other.taskClass and\
141 self.label == other.label
143 def __hash__(self):
144 return hash((self.taskClass, self.label))
147class Pipeline:
148 """A `Pipeline` is a representation of a series of tasks to run, and the
149 configuration for those tasks.
151 Parameters
152 ----------
153 description : `str`
154 A description of that this pipeline does.
155 """
156 def __init__(self, description: str):
157 pipeline_dict = {"description": description, "tasks": {}}
158 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
160 @classmethod
161 def fromFile(cls, filename: str) -> Pipeline:
162 """Load a pipeline defined in a pipeline yaml file.
164 Parameters
165 ----------
166 filename: `str`
167 A path that points to a pipeline defined in yaml format. This
168 filename may also supply additional labels to be used in
169 subsetting the loaded Pipeline. These labels are separated from
170 the path by a colon, and may be specified as a comma separated
171 list, or a range denoted as beginning..end. Beginning or end may
172 be empty, in which case the range will be a half open interval.
173 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
174 that range based selection is not well defined for pipelines that
175 are not linear in nature, and correct behavior is not guaranteed,
176 or may vary from run to run.
178 Returns
179 -------
180 pipeline: `Pipeline`
181 The pipeline loaded from specified location with appropriate (if
182 any) subsetting
184 Notes
185 -----
186 This method attempts to prune any contracts that contain labels which
187 are not in the declared subset of labels. This pruning is done using a
188 string based matching due to the nature of contracts and may prune more
189 than it should.
190 """
191 # Split up the filename and any labels that were supplied
192 filename, labelSpecifier = cls._parseFileSpecifier(filename)
193 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
195 # If there are labels supplied, only keep those
196 if labelSpecifier is not None:
197 pipeline = pipeline.subsetFromLabels(labelSpecifier)
198 return pipeline
200 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
201 """Subset a pipeline to contain only labels specified in labelSpecifier
203 Parameters
204 ----------
205 labelSpecifier : `labelSpecifier`
206 Object containing labels that describes how to subset a pipeline.
208 Returns
209 -------
210 pipeline : `Pipeline`
211 A new pipeline object that is a subset of the old pipeline
213 Raises
214 ------
215 ValueError
216 Raised if there is an issue with specified labels
218 Notes
219 -----
220 This method attempts to prune any contracts that contain labels which
221 are not in the declared subset of labels. This pruning is done using a
222 string based matching due to the nature of contracts and may prune more
223 than it should.
224 """
225 # Labels supplied as a set
226 if labelSpecifier.labels:
227 labelSet = labelSpecifier.labels
228 # Labels supplied as a range, first create a list of all the labels
229 # in the pipeline sorted according to task dependency. Then only
230 # keep labels that lie between the supplied bounds
231 else:
232 # Create a copy of the pipeline to use when assessing the label
233 # ordering. Use a dict for fast searching while preserving order.
234 # Remove contracts so they do not fail in the expansion step. This
235 # is needed because a user may only configure the tasks they intend
236 # to run, which may cause some contracts to fail if they will later
237 # be dropped
238 pipeline = copy.deepcopy(self)
239 pipeline._pipelineIR.contracts = []
240 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
242 # Verify the bounds are in the labels
243 if labelSpecifier.begin is not None:
244 if labelSpecifier.begin not in labels:
245 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
246 "pipeline definition")
247 if labelSpecifier.end is not None:
248 if labelSpecifier.end not in labels:
249 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
250 "definition")
252 labelSet = set()
253 for label in labels:
254 if labelSpecifier.begin is not None:
255 if label != labelSpecifier.begin:
256 continue
257 else:
258 labelSpecifier.begin = None
259 labelSet.add(label)
260 if labelSpecifier.end is not None and label == labelSpecifier.end:
261 break
262 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
264 @staticmethod
265 def _parseFileSpecifier(fileSpecifer):
266 """Split appart a filename path from label subsets
267 """
268 split = fileSpecifer.split(':')
269 # There is only a filename, return just that
270 if len(split) == 1:
271 return fileSpecifer, None
272 # More than one specifier provided, bail out
273 if len(split) > 2:
274 raise ValueError("Only one : is allowed when specifying a pipeline to load")
275 else:
276 labelSubset: str
277 filename: str
278 filename, labelSubset = split[0], split[1]
279 # labels supplied as a list
280 if ',' in labelSubset:
281 if '..' in labelSubset:
282 raise ValueError("Can only specify a list of labels or a range"
283 "when loading a Pipline not both")
284 labels = set(labelSubset.split(","))
285 specifier = LabelSpecifier(labels=labels)
286 # labels supplied as a range
287 elif '..' in labelSubset:
288 # Try to destructure the labelSubset, this will fail if more
289 # than one range is specified
290 try:
291 begin, end = labelSubset.split("..")
292 except ValueError:
293 raise ValueError("Only one range can be specified when loading a pipeline")
294 specifier = LabelSpecifier(begin=begin if begin else None, end=end if end else None)
295 # Assume anything else is a single label
296 else:
297 labels = {labelSubset}
298 specifier = LabelSpecifier(labels=labels)
300 return filename, specifier
302 @classmethod
303 def fromString(cls, pipeline_string: str) -> Pipeline:
304 """Create a pipeline from string formatted as a pipeline document.
306 Parameters
307 ----------
308 pipeline_string : `str`
309 A string that is formatted according like a pipeline document
311 Returns
312 -------
313 pipeline: `Pipeline`
314 """
315 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
316 return pipeline
318 @classmethod
319 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
320 """Create a pipeline from an already created `PipelineIR` object.
322 Parameters
323 ----------
324 deserialized_pipeline: `PipelineIR`
325 An already created pipeline intermediate representation object
327 Returns
328 -------
329 pipeline: `Pipeline`
330 """
331 pipeline = cls.__new__(cls)
332 pipeline._pipelineIR = deserialized_pipeline
333 return pipeline
335 @classmethod
336 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
337 """Create a new pipeline by copying an already existing `Pipeline`.
339 Parameters
340 ----------
341 pipeline: `Pipeline`
342 An already created pipeline intermediate representation object
344 Returns
345 -------
346 pipeline: `Pipeline`
347 """
348 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
350 def __str__(self) -> str:
351 return str(self._pipelineIR)
353 def addInstrument(self, instrument: Union[Instrument, str]):
354 """Add an instrument to the pipeline, or replace an instrument that is
355 already defined.
357 Parameters
358 ----------
359 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
360 Either a derived class object of a `lsst.daf.butler.instrument` or
361 a string corresponding to a fully qualified
362 `lsst.daf.butler.instrument` name.
363 """
364 if isinstance(instrument, str):
365 pass
366 else:
367 # TODO: assume that this is a subclass of Instrument, no type
368 # checking
369 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
370 self._pipelineIR.instrument = instrument
372 def getInstrument(self):
373 """Get the instrument from the pipeline.
375 Returns
376 -------
377 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
378 A derived class object of a `lsst.daf.butler.instrument`, a string
379 corresponding to a fully qualified `lsst.daf.butler.instrument`
380 name, or None if the pipeline does not have an instrument.
381 """
382 return self._pipelineIR.instrument
384 def addTask(self, task: Union[PipelineTask, str], label: str):
385 """Add a new task to the pipeline, or replace a task that is already
386 associated with the supplied label.
388 Parameters
389 ----------
390 task: `PipelineTask` or `str`
391 Either a derived class object of a `PipelineTask` or a string
392 corresponding to a fully qualified `PipelineTask` name.
393 label: `str`
394 A label that is used to identify the `PipelineTask` being added
395 """
396 if isinstance(task, str):
397 taskName = task
398 elif issubclass(task, PipelineTask):
399 taskName = f"{task.__module__}.{task.__qualname__}"
400 else:
401 raise ValueError("task must be either a child class of PipelineTask or a string containing"
402 " a fully qualified name to one")
403 if not label:
404 # in some cases (with command line-generated pipeline) tasks can
405 # be defined without label which is not acceptable, use task
406 # _DefaultName in that case
407 if isinstance(task, str):
408 task = doImport(task)
409 label = task._DefaultName
410 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
412 def removeTask(self, label: str):
413 """Remove a task from the pipeline.
415 Parameters
416 ----------
417 label : `str`
418 The label used to identify the task that is to be removed
420 Raises
421 ------
422 KeyError
423 If no task with that label exists in the pipeline
425 """
426 self._pipelineIR.tasks.pop(label)
428 def addConfigOverride(self, label: str, key: str, value: object):
429 """Apply single config override.
431 Parameters
432 ----------
433 label : `str`
434 Label of the task.
435 key: `str`
436 Fully-qualified field name.
437 value : object
438 Value to be given to a field.
439 """
440 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
442 def addConfigFile(self, label: str, filename: str):
443 """Add overrides from a specified file.
445 Parameters
446 ----------
447 label : `str`
448 The label used to identify the task associated with config to
449 modify
450 filename : `str`
451 Path to the override file.
452 """
453 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
455 def addConfigPython(self, label: str, pythonString: str):
456 """Add Overrides by running a snippet of python code against a config.
458 Parameters
459 ----------
460 label : `str`
461 The label used to identity the task associated with config to
462 modify.
463 pythonString: `str`
464 A string which is valid python code to be executed. This is done
465 with config as the only local accessible value.
466 """
467 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
469 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
470 if label not in self._pipelineIR.tasks:
471 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
472 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
474 def toFile(self, filename: str):
475 self._pipelineIR.to_file(filename)
477 def toExpandedPipeline(self) -> Generator[TaskDef]:
478 """Returns a generator of TaskDefs which can be used to create quantum
479 graphs.
481 Returns
482 -------
483 generator : generator of `TaskDef`
484 The generator returned will be the sorted iterator of tasks which
485 are to be used in constructing a quantum graph.
487 Raises
488 ------
489 NotImplementedError
490 If a dataId is supplied in a config block. This is in place for
491 future use
492 """
493 taskDefs = []
494 for label, taskIR in self._pipelineIR.tasks.items():
495 taskClass = doImport(taskIR.klass)
496 taskName = taskClass.__qualname__
497 config = taskClass.ConfigClass()
498 overrides = ConfigOverrides()
499 if self._pipelineIR.instrument is not None:
500 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
501 if taskIR.config is not None:
502 for configIR in taskIR.config:
503 if configIR.dataId is not None:
504 raise NotImplementedError("Specializing a config on a partial data id is not yet "
505 "supported in Pipeline definition")
506 # only apply override if it applies to everything
507 if configIR.dataId is None:
508 if configIR.file:
509 for configFile in configIR.file:
510 overrides.addFileOverride(os.path.expandvars(configFile))
511 if configIR.python is not None:
512 overrides.addPythonOverride(configIR.python)
513 for key, value in configIR.rest.items():
514 overrides.addValueOverride(key, value)
515 overrides.applyTo(config)
516 # This may need to be revisited
517 config.validate()
518 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
520 # lets evaluate the contracts
521 if self._pipelineIR.contracts is not None:
522 label_to_config = {x.label: x.config for x in taskDefs}
523 for contract in self._pipelineIR.contracts:
524 # execute this in its own line so it can raise a good error
525 # message if there was problems with the eval
526 success = eval(contract.contract, None, label_to_config)
527 if not success:
528 extra_info = f": {contract.msg}" if contract.msg is not None else ""
529 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
530 f"satisfied{extra_info}")
532 yield from pipeTools.orderPipeline(taskDefs)
534 def __len__(self):
535 return len(self._pipelineIR.tasks)
537 def __eq__(self, other: "Pipeline"):
538 if not isinstance(other, Pipeline):
539 return False
540 return self._pipelineIR == other._pipelineIR
543@dataclass(frozen=True)
544class TaskDatasetTypes:
545 """An immutable struct that extracts and classifies the dataset types used
546 by a `PipelineTask`
547 """
549 initInputs: NamedValueSet[DatasetType]
550 """Dataset types that are needed as inputs in order to construct this Task.
552 Task-level `initInputs` may be classified as either
553 `~PipelineDatasetTypes.initInputs` or
554 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
555 """
557 initOutputs: NamedValueSet[DatasetType]
558 """Dataset types that may be written after constructing this Task.
560 Task-level `initOutputs` may be classified as either
561 `~PipelineDatasetTypes.initOutputs` or
562 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
563 """
565 inputs: NamedValueSet[DatasetType]
566 """Dataset types that are regular inputs to this Task.
568 If an input dataset needed for a Quantum cannot be found in the input
569 collection(s) or produced by another Task in the Pipeline, that Quantum
570 (and all dependent Quanta) will not be produced.
572 Task-level `inputs` may be classified as either
573 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
574 at the Pipeline level.
575 """
577 prerequisites: NamedValueSet[DatasetType]
578 """Dataset types that are prerequisite inputs to this Task.
580 Prerequisite inputs must exist in the input collection(s) before the
581 pipeline is run, but do not constrain the graph - if a prerequisite is
582 missing for a Quantum, `PrerequisiteMissingError` is raised.
584 Prerequisite inputs are not resolved until the second stage of
585 QuantumGraph generation.
586 """
588 outputs: NamedValueSet[DatasetType]
589 """Dataset types that are produced by this Task.
591 Task-level `outputs` may be classified as either
592 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
593 at the Pipeline level.
594 """
596 @classmethod
597 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
598 """Extract and classify the dataset types from a single `PipelineTask`.
600 Parameters
601 ----------
602 taskDef: `TaskDef`
603 An instance of a `TaskDef` class for a particular `PipelineTask`.
604 registry: `Registry`
605 Registry used to construct normalized `DatasetType` objects and
606 retrieve those that are incomplete.
608 Returns
609 -------
610 types: `TaskDatasetTypes`
611 The dataset types used by this task.
612 """
613 def makeDatasetTypesSet(connectionType, freeze=True):
614 """Constructs a set of true `DatasetType` objects
616 Parameters
617 ----------
618 connectionType : `str`
619 Name of the connection type to produce a set for, corresponds
620 to an attribute of type `list` on the connection class instance
621 freeze : `bool`, optional
622 If `True`, call `NamedValueSet.freeze` on the object returned.
624 Returns
625 -------
626 datasetTypes : `NamedValueSet`
627 A set of all datasetTypes which correspond to the input
628 connection type specified in the connection class of this
629 `PipelineTask`
631 Notes
632 -----
633 This function is a closure over the variables ``registry`` and
634 ``taskDef``.
635 """
636 datasetTypes = NamedValueSet()
637 for c in iterConnections(taskDef.connections, connectionType):
638 dimensions = set(getattr(c, 'dimensions', set()))
639 if "skypix" in dimensions:
640 try:
641 datasetType = registry.getDatasetType(c.name)
642 except LookupError as err:
643 raise LookupError(
644 f"DatasetType '{c.name}' referenced by "
645 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
646 f"placeholder, but does not already exist in the registry. "
647 f"Note that reference catalog names are now used as the dataset "
648 f"type name instead of 'ref_cat'."
649 ) from err
650 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
651 rest2 = set(dim.name for dim in datasetType.dimensions
652 if not isinstance(dim, SkyPixDimension))
653 if rest1 != rest2:
654 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
655 f"connections ({rest1}) are inconsistent with those in "
656 f"registry's version of this dataset ({rest2}).")
657 else:
658 # Component dataset types are not explicitly in the
659 # registry. This complicates consistency checks with
660 # registry and requires we work out the composite storage
661 # class.
662 registryDatasetType = None
663 try:
664 registryDatasetType = registry.getDatasetType(c.name)
665 except KeyError:
666 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
667 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
668 if componentName else None
669 datasetType = c.makeDatasetType(
670 registry.dimensions,
671 parentStorageClass=parentStorageClass
672 )
673 registryDatasetType = datasetType
674 else:
675 datasetType = c.makeDatasetType(
676 registry.dimensions,
677 parentStorageClass=registryDatasetType.parentStorageClass
678 )
680 if registryDatasetType and datasetType != registryDatasetType:
681 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
682 f"registry definition ({registryDatasetType}) "
683 f"for {taskDef.label}.")
684 datasetTypes.add(datasetType)
685 if freeze:
686 datasetTypes.freeze()
687 return datasetTypes
689 # optionally add output dataset for metadata
690 outputs = makeDatasetTypesSet("outputs", freeze=False)
691 if taskDef.metadataDatasetName is not None:
692 # Metadata is supposed to be of the PropertySet type, its
693 # dimensions correspond to a task quantum
694 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
695 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
696 outputs.freeze()
698 return cls(
699 initInputs=makeDatasetTypesSet("initInputs"),
700 initOutputs=makeDatasetTypesSet("initOutputs"),
701 inputs=makeDatasetTypesSet("inputs"),
702 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
703 outputs=outputs,
704 )
707@dataclass(frozen=True)
708class PipelineDatasetTypes:
709 """An immutable struct that classifies the dataset types used in a
710 `Pipeline`.
711 """
713 initInputs: NamedValueSet[DatasetType]
714 """Dataset types that are needed as inputs in order to construct the Tasks
715 in this Pipeline.
717 This does not include dataset types that are produced when constructing
718 other Tasks in the Pipeline (these are classified as `initIntermediates`).
719 """
721 initOutputs: NamedValueSet[DatasetType]
722 """Dataset types that may be written after constructing the Tasks in this
723 Pipeline.
725 This does not include dataset types that are also used as inputs when
726 constructing other Tasks in the Pipeline (these are classified as
727 `initIntermediates`).
728 """
730 initIntermediates: NamedValueSet[DatasetType]
731 """Dataset types that are both used when constructing one or more Tasks
732 in the Pipeline and produced as a side-effect of constructing another
733 Task in the Pipeline.
734 """
736 inputs: NamedValueSet[DatasetType]
737 """Dataset types that are regular inputs for the full pipeline.
739 If an input dataset needed for a Quantum cannot be found in the input
740 collection(s), that Quantum (and all dependent Quanta) will not be
741 produced.
742 """
744 prerequisites: NamedValueSet[DatasetType]
745 """Dataset types that are prerequisite inputs for the full Pipeline.
747 Prerequisite inputs must exist in the input collection(s) before the
748 pipeline is run, but do not constrain the graph - if a prerequisite is
749 missing for a Quantum, `PrerequisiteMissingError` is raised.
751 Prerequisite inputs are not resolved until the second stage of
752 QuantumGraph generation.
753 """
755 intermediates: NamedValueSet[DatasetType]
756 """Dataset types that are output by one Task in the Pipeline and consumed
757 as inputs by one or more other Tasks in the Pipeline.
758 """
760 outputs: NamedValueSet[DatasetType]
761 """Dataset types that are output by a Task in the Pipeline and not consumed
762 by any other Task in the Pipeline.
763 """
765 byTask: Mapping[str, TaskDatasetTypes]
766 """Per-Task dataset types, keyed by label in the `Pipeline`.
768 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
769 neither has been modified since the dataset types were extracted, of
770 course).
771 """
773 @classmethod
774 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
775 """Extract and classify the dataset types from all tasks in a
776 `Pipeline`.
778 Parameters
779 ----------
780 pipeline: `Pipeline`
781 An ordered collection of tasks that can be run together.
782 registry: `Registry`
783 Registry used to construct normalized `DatasetType` objects and
784 retrieve those that are incomplete.
786 Returns
787 -------
788 types: `PipelineDatasetTypes`
789 The dataset types used by this `Pipeline`.
791 Raises
792 ------
793 ValueError
794 Raised if Tasks are inconsistent about which datasets are marked
795 prerequisite. This indicates that the Tasks cannot be run as part
796 of the same `Pipeline`.
797 """
798 allInputs = NamedValueSet()
799 allOutputs = NamedValueSet()
800 allInitInputs = NamedValueSet()
801 allInitOutputs = NamedValueSet()
802 prerequisites = NamedValueSet()
803 byTask = dict()
804 if isinstance(pipeline, Pipeline):
805 pipeline = pipeline.toExpandedPipeline()
806 for taskDef in pipeline:
807 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
808 allInitInputs |= thisTask.initInputs
809 allInitOutputs |= thisTask.initOutputs
810 allInputs |= thisTask.inputs
811 prerequisites |= thisTask.prerequisites
812 allOutputs |= thisTask.outputs
813 byTask[taskDef.label] = thisTask
814 if not prerequisites.isdisjoint(allInputs):
815 raise ValueError("{} marked as both prerequisites and regular inputs".format(
816 {dt.name for dt in allInputs & prerequisites}
817 ))
818 if not prerequisites.isdisjoint(allOutputs):
819 raise ValueError("{} marked as both prerequisites and outputs".format(
820 {dt.name for dt in allOutputs & prerequisites}
821 ))
822 # Make sure that components which are marked as inputs get treated as
823 # intermediates if there is an output which produces the composite
824 # containing the component
825 intermediateComponents = NamedValueSet()
826 intermediateComposites = NamedValueSet()
827 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
828 for dsType in allInputs:
829 # get the name of a possible component
830 name, component = dsType.nameAndComponent()
831 # if there is a component name, that means this is a component
832 # DatasetType, if there is an output which produces the parent of
833 # this component, treat this input as an intermediate
834 if component is not None:
835 if name in outputNameMapping:
836 if outputNameMapping[name].dimensions != dsType.dimensions:
837 raise ValueError(f"Component dataset type {dsType.name} has different "
838 f"dimensions ({dsType.dimensions}) than its parent "
839 f"({outputNameMapping[name].dimensions}).")
840 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
841 universe=registry.dimensions)
842 intermediateComponents.add(dsType)
843 intermediateComposites.add(composite)
845 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
846 common = a.names & b.names
847 for name in common:
848 if a[name] != b[name]:
849 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
851 checkConsistency(allInitInputs, allInitOutputs)
852 checkConsistency(allInputs, allOutputs)
853 checkConsistency(allInputs, intermediateComposites)
854 checkConsistency(allOutputs, intermediateComposites)
856 def frozen(s: NamedValueSet) -> NamedValueSet:
857 s.freeze()
858 return s
860 return cls(
861 initInputs=frozen(allInitInputs - allInitOutputs),
862 initIntermediates=frozen(allInitInputs & allInitOutputs),
863 initOutputs=frozen(allInitOutputs - allInitInputs),
864 inputs=frozen(allInputs - allOutputs - intermediateComponents),
865 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
866 outputs=frozen(allOutputs - allInputs - intermediateComposites),
867 prerequisites=frozen(prerequisites),
868 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
869 )