Coverage for python/lsst/pipe/base/pipeline.py : 16%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Set, Union, Generator, TYPE_CHECKING, Optional
35import copy
36import os
37import re
39# -----------------------------
40# Imports for other modules --
41from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
42from lsst.utils import doImport
43from .configOverrides import ConfigOverrides
44from .connections import iterConnections
45from .pipelineTask import PipelineTask
47from . import pipelineIR
48from . import pipeTools
50if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true
51 from lsst.obs.base.instrument import Instrument
53# ----------------------------------
54# Local non-exported definitions --
55# ----------------------------------
57# ------------------------
58# Exported definitions --
59# ------------------------
62@dataclass
63class LabelSpecifier:
64 """A structure to specify a subset of labels to load
66 This structure may contain a set of labels to be used in subsetting a
67 pipeline, or a beginning and end point. Beginning or end may be empty,
68 in which case the range will be a half open interval. Unlike python
69 iteration bounds, end bounds are *INCLUDED*. Note that range based selection
70 is not well defined for pipelines that are not linear in nature, and
71 correct behavior is not guaranteed, or may vary from run to run.
72 """
73 labels: Optional[Set[str]] = None
74 begin: Optional[str] = None
75 end: Optional[str] = None
77 def __post_init__(self):
78 if self.labels is not None and (self.begin or self.end):
79 raise ValueError("This struct can only be initialized with a labels set or "
80 "a begin (and/or) end specifier")
83class TaskDef:
84 """TaskDef is a collection of information about task needed by Pipeline.
86 The information includes task name, configuration object and optional
87 task class. This class is just a collection of attributes and it exposes
88 all of them so that attributes could potentially be modified in place
89 (e.g. if configuration needs extra overrides).
91 Attributes
92 ----------
93 taskName : `str`
94 `PipelineTask` class name, currently it is not specified whether this
95 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
96 Framework should be prepared to handle all cases.
97 config : `lsst.pex.config.Config`
98 Instance of the configuration class corresponding to this task class,
99 usually with all overrides applied. This config will be frozen.
100 taskClass : `type` or ``None``
101 `PipelineTask` class object, can be ``None``. If ``None`` then
102 framework will have to locate and load class.
103 label : `str`, optional
104 Task label, usually a short string unique in a pipeline.
105 """
106 def __init__(self, taskName, config, taskClass=None, label=""):
107 self.taskName = taskName
108 config.freeze()
109 self.config = config
110 self.taskClass = taskClass
111 self.label = label
112 self.connections = config.connections.ConnectionsClass(config=config)
114 @property
115 def configDatasetName(self):
116 """Name of a dataset type for configuration of this task (`str`)
117 """
118 return self.label + "_config"
120 @property
121 def metadataDatasetName(self):
122 """Name of a dataset type for metadata of this task, `None` if
123 metadata is not to be saved (`str`)
124 """
125 if self.config.saveMetadata:
126 return self.label + "_metadata"
127 else:
128 return None
130 def __str__(self):
131 rep = "TaskDef(" + self.taskName
132 if self.label:
133 rep += ", label=" + self.label
134 rep += ")"
135 return rep
137 def __eq__(self, other: object) -> bool:
138 if not isinstance(other, TaskDef):
139 return False
140 return self.config == other.config and\
141 self.taskClass == other.taskClass and\
142 self.label == other.label
144 def __hash__(self):
145 return hash((self.taskClass, self.label))
148class Pipeline:
149 """A `Pipeline` is a representation of a series of tasks to run, and the
150 configuration for those tasks.
152 Parameters
153 ----------
154 description : `str`
155 A description of that this pipeline does.
156 """
157 def __init__(self, description: str):
158 pipeline_dict = {"description": description, "tasks": {}}
159 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
161 @classmethod
162 def fromFile(cls, filename: str) -> Pipeline:
163 """Load a pipeline defined in a pipeline yaml file.
165 Parameters
166 ----------
167 filename: `str`
168 A path that points to a pipeline defined in yaml format. This
169 filename may also supply additional labels to be used in
170 subsetting the loaded Pipeline. These labels are separated from
171 the path by a colon, and may be specified as a comma separated
172 list, or a range denoted as beginning..end. Beginning or end may
173 be empty, in which case the range will be a half open interval.
174 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
175 that range based selection is not well defined for pipelines that
176 are not linear in nature, and correct behavior is not guaranteed,
177 or may vary from run to run.
179 Returns
180 -------
181 pipeline: `Pipeline`
182 The pipeline loaded from specified location with appropriate (if
183 any) subsetting
185 Notes
186 -----
187 This method attempts to prune any contracts that contain labels which
188 are not in the declared subset of labels. This pruning is done using a
189 string based matching due to the nature of contracts and may prune more
190 than it should.
191 """
192 # Split up the filename and any labels that were supplied
193 filename, labelSpecifier = cls._parseFileSpecifier(filename)
194 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
196 # If there are labels supplied, only keep those
197 if labelSpecifier is not None:
198 pipeline = pipeline.subsetFromLabels(labelSpecifier)
199 return pipeline
201 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
202 """Subset a pipeline to contain only labels specified in labelSpecifier
204 Parameters
205 ----------
206 labelSpecifier : `labelSpecifier`
207 Object containing labels that describes how to subset a pipeline.
209 Returns
210 -------
211 pipeline : `Pipeline`
212 A new pipeline object that is a subset of the old pipeline
214 Raises
215 ------
216 ValueError
217 Raised if there is an issue with specified labels
219 Notes
220 -----
221 This method attempts to prune any contracts that contain labels which
222 are not in the declared subset of labels. This pruning is done using a
223 string based matching due to the nature of contracts and may prune more
224 than it should.
225 """
227 pipeline = copy.deepcopy(self)
229 def remove_contracts(label: str):
230 """Remove any contracts that contain the given label
232 String comparison used in this way is not the most elegant and may
233 have issues, but it is the only feasible way when users can specify
234 contracts with generic strings.
235 """
236 new_contracts = []
237 for contract in pipeline._pipelineIR.contracts:
238 # match a label that is not preceded by an ASCII identifier, or
239 # is the start of a line and is followed by a dot
240 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract):
241 continue
242 new_contracts.append(contract)
243 pipeline._pipelineIR.contracts = new_contracts
245 # Labels supplied as a set, explicitly remove any that are not in
246 # That list
247 if labelSpecifier.labels:
248 # verify all the labels are in the pipeline
249 if not labelSpecifier.labels.issubset(pipeline._pipelineIR.tasks.keys()):
250 difference = labelSpecifier.labels.difference(pipeline._pipelineIR.tasks.keys())
251 raise ValueError("Not all supplied labels are in the pipeline definition, extra labels:"
252 f"{difference}")
253 # copy needed so as to not modify while iterating
254 pipeline_labels = list(pipeline._pipelineIR.tasks.keys())
255 for label in pipeline_labels:
256 if label not in labelSpecifier.labels:
257 pipeline.removeTask(label)
258 remove_contracts(label)
259 # Labels supplied as a range, first create a list of all the labels
260 # in the pipeline sorted according to task dependency. Then only
261 # keep labels that lie between the supplied bounds
262 else:
263 # Use a dict for fast searching while preserving order
264 # save contracts and remove them so they do not fail in the
265 # expansion step, will be restored after. This is needed because
266 # a user may only configure the tasks they intend to run, which
267 # may cause some contracts to fail if they will later be dropped
268 contractSave = pipeline._pipelineIR.contracts
269 pipeline._pipelineIR.contracts = []
270 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
271 pipeline._pipelineIR.contracts = contractSave
273 # Verify the bounds are in the labels
274 if labelSpecifier.begin is not None:
275 if labelSpecifier.begin not in labels:
276 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
277 "pipeline definition")
278 if labelSpecifier.end is not None:
279 if labelSpecifier.end not in labels:
280 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
281 "definition")
283 closed = False
284 for label in labels:
285 # if there is a begin label delete all labels until it is
286 # reached.
287 if labelSpecifier.begin:
288 if label != labelSpecifier.begin:
289 pipeline.removeTask(label)
290 remove_contracts(label)
291 continue
292 else:
293 labelSpecifier.begin = None
294 # if there is an end specifier, keep all tasks until the
295 # end specifier is reached, afterwards delete the labels
296 if labelSpecifier.end:
297 if label != labelSpecifier.end:
298 if closed:
299 pipeline.removeTask(label)
300 remove_contracts(label)
301 continue
302 else:
303 closed = True
304 return pipeline
306 @staticmethod
307 def _parseFileSpecifier(fileSpecifer):
308 """Split appart a filename path from label subsets
309 """
310 split = fileSpecifer.split(':')
311 # There is only a filename, return just that
312 if len(split) == 1:
313 return fileSpecifer, None
314 # More than one specifier provided, bail out
315 if len(split) > 2:
316 raise ValueError("Only one : is allowed when specifying a pipeline to load")
317 else:
318 labelSubset: str
319 filename: str
320 filename, labelSubset = split[0], split[1]
321 # labels supplied as a list
322 if ',' in labelSubset:
323 if '..' in labelSubset:
324 raise ValueError("Can only specify a list of labels or a range"
325 "when loading a Pipline not both")
326 labels = set(labelSubset.split(","))
327 specifier = LabelSpecifier(labels=labels)
328 # labels supplied as a range
329 elif '..' in labelSubset:
330 # Try to destructure the labelSubset, this will fail if more
331 # than one range is specified
332 try:
333 begin, end = labelSubset.split("..")
334 except ValueError:
335 raise ValueError("Only one range can be specified when loading a pipeline")
336 specifier = LabelSpecifier(begin=begin if begin else None, end=end if end else None)
337 # Assume anything else is a single label
338 else:
339 labels = {labelSubset}
340 specifier = LabelSpecifier(labels=labels)
342 return filename, specifier
344 @classmethod
345 def fromString(cls, pipeline_string: str) -> Pipeline:
346 """Create a pipeline from string formatted as a pipeline document.
348 Parameters
349 ----------
350 pipeline_string : `str`
351 A string that is formatted according like a pipeline document
353 Returns
354 -------
355 pipeline: `Pipeline`
356 """
357 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
358 return pipeline
360 @classmethod
361 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
362 """Create a pipeline from an already created `PipelineIR` object.
364 Parameters
365 ----------
366 deserialized_pipeline: `PipelineIR`
367 An already created pipeline intermediate representation object
369 Returns
370 -------
371 pipeline: `Pipeline`
372 """
373 pipeline = cls.__new__(cls)
374 pipeline._pipelineIR = deserialized_pipeline
375 return pipeline
377 @classmethod
378 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
379 """Create a new pipeline by copying an already existing `Pipeline`.
381 Parameters
382 ----------
383 pipeline: `Pipeline`
384 An already created pipeline intermediate representation object
386 Returns
387 -------
388 pipeline: `Pipeline`
389 """
390 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
392 def __str__(self) -> str:
393 return str(self._pipelineIR)
395 def addInstrument(self, instrument: Union[Instrument, str]):
396 """Add an instrument to the pipeline, or replace an instrument that is
397 already defined.
399 Parameters
400 ----------
401 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
402 Either a derived class object of a `lsst.daf.butler.instrument` or a
403 string corresponding to a fully qualified
404 `lsst.daf.butler.instrument` name.
405 """
406 if isinstance(instrument, str):
407 pass
408 else:
409 # TODO: assume that this is a subclass of Instrument, no type checking
410 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
411 self._pipelineIR.instrument = instrument
413 def getInstrument(self):
414 """Get the instrument from the pipeline.
416 Returns
417 -------
418 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
419 A derived class object of a `lsst.daf.butler.instrument`, a string
420 corresponding to a fully qualified `lsst.daf.butler.instrument`
421 name, or None if the pipeline does not have an instrument.
422 """
423 return self._pipelineIR.instrument
425 def addTask(self, task: Union[PipelineTask, str], label: str):
426 """Add a new task to the pipeline, or replace a task that is already
427 associated with the supplied label.
429 Parameters
430 ----------
431 task: `PipelineTask` or `str`
432 Either a derived class object of a `PipelineTask` or a string
433 corresponding to a fully qualified `PipelineTask` name.
434 label: `str`
435 A label that is used to identify the `PipelineTask` being added
436 """
437 if isinstance(task, str):
438 taskName = task
439 elif issubclass(task, PipelineTask):
440 taskName = f"{task.__module__}.{task.__qualname__}"
441 else:
442 raise ValueError("task must be either a child class of PipelineTask or a string containing"
443 " a fully qualified name to one")
444 if not label:
445 # in some cases (with command line-generated pipeline) tasks can
446 # be defined without label which is not acceptable, use task
447 # _DefaultName in that case
448 if isinstance(task, str):
449 task = doImport(task)
450 label = task._DefaultName
451 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
453 def removeTask(self, label: str):
454 """Remove a task from the pipeline.
456 Parameters
457 ----------
458 label : `str`
459 The label used to identify the task that is to be removed
461 Raises
462 ------
463 KeyError
464 If no task with that label exists in the pipeline
466 """
467 self._pipelineIR.tasks.pop(label)
469 def addConfigOverride(self, label: str, key: str, value: object):
470 """Apply single config override.
472 Parameters
473 ----------
474 label : `str`
475 Label of the task.
476 key: `str`
477 Fully-qualified field name.
478 value : object
479 Value to be given to a field.
480 """
481 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
483 def addConfigFile(self, label: str, filename: str):
484 """Add overrides from a specified file.
486 Parameters
487 ----------
488 label : `str`
489 The label used to identify the task associated with config to
490 modify
491 filename : `str`
492 Path to the override file.
493 """
494 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
496 def addConfigPython(self, label: str, pythonString: str):
497 """Add Overrides by running a snippet of python code against a config.
499 Parameters
500 ----------
501 label : `str`
502 The label used to identity the task associated with config to
503 modify.
504 pythonString: `str`
505 A string which is valid python code to be executed. This is done
506 with config as the only local accessible value.
507 """
508 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
510 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
511 if label not in self._pipelineIR.tasks:
512 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
513 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
515 def toFile(self, filename: str):
516 self._pipelineIR.to_file(filename)
518 def toExpandedPipeline(self) -> Generator[TaskDef]:
519 """Returns a generator of TaskDefs which can be used to create quantum
520 graphs.
522 Returns
523 -------
524 generator : generator of `TaskDef`
525 The generator returned will be the sorted iterator of tasks which
526 are to be used in constructing a quantum graph.
528 Raises
529 ------
530 NotImplementedError
531 If a dataId is supplied in a config block. This is in place for
532 future use
533 """
534 taskDefs = []
535 for label, taskIR in self._pipelineIR.tasks.items():
536 taskClass = doImport(taskIR.klass)
537 taskName = taskClass.__qualname__
538 config = taskClass.ConfigClass()
539 overrides = ConfigOverrides()
540 if self._pipelineIR.instrument is not None:
541 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
542 if taskIR.config is not None:
543 for configIR in taskIR.config:
544 if configIR.dataId is not None:
545 raise NotImplementedError("Specializing a config on a partial data id is not yet "
546 "supported in Pipeline definition")
547 # only apply override if it applies to everything
548 if configIR.dataId is None:
549 if configIR.file:
550 for configFile in configIR.file:
551 overrides.addFileOverride(os.path.expandvars(configFile))
552 if configIR.python is not None:
553 overrides.addPythonOverride(configIR.python)
554 for key, value in configIR.rest.items():
555 overrides.addValueOverride(key, value)
556 overrides.applyTo(config)
557 # This may need to be revisited
558 config.validate()
559 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
561 # lets evaluate the contracts
562 if self._pipelineIR.contracts is not None:
563 label_to_config = {x.label: x.config for x in taskDefs}
564 for contract in self._pipelineIR.contracts:
565 # execute this in its own line so it can raise a good error message if there was problems
566 # with the eval
567 success = eval(contract.contract, None, label_to_config)
568 if not success:
569 extra_info = f": {contract.msg}" if contract.msg is not None else ""
570 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
571 f"satisfied{extra_info}")
573 yield from pipeTools.orderPipeline(taskDefs)
575 def __len__(self):
576 return len(self._pipelineIR.tasks)
578 def __eq__(self, other: "Pipeline"):
579 if not isinstance(other, Pipeline):
580 return False
581 return self._pipelineIR == other._pipelineIR
584@dataclass(frozen=True)
585class TaskDatasetTypes:
586 """An immutable struct that extracts and classifies the dataset types used
587 by a `PipelineTask`
588 """
590 initInputs: NamedValueSet[DatasetType]
591 """Dataset types that are needed as inputs in order to construct this Task.
593 Task-level `initInputs` may be classified as either
594 `~PipelineDatasetTypes.initInputs` or
595 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
596 """
598 initOutputs: NamedValueSet[DatasetType]
599 """Dataset types that may be written after constructing this Task.
601 Task-level `initOutputs` may be classified as either
602 `~PipelineDatasetTypes.initOutputs` or
603 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
604 """
606 inputs: NamedValueSet[DatasetType]
607 """Dataset types that are regular inputs to this Task.
609 If an input dataset needed for a Quantum cannot be found in the input
610 collection(s) or produced by another Task in the Pipeline, that Quantum
611 (and all dependent Quanta) will not be produced.
613 Task-level `inputs` may be classified as either
614 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
615 at the Pipeline level.
616 """
618 prerequisites: NamedValueSet[DatasetType]
619 """Dataset types that are prerequisite inputs to this Task.
621 Prerequisite inputs must exist in the input collection(s) before the
622 pipeline is run, but do not constrain the graph - if a prerequisite is
623 missing for a Quantum, `PrerequisiteMissingError` is raised.
625 Prerequisite inputs are not resolved until the second stage of
626 QuantumGraph generation.
627 """
629 outputs: NamedValueSet[DatasetType]
630 """Dataset types that are produced by this Task.
632 Task-level `outputs` may be classified as either
633 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
634 at the Pipeline level.
635 """
637 @classmethod
638 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
639 """Extract and classify the dataset types from a single `PipelineTask`.
641 Parameters
642 ----------
643 taskDef: `TaskDef`
644 An instance of a `TaskDef` class for a particular `PipelineTask`.
645 registry: `Registry`
646 Registry used to construct normalized `DatasetType` objects and
647 retrieve those that are incomplete.
649 Returns
650 -------
651 types: `TaskDatasetTypes`
652 The dataset types used by this task.
653 """
654 def makeDatasetTypesSet(connectionType, freeze=True):
655 """Constructs a set of true `DatasetType` objects
657 Parameters
658 ----------
659 connectionType : `str`
660 Name of the connection type to produce a set for, corresponds
661 to an attribute of type `list` on the connection class instance
662 freeze : `bool`, optional
663 If `True`, call `NamedValueSet.freeze` on the object returned.
665 Returns
666 -------
667 datasetTypes : `NamedValueSet`
668 A set of all datasetTypes which correspond to the input
669 connection type specified in the connection class of this
670 `PipelineTask`
672 Notes
673 -----
674 This function is a closure over the variables ``registry`` and
675 ``taskDef``.
676 """
677 datasetTypes = NamedValueSet()
678 for c in iterConnections(taskDef.connections, connectionType):
679 dimensions = set(getattr(c, 'dimensions', set()))
680 if "skypix" in dimensions:
681 try:
682 datasetType = registry.getDatasetType(c.name)
683 except LookupError as err:
684 raise LookupError(
685 f"DatasetType '{c.name}' referenced by "
686 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
687 f"placeholder, but does not already exist in the registry. "
688 f"Note that reference catalog names are now used as the dataset "
689 f"type name instead of 'ref_cat'."
690 ) from err
691 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
692 rest2 = set(dim.name for dim in datasetType.dimensions
693 if not isinstance(dim, SkyPixDimension))
694 if rest1 != rest2:
695 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
696 f"connections ({rest1}) are inconsistent with those in "
697 f"registry's version of this dataset ({rest2}).")
698 else:
699 # Component dataset types are not explicitly in the
700 # registry. This complicates consistency checks with
701 # registry and requires we work out the composite storage
702 # class.
703 registryDatasetType = None
704 try:
705 registryDatasetType = registry.getDatasetType(c.name)
706 except KeyError:
707 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
708 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
709 if componentName else None
710 datasetType = c.makeDatasetType(
711 registry.dimensions,
712 parentStorageClass=parentStorageClass
713 )
714 registryDatasetType = datasetType
715 else:
716 datasetType = c.makeDatasetType(
717 registry.dimensions,
718 parentStorageClass=registryDatasetType.parentStorageClass
719 )
721 if registryDatasetType and datasetType != registryDatasetType:
722 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
723 f"registry definition ({registryDatasetType}) "
724 f"for {taskDef.label}.")
725 datasetTypes.add(datasetType)
726 if freeze:
727 datasetTypes.freeze()
728 return datasetTypes
730 # optionally add output dataset for metadata
731 outputs = makeDatasetTypesSet("outputs", freeze=False)
732 if taskDef.metadataDatasetName is not None:
733 # Metadata is supposed to be of the PropertySet type, its dimensions
734 # correspond to a task quantum
735 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
736 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
737 outputs.freeze()
739 return cls(
740 initInputs=makeDatasetTypesSet("initInputs"),
741 initOutputs=makeDatasetTypesSet("initOutputs"),
742 inputs=makeDatasetTypesSet("inputs"),
743 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
744 outputs=outputs,
745 )
748@dataclass(frozen=True)
749class PipelineDatasetTypes:
750 """An immutable struct that classifies the dataset types used in a
751 `Pipeline`.
752 """
754 initInputs: NamedValueSet[DatasetType]
755 """Dataset types that are needed as inputs in order to construct the Tasks
756 in this Pipeline.
758 This does not include dataset types that are produced when constructing
759 other Tasks in the Pipeline (these are classified as `initIntermediates`).
760 """
762 initOutputs: NamedValueSet[DatasetType]
763 """Dataset types that may be written after constructing the Tasks in this
764 Pipeline.
766 This does not include dataset types that are also used as inputs when
767 constructing other Tasks in the Pipeline (these are classified as
768 `initIntermediates`).
769 """
771 initIntermediates: NamedValueSet[DatasetType]
772 """Dataset types that are both used when constructing one or more Tasks
773 in the Pipeline and produced as a side-effect of constructing another
774 Task in the Pipeline.
775 """
777 inputs: NamedValueSet[DatasetType]
778 """Dataset types that are regular inputs for the full pipeline.
780 If an input dataset needed for a Quantum cannot be found in the input
781 collection(s), that Quantum (and all dependent Quanta) will not be
782 produced.
783 """
785 prerequisites: NamedValueSet[DatasetType]
786 """Dataset types that are prerequisite inputs for the full Pipeline.
788 Prerequisite inputs must exist in the input collection(s) before the
789 pipeline is run, but do not constrain the graph - if a prerequisite is
790 missing for a Quantum, `PrerequisiteMissingError` is raised.
792 Prerequisite inputs are not resolved until the second stage of
793 QuantumGraph generation.
794 """
796 intermediates: NamedValueSet[DatasetType]
797 """Dataset types that are output by one Task in the Pipeline and consumed
798 as inputs by one or more other Tasks in the Pipeline.
799 """
801 outputs: NamedValueSet[DatasetType]
802 """Dataset types that are output by a Task in the Pipeline and not consumed
803 by any other Task in the Pipeline.
804 """
806 byTask: Mapping[str, TaskDatasetTypes]
807 """Per-Task dataset types, keyed by label in the `Pipeline`.
809 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
810 neither has been modified since the dataset types were extracted, of
811 course).
812 """
814 @classmethod
815 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
816 """Extract and classify the dataset types from all tasks in a
817 `Pipeline`.
819 Parameters
820 ----------
821 pipeline: `Pipeline`
822 An ordered collection of tasks that can be run together.
823 registry: `Registry`
824 Registry used to construct normalized `DatasetType` objects and
825 retrieve those that are incomplete.
827 Returns
828 -------
829 types: `PipelineDatasetTypes`
830 The dataset types used by this `Pipeline`.
832 Raises
833 ------
834 ValueError
835 Raised if Tasks are inconsistent about which datasets are marked
836 prerequisite. This indicates that the Tasks cannot be run as part
837 of the same `Pipeline`.
838 """
839 allInputs = NamedValueSet()
840 allOutputs = NamedValueSet()
841 allInitInputs = NamedValueSet()
842 allInitOutputs = NamedValueSet()
843 prerequisites = NamedValueSet()
844 byTask = dict()
845 if isinstance(pipeline, Pipeline):
846 pipeline = pipeline.toExpandedPipeline()
847 for taskDef in pipeline:
848 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
849 allInitInputs |= thisTask.initInputs
850 allInitOutputs |= thisTask.initOutputs
851 allInputs |= thisTask.inputs
852 prerequisites |= thisTask.prerequisites
853 allOutputs |= thisTask.outputs
854 byTask[taskDef.label] = thisTask
855 if not prerequisites.isdisjoint(allInputs):
856 raise ValueError("{} marked as both prerequisites and regular inputs".format(
857 {dt.name for dt in allInputs & prerequisites}
858 ))
859 if not prerequisites.isdisjoint(allOutputs):
860 raise ValueError("{} marked as both prerequisites and outputs".format(
861 {dt.name for dt in allOutputs & prerequisites}
862 ))
863 # Make sure that components which are marked as inputs get treated as
864 # intermediates if there is an output which produces the composite
865 # containing the component
866 intermediateComponents = NamedValueSet()
867 intermediateComposites = NamedValueSet()
868 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
869 for dsType in allInputs:
870 # get the name of a possible component
871 name, component = dsType.nameAndComponent()
872 # if there is a component name, that means this is a component
873 # DatasetType, if there is an output which produces the parent of
874 # this component, treat this input as an intermediate
875 if component is not None:
876 if name in outputNameMapping:
877 if outputNameMapping[name].dimensions != dsType.dimensions:
878 raise ValueError(f"Component dataset type {dsType.name} has different "
879 f"dimensions ({dsType.dimensions}) than its parent "
880 f"({outputNameMapping[name].dimensions}).")
881 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
882 universe=registry.dimensions)
883 intermediateComponents.add(dsType)
884 intermediateComposites.add(composite)
886 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
887 common = a.names & b.names
888 for name in common:
889 if a[name] != b[name]:
890 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
892 checkConsistency(allInitInputs, allInitOutputs)
893 checkConsistency(allInputs, allOutputs)
894 checkConsistency(allInputs, intermediateComposites)
895 checkConsistency(allOutputs, intermediateComposites)
897 def frozen(s: NamedValueSet) -> NamedValueSet:
898 s.freeze()
899 return s
901 return cls(
902 initInputs=frozen(allInitInputs - allInitOutputs),
903 initIntermediates=frozen(allInitInputs & allInitOutputs),
904 initOutputs=frozen(allInitOutputs - allInitInputs),
905 inputs=frozen(allInputs - allOutputs - intermediateComponents),
906 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
907 outputs=frozen(allOutputs - allInputs - intermediateComposites),
908 prerequisites=frozen(prerequisites),
909 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
910 )