Coverage for python/lsst/pipe/base/pipeline.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Set, Union, Generator, TYPE_CHECKING, Optional
35import copy
36import os
38# -----------------------------
39# Imports for other modules --
40from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
41from lsst.utils import doImport
42from .configOverrides import ConfigOverrides
43from .connections import iterConnections
44from .pipelineTask import PipelineTask
46from . import pipelineIR
47from . import pipeTools
49if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 from lsst.obs.base.instrument import Instrument
52# ----------------------------------
53# Local non-exported definitions --
54# ----------------------------------
56# ------------------------
57# Exported definitions --
58# ------------------------
61@dataclass
62class LabelSpecifier:
63 """A structure to specify a subset of labels to load
65 This structure may contain a set of labels to be used in subsetting a
66 pipeline, or a beginning and end point. Beginning or end may be empty,
67 in which case the range will be a half open interval. Unlike python
68 iteration bounds, end bounds are *INCLUDED*. Note that range based
69 selection is not well defined for pipelines that are not linear in nature,
70 and correct behavior is not guaranteed, or may vary from run to run.
71 """
72 labels: Optional[Set[str]] = None
73 begin: Optional[str] = None
74 end: Optional[str] = None
76 def __post_init__(self):
77 if self.labels is not None and (self.begin or self.end):
78 raise ValueError("This struct can only be initialized with a labels set or "
79 "a begin (and/or) end specifier")
82class TaskDef:
83 """TaskDef is a collection of information about task needed by Pipeline.
85 The information includes task name, configuration object and optional
86 task class. This class is just a collection of attributes and it exposes
87 all of them so that attributes could potentially be modified in place
88 (e.g. if configuration needs extra overrides).
90 Attributes
91 ----------
92 taskName : `str`
93 `PipelineTask` class name, currently it is not specified whether this
94 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
95 Framework should be prepared to handle all cases.
96 config : `lsst.pex.config.Config`
97 Instance of the configuration class corresponding to this task class,
98 usually with all overrides applied. This config will be frozen.
99 taskClass : `type` or ``None``
100 `PipelineTask` class object, can be ``None``. If ``None`` then
101 framework will have to locate and load class.
102 label : `str`, optional
103 Task label, usually a short string unique in a pipeline.
104 """
105 def __init__(self, taskName, config, taskClass=None, label=""):
106 self.taskName = taskName
107 config.freeze()
108 self.config = config
109 self.taskClass = taskClass
110 self.label = label
111 self.connections = config.connections.ConnectionsClass(config=config)
113 @property
114 def configDatasetName(self):
115 """Name of a dataset type for configuration of this task (`str`)
116 """
117 return self.label + "_config"
119 @property
120 def metadataDatasetName(self):
121 """Name of a dataset type for metadata of this task, `None` if
122 metadata is not to be saved (`str`)
123 """
124 if self.config.saveMetadata:
125 return self.label + "_metadata"
126 else:
127 return None
129 def __str__(self):
130 rep = "TaskDef(" + self.taskName
131 if self.label:
132 rep += ", label=" + self.label
133 rep += ")"
134 return rep
136 def __eq__(self, other: object) -> bool:
137 if not isinstance(other, TaskDef):
138 return False
139 # This does not consider equality of configs when determining equality
140 # as config equality is a difficult thing to define. Should be updated
141 # after DM-27847
142 return self.taskClass == other.taskClass and self.label == other.label
144 def __hash__(self):
145 return hash((self.taskClass, self.label))
148class Pipeline:
149 """A `Pipeline` is a representation of a series of tasks to run, and the
150 configuration for those tasks.
152 Parameters
153 ----------
154 description : `str`
155 A description of that this pipeline does.
156 """
157 def __init__(self, description: str):
158 pipeline_dict = {"description": description, "tasks": {}}
159 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
161 @classmethod
162 def fromFile(cls, filename: str) -> Pipeline:
163 """Load a pipeline defined in a pipeline yaml file.
165 Parameters
166 ----------
167 filename: `str`
168 A path that points to a pipeline defined in yaml format. This
169 filename may also supply additional labels to be used in
170 subsetting the loaded Pipeline. These labels are separated from
171 the path by a colon, and may be specified as a comma separated
172 list, or a range denoted as beginning..end. Beginning or end may
173 be empty, in which case the range will be a half open interval.
174 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
175 that range based selection is not well defined for pipelines that
176 are not linear in nature, and correct behavior is not guaranteed,
177 or may vary from run to run.
179 Returns
180 -------
181 pipeline: `Pipeline`
182 The pipeline loaded from specified location with appropriate (if
183 any) subsetting
185 Notes
186 -----
187 This method attempts to prune any contracts that contain labels which
188 are not in the declared subset of labels. This pruning is done using a
189 string based matching due to the nature of contracts and may prune more
190 than it should.
191 """
192 # Split up the filename and any labels that were supplied
193 filename, labelSpecifier = cls._parseFileSpecifier(filename)
194 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
196 # If there are labels supplied, only keep those
197 if labelSpecifier is not None:
198 pipeline = pipeline.subsetFromLabels(labelSpecifier)
199 return pipeline
201 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
202 """Subset a pipeline to contain only labels specified in labelSpecifier
204 Parameters
205 ----------
206 labelSpecifier : `labelSpecifier`
207 Object containing labels that describes how to subset a pipeline.
209 Returns
210 -------
211 pipeline : `Pipeline`
212 A new pipeline object that is a subset of the old pipeline
214 Raises
215 ------
216 ValueError
217 Raised if there is an issue with specified labels
219 Notes
220 -----
221 This method attempts to prune any contracts that contain labels which
222 are not in the declared subset of labels. This pruning is done using a
223 string based matching due to the nature of contracts and may prune more
224 than it should.
225 """
226 # Labels supplied as a set
227 if labelSpecifier.labels:
228 labelSet = labelSpecifier.labels
229 # Labels supplied as a range, first create a list of all the labels
230 # in the pipeline sorted according to task dependency. Then only
231 # keep labels that lie between the supplied bounds
232 else:
233 # Create a copy of the pipeline to use when assessing the label
234 # ordering. Use a dict for fast searching while preserving order.
235 # Remove contracts so they do not fail in the expansion step. This
236 # is needed because a user may only configure the tasks they intend
237 # to run, which may cause some contracts to fail if they will later
238 # be dropped
239 pipeline = copy.deepcopy(self)
240 pipeline._pipelineIR.contracts = []
241 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
243 # Verify the bounds are in the labels
244 if labelSpecifier.begin is not None:
245 if labelSpecifier.begin not in labels:
246 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
247 "pipeline definition")
248 if labelSpecifier.end is not None:
249 if labelSpecifier.end not in labels:
250 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
251 "definition")
253 labelSet = set()
254 for label in labels:
255 if labelSpecifier.begin is not None:
256 if label != labelSpecifier.begin:
257 continue
258 else:
259 labelSpecifier.begin = None
260 labelSet.add(label)
261 if labelSpecifier.end is not None and label == labelSpecifier.end:
262 break
263 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
265 @staticmethod
266 def _parseFileSpecifier(fileSpecifer):
267 """Split appart a filename path from label subsets
268 """
269 split = fileSpecifer.split(':')
270 # There is only a filename, return just that
271 if len(split) == 1:
272 return fileSpecifer, None
273 # More than one specifier provided, bail out
274 if len(split) > 2:
275 raise ValueError("Only one : is allowed when specifying a pipeline to load")
276 else:
277 labelSubset: str
278 filename: str
279 filename, labelSubset = split[0], split[1]
280 # labels supplied as a list
281 if ',' in labelSubset:
282 if '..' in labelSubset:
283 raise ValueError("Can only specify a list of labels or a range"
284 "when loading a Pipline not both")
285 labels = set(labelSubset.split(","))
286 specifier = LabelSpecifier(labels=labels)
287 # labels supplied as a range
288 elif '..' in labelSubset:
289 # Try to destructure the labelSubset, this will fail if more
290 # than one range is specified
291 try:
292 begin, end = labelSubset.split("..")
293 except ValueError:
294 raise ValueError("Only one range can be specified when loading a pipeline")
295 specifier = LabelSpecifier(begin=begin if begin else None, end=end if end else None)
296 # Assume anything else is a single label
297 else:
298 labels = {labelSubset}
299 specifier = LabelSpecifier(labels=labels)
301 return filename, specifier
303 @classmethod
304 def fromString(cls, pipeline_string: str) -> Pipeline:
305 """Create a pipeline from string formatted as a pipeline document.
307 Parameters
308 ----------
309 pipeline_string : `str`
310 A string that is formatted according like a pipeline document
312 Returns
313 -------
314 pipeline: `Pipeline`
315 """
316 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
317 return pipeline
319 @classmethod
320 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
321 """Create a pipeline from an already created `PipelineIR` object.
323 Parameters
324 ----------
325 deserialized_pipeline: `PipelineIR`
326 An already created pipeline intermediate representation object
328 Returns
329 -------
330 pipeline: `Pipeline`
331 """
332 pipeline = cls.__new__(cls)
333 pipeline._pipelineIR = deserialized_pipeline
334 return pipeline
336 @classmethod
337 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
338 """Create a new pipeline by copying an already existing `Pipeline`.
340 Parameters
341 ----------
342 pipeline: `Pipeline`
343 An already created pipeline intermediate representation object
345 Returns
346 -------
347 pipeline: `Pipeline`
348 """
349 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
351 def __str__(self) -> str:
352 return str(self._pipelineIR)
354 def addInstrument(self, instrument: Union[Instrument, str]):
355 """Add an instrument to the pipeline, or replace an instrument that is
356 already defined.
358 Parameters
359 ----------
360 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
361 Either a derived class object of a `lsst.daf.butler.instrument` or
362 a string corresponding to a fully qualified
363 `lsst.daf.butler.instrument` name.
364 """
365 if isinstance(instrument, str):
366 pass
367 else:
368 # TODO: assume that this is a subclass of Instrument, no type
369 # checking
370 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
371 self._pipelineIR.instrument = instrument
373 def getInstrument(self):
374 """Get the instrument from the pipeline.
376 Returns
377 -------
378 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
379 A derived class object of a `lsst.daf.butler.instrument`, a string
380 corresponding to a fully qualified `lsst.daf.butler.instrument`
381 name, or None if the pipeline does not have an instrument.
382 """
383 return self._pipelineIR.instrument
385 def addTask(self, task: Union[PipelineTask, str], label: str):
386 """Add a new task to the pipeline, or replace a task that is already
387 associated with the supplied label.
389 Parameters
390 ----------
391 task: `PipelineTask` or `str`
392 Either a derived class object of a `PipelineTask` or a string
393 corresponding to a fully qualified `PipelineTask` name.
394 label: `str`
395 A label that is used to identify the `PipelineTask` being added
396 """
397 if isinstance(task, str):
398 taskName = task
399 elif issubclass(task, PipelineTask):
400 taskName = f"{task.__module__}.{task.__qualname__}"
401 else:
402 raise ValueError("task must be either a child class of PipelineTask or a string containing"
403 " a fully qualified name to one")
404 if not label:
405 # in some cases (with command line-generated pipeline) tasks can
406 # be defined without label which is not acceptable, use task
407 # _DefaultName in that case
408 if isinstance(task, str):
409 task = doImport(task)
410 label = task._DefaultName
411 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
413 def removeTask(self, label: str):
414 """Remove a task from the pipeline.
416 Parameters
417 ----------
418 label : `str`
419 The label used to identify the task that is to be removed
421 Raises
422 ------
423 KeyError
424 If no task with that label exists in the pipeline
426 """
427 self._pipelineIR.tasks.pop(label)
429 def addConfigOverride(self, label: str, key: str, value: object):
430 """Apply single config override.
432 Parameters
433 ----------
434 label : `str`
435 Label of the task.
436 key: `str`
437 Fully-qualified field name.
438 value : object
439 Value to be given to a field.
440 """
441 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
443 def addConfigFile(self, label: str, filename: str):
444 """Add overrides from a specified file.
446 Parameters
447 ----------
448 label : `str`
449 The label used to identify the task associated with config to
450 modify
451 filename : `str`
452 Path to the override file.
453 """
454 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
456 def addConfigPython(self, label: str, pythonString: str):
457 """Add Overrides by running a snippet of python code against a config.
459 Parameters
460 ----------
461 label : `str`
462 The label used to identity the task associated with config to
463 modify.
464 pythonString: `str`
465 A string which is valid python code to be executed. This is done
466 with config as the only local accessible value.
467 """
468 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
470 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
471 if label == "parameters":
472 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
473 raise ValueError("Cannot override parameters that are not defined in pipeline")
474 self._pipelineIR.parameters.mapping.update(newConfig.rest)
475 if newConfig.file:
476 raise ValueError("Setting parameters section with config file is not supported")
477 if newConfig.python:
478 raise ValueError("Setting parameters section using python block in unsupported")
479 return
480 if label not in self._pipelineIR.tasks:
481 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
482 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
484 def toFile(self, filename: str):
485 self._pipelineIR.to_file(filename)
487 def toExpandedPipeline(self) -> Generator[TaskDef]:
488 """Returns a generator of TaskDefs which can be used to create quantum
489 graphs.
491 Returns
492 -------
493 generator : generator of `TaskDef`
494 The generator returned will be the sorted iterator of tasks which
495 are to be used in constructing a quantum graph.
497 Raises
498 ------
499 NotImplementedError
500 If a dataId is supplied in a config block. This is in place for
501 future use
502 """
503 taskDefs = []
504 for label, taskIR in self._pipelineIR.tasks.items():
505 taskClass = doImport(taskIR.klass)
506 taskName = taskClass.__qualname__
507 config = taskClass.ConfigClass()
508 overrides = ConfigOverrides()
509 if self._pipelineIR.instrument is not None:
510 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
511 if taskIR.config is not None:
512 for configIR in (configIr.formatted(self._pipelineIR.parameters)
513 for configIr in taskIR.config):
514 if configIR.dataId is not None:
515 raise NotImplementedError("Specializing a config on a partial data id is not yet "
516 "supported in Pipeline definition")
517 # only apply override if it applies to everything
518 if configIR.dataId is None:
519 if configIR.file:
520 for configFile in configIR.file:
521 overrides.addFileOverride(os.path.expandvars(configFile))
522 if configIR.python is not None:
523 overrides.addPythonOverride(configIR.python)
524 for key, value in configIR.rest.items():
525 overrides.addValueOverride(key, value)
526 overrides.applyTo(config)
527 # This may need to be revisited
528 config.validate()
529 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
531 # lets evaluate the contracts
532 if self._pipelineIR.contracts is not None:
533 label_to_config = {x.label: x.config for x in taskDefs}
534 for contract in self._pipelineIR.contracts:
535 # execute this in its own line so it can raise a good error
536 # message if there was problems with the eval
537 success = eval(contract.contract, None, label_to_config)
538 if not success:
539 extra_info = f": {contract.msg}" if contract.msg is not None else ""
540 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
541 f"satisfied{extra_info}")
543 yield from pipeTools.orderPipeline(taskDefs)
545 def __len__(self):
546 return len(self._pipelineIR.tasks)
548 def __eq__(self, other: "Pipeline"):
549 if not isinstance(other, Pipeline):
550 return False
551 return self._pipelineIR == other._pipelineIR
554@dataclass(frozen=True)
555class TaskDatasetTypes:
556 """An immutable struct that extracts and classifies the dataset types used
557 by a `PipelineTask`
558 """
560 initInputs: NamedValueSet[DatasetType]
561 """Dataset types that are needed as inputs in order to construct this Task.
563 Task-level `initInputs` may be classified as either
564 `~PipelineDatasetTypes.initInputs` or
565 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
566 """
568 initOutputs: NamedValueSet[DatasetType]
569 """Dataset types that may be written after constructing this Task.
571 Task-level `initOutputs` may be classified as either
572 `~PipelineDatasetTypes.initOutputs` or
573 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
574 """
576 inputs: NamedValueSet[DatasetType]
577 """Dataset types that are regular inputs to this Task.
579 If an input dataset needed for a Quantum cannot be found in the input
580 collection(s) or produced by another Task in the Pipeline, that Quantum
581 (and all dependent Quanta) will not be produced.
583 Task-level `inputs` may be classified as either
584 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
585 at the Pipeline level.
586 """
588 prerequisites: NamedValueSet[DatasetType]
589 """Dataset types that are prerequisite inputs to this Task.
591 Prerequisite inputs must exist in the input collection(s) before the
592 pipeline is run, but do not constrain the graph - if a prerequisite is
593 missing for a Quantum, `PrerequisiteMissingError` is raised.
595 Prerequisite inputs are not resolved until the second stage of
596 QuantumGraph generation.
597 """
599 outputs: NamedValueSet[DatasetType]
600 """Dataset types that are produced by this Task.
602 Task-level `outputs` may be classified as either
603 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
604 at the Pipeline level.
605 """
607 @classmethod
608 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
609 """Extract and classify the dataset types from a single `PipelineTask`.
611 Parameters
612 ----------
613 taskDef: `TaskDef`
614 An instance of a `TaskDef` class for a particular `PipelineTask`.
615 registry: `Registry`
616 Registry used to construct normalized `DatasetType` objects and
617 retrieve those that are incomplete.
619 Returns
620 -------
621 types: `TaskDatasetTypes`
622 The dataset types used by this task.
623 """
624 def makeDatasetTypesSet(connectionType, freeze=True):
625 """Constructs a set of true `DatasetType` objects
627 Parameters
628 ----------
629 connectionType : `str`
630 Name of the connection type to produce a set for, corresponds
631 to an attribute of type `list` on the connection class instance
632 freeze : `bool`, optional
633 If `True`, call `NamedValueSet.freeze` on the object returned.
635 Returns
636 -------
637 datasetTypes : `NamedValueSet`
638 A set of all datasetTypes which correspond to the input
639 connection type specified in the connection class of this
640 `PipelineTask`
642 Notes
643 -----
644 This function is a closure over the variables ``registry`` and
645 ``taskDef``.
646 """
647 datasetTypes = NamedValueSet()
648 for c in iterConnections(taskDef.connections, connectionType):
649 dimensions = set(getattr(c, 'dimensions', set()))
650 if "skypix" in dimensions:
651 try:
652 datasetType = registry.getDatasetType(c.name)
653 except LookupError as err:
654 raise LookupError(
655 f"DatasetType '{c.name}' referenced by "
656 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
657 f"placeholder, but does not already exist in the registry. "
658 f"Note that reference catalog names are now used as the dataset "
659 f"type name instead of 'ref_cat'."
660 ) from err
661 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
662 rest2 = set(dim.name for dim in datasetType.dimensions
663 if not isinstance(dim, SkyPixDimension))
664 if rest1 != rest2:
665 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
666 f"connections ({rest1}) are inconsistent with those in "
667 f"registry's version of this dataset ({rest2}).")
668 else:
669 # Component dataset types are not explicitly in the
670 # registry. This complicates consistency checks with
671 # registry and requires we work out the composite storage
672 # class.
673 registryDatasetType = None
674 try:
675 registryDatasetType = registry.getDatasetType(c.name)
676 except KeyError:
677 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
678 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
679 if componentName else None
680 datasetType = c.makeDatasetType(
681 registry.dimensions,
682 parentStorageClass=parentStorageClass
683 )
684 registryDatasetType = datasetType
685 else:
686 datasetType = c.makeDatasetType(
687 registry.dimensions,
688 parentStorageClass=registryDatasetType.parentStorageClass
689 )
691 if registryDatasetType and datasetType != registryDatasetType:
692 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
693 f"registry definition ({registryDatasetType}) "
694 f"for {taskDef.label}.")
695 datasetTypes.add(datasetType)
696 if freeze:
697 datasetTypes.freeze()
698 return datasetTypes
700 # optionally add output dataset for metadata
701 outputs = makeDatasetTypesSet("outputs", freeze=False)
702 if taskDef.metadataDatasetName is not None:
703 # Metadata is supposed to be of the PropertySet type, its
704 # dimensions correspond to a task quantum
705 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
706 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
707 outputs.freeze()
709 return cls(
710 initInputs=makeDatasetTypesSet("initInputs"),
711 initOutputs=makeDatasetTypesSet("initOutputs"),
712 inputs=makeDatasetTypesSet("inputs"),
713 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
714 outputs=outputs,
715 )
718@dataclass(frozen=True)
719class PipelineDatasetTypes:
720 """An immutable struct that classifies the dataset types used in a
721 `Pipeline`.
722 """
724 initInputs: NamedValueSet[DatasetType]
725 """Dataset types that are needed as inputs in order to construct the Tasks
726 in this Pipeline.
728 This does not include dataset types that are produced when constructing
729 other Tasks in the Pipeline (these are classified as `initIntermediates`).
730 """
732 initOutputs: NamedValueSet[DatasetType]
733 """Dataset types that may be written after constructing the Tasks in this
734 Pipeline.
736 This does not include dataset types that are also used as inputs when
737 constructing other Tasks in the Pipeline (these are classified as
738 `initIntermediates`).
739 """
741 initIntermediates: NamedValueSet[DatasetType]
742 """Dataset types that are both used when constructing one or more Tasks
743 in the Pipeline and produced as a side-effect of constructing another
744 Task in the Pipeline.
745 """
747 inputs: NamedValueSet[DatasetType]
748 """Dataset types that are regular inputs for the full pipeline.
750 If an input dataset needed for a Quantum cannot be found in the input
751 collection(s), that Quantum (and all dependent Quanta) will not be
752 produced.
753 """
755 prerequisites: NamedValueSet[DatasetType]
756 """Dataset types that are prerequisite inputs for the full Pipeline.
758 Prerequisite inputs must exist in the input collection(s) before the
759 pipeline is run, but do not constrain the graph - if a prerequisite is
760 missing for a Quantum, `PrerequisiteMissingError` is raised.
762 Prerequisite inputs are not resolved until the second stage of
763 QuantumGraph generation.
764 """
766 intermediates: NamedValueSet[DatasetType]
767 """Dataset types that are output by one Task in the Pipeline and consumed
768 as inputs by one or more other Tasks in the Pipeline.
769 """
771 outputs: NamedValueSet[DatasetType]
772 """Dataset types that are output by a Task in the Pipeline and not consumed
773 by any other Task in the Pipeline.
774 """
776 byTask: Mapping[str, TaskDatasetTypes]
777 """Per-Task dataset types, keyed by label in the `Pipeline`.
779 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
780 neither has been modified since the dataset types were extracted, of
781 course).
782 """
784 @classmethod
785 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
786 """Extract and classify the dataset types from all tasks in a
787 `Pipeline`.
789 Parameters
790 ----------
791 pipeline: `Pipeline`
792 An ordered collection of tasks that can be run together.
793 registry: `Registry`
794 Registry used to construct normalized `DatasetType` objects and
795 retrieve those that are incomplete.
797 Returns
798 -------
799 types: `PipelineDatasetTypes`
800 The dataset types used by this `Pipeline`.
802 Raises
803 ------
804 ValueError
805 Raised if Tasks are inconsistent about which datasets are marked
806 prerequisite. This indicates that the Tasks cannot be run as part
807 of the same `Pipeline`.
808 """
809 allInputs = NamedValueSet()
810 allOutputs = NamedValueSet()
811 allInitInputs = NamedValueSet()
812 allInitOutputs = NamedValueSet()
813 prerequisites = NamedValueSet()
814 byTask = dict()
815 if isinstance(pipeline, Pipeline):
816 pipeline = pipeline.toExpandedPipeline()
817 for taskDef in pipeline:
818 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
819 allInitInputs |= thisTask.initInputs
820 allInitOutputs |= thisTask.initOutputs
821 allInputs |= thisTask.inputs
822 prerequisites |= thisTask.prerequisites
823 allOutputs |= thisTask.outputs
824 byTask[taskDef.label] = thisTask
825 if not prerequisites.isdisjoint(allInputs):
826 raise ValueError("{} marked as both prerequisites and regular inputs".format(
827 {dt.name for dt in allInputs & prerequisites}
828 ))
829 if not prerequisites.isdisjoint(allOutputs):
830 raise ValueError("{} marked as both prerequisites and outputs".format(
831 {dt.name for dt in allOutputs & prerequisites}
832 ))
833 # Make sure that components which are marked as inputs get treated as
834 # intermediates if there is an output which produces the composite
835 # containing the component
836 intermediateComponents = NamedValueSet()
837 intermediateComposites = NamedValueSet()
838 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
839 for dsType in allInputs:
840 # get the name of a possible component
841 name, component = dsType.nameAndComponent()
842 # if there is a component name, that means this is a component
843 # DatasetType, if there is an output which produces the parent of
844 # this component, treat this input as an intermediate
845 if component is not None:
846 if name in outputNameMapping:
847 if outputNameMapping[name].dimensions != dsType.dimensions:
848 raise ValueError(f"Component dataset type {dsType.name} has different "
849 f"dimensions ({dsType.dimensions}) than its parent "
850 f"({outputNameMapping[name].dimensions}).")
851 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
852 universe=registry.dimensions)
853 intermediateComponents.add(dsType)
854 intermediateComposites.add(composite)
856 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
857 common = a.names & b.names
858 for name in common:
859 if a[name] != b[name]:
860 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
862 checkConsistency(allInitInputs, allInitOutputs)
863 checkConsistency(allInputs, allOutputs)
864 checkConsistency(allInputs, intermediateComposites)
865 checkConsistency(allOutputs, intermediateComposites)
867 def frozen(s: NamedValueSet) -> NamedValueSet:
868 s.freeze()
869 return s
871 return cls(
872 initInputs=frozen(allInitInputs - allInitOutputs),
873 initIntermediates=frozen(allInitInputs & allInitOutputs),
874 initOutputs=frozen(allInitOutputs - allInitInputs),
875 inputs=frozen(allInputs - allOutputs - intermediateComponents),
876 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
877 outputs=frozen(allOutputs - allInputs - intermediateComposites),
878 prerequisites=frozen(prerequisites),
879 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
880 )