Coverage for python/lsst/pipe/base/pipeline.py : 16%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Mapping, Set, Union, Generator, TYPE_CHECKING, Optional
35import copy
36import os
37import re
39# -----------------------------
40# Imports for other modules --
41from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
42from lsst.utils import doImport
43from .configOverrides import ConfigOverrides
44from .connections import iterConnections
45from .pipelineTask import PipelineTask
47from . import pipelineIR
48from . import pipeTools
50if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true
51 from lsst.obs.base.instrument import Instrument
53# ----------------------------------
54# Local non-exported definitions --
55# ----------------------------------
57# ------------------------
58# Exported definitions --
59# ------------------------
62@dataclass
63class LabelSpecifier:
64 """A structure to specify a subset of labels to load
66 This structure may contain a set of labels to be used in subsetting a
67 pipeline, or a beginning and end point. Beginning or end may be empty,
68 in which case the range will be a half open interval. Unlike python
69 iteration bounds, end bounds are *INCLUDED*. Note that range based
70 selection is not well defined for pipelines that are not linear in nature,
71 and correct behavior is not guaranteed, or may vary from run to run.
72 """
73 labels: Optional[Set[str]] = None
74 begin: Optional[str] = None
75 end: Optional[str] = None
77 def __post_init__(self):
78 if self.labels is not None and (self.begin or self.end):
79 raise ValueError("This struct can only be initialized with a labels set or "
80 "a begin (and/or) end specifier")
83class TaskDef:
84 """TaskDef is a collection of information about task needed by Pipeline.
86 The information includes task name, configuration object and optional
87 task class. This class is just a collection of attributes and it exposes
88 all of them so that attributes could potentially be modified in place
89 (e.g. if configuration needs extra overrides).
91 Attributes
92 ----------
93 taskName : `str`
94 `PipelineTask` class name, currently it is not specified whether this
95 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
96 Framework should be prepared to handle all cases.
97 config : `lsst.pex.config.Config`
98 Instance of the configuration class corresponding to this task class,
99 usually with all overrides applied. This config will be frozen.
100 taskClass : `type` or ``None``
101 `PipelineTask` class object, can be ``None``. If ``None`` then
102 framework will have to locate and load class.
103 label : `str`, optional
104 Task label, usually a short string unique in a pipeline.
105 """
106 def __init__(self, taskName, config, taskClass=None, label=""):
107 self.taskName = taskName
108 config.freeze()
109 self.config = config
110 self.taskClass = taskClass
111 self.label = label
112 self.connections = config.connections.ConnectionsClass(config=config)
114 @property
115 def configDatasetName(self):
116 """Name of a dataset type for configuration of this task (`str`)
117 """
118 return self.label + "_config"
120 @property
121 def metadataDatasetName(self):
122 """Name of a dataset type for metadata of this task, `None` if
123 metadata is not to be saved (`str`)
124 """
125 if self.config.saveMetadata:
126 return self.label + "_metadata"
127 else:
128 return None
130 def __str__(self):
131 rep = "TaskDef(" + self.taskName
132 if self.label:
133 rep += ", label=" + self.label
134 rep += ")"
135 return rep
137 def __eq__(self, other: object) -> bool:
138 if not isinstance(other, TaskDef):
139 return False
140 return self.config == other.config and\
141 self.taskClass == other.taskClass and\
142 self.label == other.label
144 def __hash__(self):
145 return hash((self.taskClass, self.label))
148class Pipeline:
149 """A `Pipeline` is a representation of a series of tasks to run, and the
150 configuration for those tasks.
152 Parameters
153 ----------
154 description : `str`
155 A description of that this pipeline does.
156 """
157 def __init__(self, description: str):
158 pipeline_dict = {"description": description, "tasks": {}}
159 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
161 @classmethod
162 def fromFile(cls, filename: str) -> Pipeline:
163 """Load a pipeline defined in a pipeline yaml file.
165 Parameters
166 ----------
167 filename: `str`
168 A path that points to a pipeline defined in yaml format. This
169 filename may also supply additional labels to be used in
170 subsetting the loaded Pipeline. These labels are separated from
171 the path by a colon, and may be specified as a comma separated
172 list, or a range denoted as beginning..end. Beginning or end may
173 be empty, in which case the range will be a half open interval.
174 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
175 that range based selection is not well defined for pipelines that
176 are not linear in nature, and correct behavior is not guaranteed,
177 or may vary from run to run.
179 Returns
180 -------
181 pipeline: `Pipeline`
182 The pipeline loaded from specified location with appropriate (if
183 any) subsetting
185 Notes
186 -----
187 This method attempts to prune any contracts that contain labels which
188 are not in the declared subset of labels. This pruning is done using a
189 string based matching due to the nature of contracts and may prune more
190 than it should.
191 """
192 # Split up the filename and any labels that were supplied
193 filename, labelSpecifier = cls._parseFileSpecifier(filename)
194 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_file(filename))
196 # If there are labels supplied, only keep those
197 if labelSpecifier is not None:
198 pipeline = pipeline.subsetFromLabels(labelSpecifier)
199 return pipeline
201 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
202 """Subset a pipeline to contain only labels specified in labelSpecifier
204 Parameters
205 ----------
206 labelSpecifier : `labelSpecifier`
207 Object containing labels that describes how to subset a pipeline.
209 Returns
210 -------
211 pipeline : `Pipeline`
212 A new pipeline object that is a subset of the old pipeline
214 Raises
215 ------
216 ValueError
217 Raised if there is an issue with specified labels
219 Notes
220 -----
221 This method attempts to prune any contracts that contain labels which
222 are not in the declared subset of labels. This pruning is done using a
223 string based matching due to the nature of contracts and may prune more
224 than it should.
225 """
227 pipeline = copy.deepcopy(self)
229 def remove_contracts(label: str):
230 """Remove any contracts that contain the given label
232 String comparison used in this way is not the most elegant and may
233 have issues, but it is the only feasible way when users can specify
234 contracts with generic strings.
235 """
236 new_contracts = []
237 for contract in pipeline._pipelineIR.contracts:
238 # match a label that is not preceded by an ASCII identifier, or
239 # is the start of a line and is followed by a dot
240 if re.match(f".*([^A-Za-z0-9_]|^){label}[.]", contract.contract):
241 continue
242 new_contracts.append(contract)
243 pipeline._pipelineIR.contracts = new_contracts
245 # Labels supplied as a set, explicitly remove any that are not in
246 # That list
247 if labelSpecifier.labels:
248 # verify all the labels are in the pipeline
249 if not labelSpecifier.labels.issubset(pipeline._pipelineIR.tasks.keys()):
250 difference = labelSpecifier.labels.difference(pipeline._pipelineIR.tasks.keys())
251 raise ValueError("Not all supplied labels are in the pipeline definition, extra labels:"
252 f"{difference}")
253 # copy needed so as to not modify while iterating
254 pipeline_labels = list(pipeline._pipelineIR.tasks.keys())
255 for label in pipeline_labels:
256 if label not in labelSpecifier.labels:
257 pipeline.removeTask(label)
258 remove_contracts(label)
259 # Labels supplied as a range, first create a list of all the labels
260 # in the pipeline sorted according to task dependency. Then only
261 # keep labels that lie between the supplied bounds
262 else:
263 # Use a dict for fast searching while preserving order
264 # save contracts and remove them so they do not fail in the
265 # expansion step, will be restored after. This is needed because
266 # a user may only configure the tasks they intend to run, which
267 # may cause some contracts to fail if they will later be dropped
268 contractSave = pipeline._pipelineIR.contracts
269 pipeline._pipelineIR.contracts = []
270 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
271 pipeline._pipelineIR.contracts = contractSave
273 # Verify the bounds are in the labels
274 if labelSpecifier.begin is not None:
275 if labelSpecifier.begin not in labels:
276 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
277 "pipeline definition")
278 if labelSpecifier.end is not None:
279 if labelSpecifier.end not in labels:
280 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
281 "definition")
283 closed = False
284 for label in labels:
285 # if there is a begin label delete all labels until it is
286 # reached.
287 if labelSpecifier.begin:
288 if label != labelSpecifier.begin:
289 pipeline.removeTask(label)
290 remove_contracts(label)
291 continue
292 else:
293 labelSpecifier.begin = None
294 # if there is an end specifier, keep all tasks until the
295 # end specifier is reached, afterwards delete the labels
296 if labelSpecifier.end:
297 if label != labelSpecifier.end:
298 if closed:
299 pipeline.removeTask(label)
300 remove_contracts(label)
301 continue
302 else:
303 closed = True
304 return pipeline
306 @staticmethod
307 def _parseFileSpecifier(fileSpecifer):
308 """Split appart a filename path from label subsets
309 """
310 split = fileSpecifer.split(':')
311 # There is only a filename, return just that
312 if len(split) == 1:
313 return fileSpecifer, None
314 # More than one specifier provided, bail out
315 if len(split) > 2:
316 raise ValueError("Only one : is allowed when specifying a pipeline to load")
317 else:
318 labelSubset: str
319 filename: str
320 filename, labelSubset = split[0], split[1]
321 # labels supplied as a list
322 if ',' in labelSubset:
323 if '..' in labelSubset:
324 raise ValueError("Can only specify a list of labels or a range"
325 "when loading a Pipline not both")
326 labels = set(labelSubset.split(","))
327 specifier = LabelSpecifier(labels=labels)
328 # labels supplied as a range
329 elif '..' in labelSubset:
330 # Try to destructure the labelSubset, this will fail if more
331 # than one range is specified
332 try:
333 begin, end = labelSubset.split("..")
334 except ValueError:
335 raise ValueError("Only one range can be specified when loading a pipeline")
336 specifier = LabelSpecifier(begin=begin if begin else None, end=end if end else None)
337 # Assume anything else is a single label
338 else:
339 labels = {labelSubset}
340 specifier = LabelSpecifier(labels=labels)
342 return filename, specifier
344 @classmethod
345 def fromString(cls, pipeline_string: str) -> Pipeline:
346 """Create a pipeline from string formatted as a pipeline document.
348 Parameters
349 ----------
350 pipeline_string : `str`
351 A string that is formatted according like a pipeline document
353 Returns
354 -------
355 pipeline: `Pipeline`
356 """
357 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
358 return pipeline
360 @classmethod
361 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
362 """Create a pipeline from an already created `PipelineIR` object.
364 Parameters
365 ----------
366 deserialized_pipeline: `PipelineIR`
367 An already created pipeline intermediate representation object
369 Returns
370 -------
371 pipeline: `Pipeline`
372 """
373 pipeline = cls.__new__(cls)
374 pipeline._pipelineIR = deserialized_pipeline
375 return pipeline
377 @classmethod
378 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
379 """Create a new pipeline by copying an already existing `Pipeline`.
381 Parameters
382 ----------
383 pipeline: `Pipeline`
384 An already created pipeline intermediate representation object
386 Returns
387 -------
388 pipeline: `Pipeline`
389 """
390 return cls.fromIR(copy.deep_copy(pipeline._pipelineIR))
392 def __str__(self) -> str:
393 return str(self._pipelineIR)
395 def addInstrument(self, instrument: Union[Instrument, str]):
396 """Add an instrument to the pipeline, or replace an instrument that is
397 already defined.
399 Parameters
400 ----------
401 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
402 Either a derived class object of a `lsst.daf.butler.instrument` or
403 a string corresponding to a fully qualified
404 `lsst.daf.butler.instrument` name.
405 """
406 if isinstance(instrument, str):
407 pass
408 else:
409 # TODO: assume that this is a subclass of Instrument, no type
410 # checking
411 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
412 self._pipelineIR.instrument = instrument
414 def getInstrument(self):
415 """Get the instrument from the pipeline.
417 Returns
418 -------
419 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
420 A derived class object of a `lsst.daf.butler.instrument`, a string
421 corresponding to a fully qualified `lsst.daf.butler.instrument`
422 name, or None if the pipeline does not have an instrument.
423 """
424 return self._pipelineIR.instrument
426 def addTask(self, task: Union[PipelineTask, str], label: str):
427 """Add a new task to the pipeline, or replace a task that is already
428 associated with the supplied label.
430 Parameters
431 ----------
432 task: `PipelineTask` or `str`
433 Either a derived class object of a `PipelineTask` or a string
434 corresponding to a fully qualified `PipelineTask` name.
435 label: `str`
436 A label that is used to identify the `PipelineTask` being added
437 """
438 if isinstance(task, str):
439 taskName = task
440 elif issubclass(task, PipelineTask):
441 taskName = f"{task.__module__}.{task.__qualname__}"
442 else:
443 raise ValueError("task must be either a child class of PipelineTask or a string containing"
444 " a fully qualified name to one")
445 if not label:
446 # in some cases (with command line-generated pipeline) tasks can
447 # be defined without label which is not acceptable, use task
448 # _DefaultName in that case
449 if isinstance(task, str):
450 task = doImport(task)
451 label = task._DefaultName
452 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
454 def removeTask(self, label: str):
455 """Remove a task from the pipeline.
457 Parameters
458 ----------
459 label : `str`
460 The label used to identify the task that is to be removed
462 Raises
463 ------
464 KeyError
465 If no task with that label exists in the pipeline
467 """
468 self._pipelineIR.tasks.pop(label)
470 def addConfigOverride(self, label: str, key: str, value: object):
471 """Apply single config override.
473 Parameters
474 ----------
475 label : `str`
476 Label of the task.
477 key: `str`
478 Fully-qualified field name.
479 value : object
480 Value to be given to a field.
481 """
482 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
484 def addConfigFile(self, label: str, filename: str):
485 """Add overrides from a specified file.
487 Parameters
488 ----------
489 label : `str`
490 The label used to identify the task associated with config to
491 modify
492 filename : `str`
493 Path to the override file.
494 """
495 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
497 def addConfigPython(self, label: str, pythonString: str):
498 """Add Overrides by running a snippet of python code against a config.
500 Parameters
501 ----------
502 label : `str`
503 The label used to identity the task associated with config to
504 modify.
505 pythonString: `str`
506 A string which is valid python code to be executed. This is done
507 with config as the only local accessible value.
508 """
509 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
511 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR):
512 if label not in self._pipelineIR.tasks:
513 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
514 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
516 def toFile(self, filename: str):
517 self._pipelineIR.to_file(filename)
519 def toExpandedPipeline(self) -> Generator[TaskDef]:
520 """Returns a generator of TaskDefs which can be used to create quantum
521 graphs.
523 Returns
524 -------
525 generator : generator of `TaskDef`
526 The generator returned will be the sorted iterator of tasks which
527 are to be used in constructing a quantum graph.
529 Raises
530 ------
531 NotImplementedError
532 If a dataId is supplied in a config block. This is in place for
533 future use
534 """
535 taskDefs = []
536 for label, taskIR in self._pipelineIR.tasks.items():
537 taskClass = doImport(taskIR.klass)
538 taskName = taskClass.__qualname__
539 config = taskClass.ConfigClass()
540 overrides = ConfigOverrides()
541 if self._pipelineIR.instrument is not None:
542 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
543 if taskIR.config is not None:
544 for configIR in taskIR.config:
545 if configIR.dataId is not None:
546 raise NotImplementedError("Specializing a config on a partial data id is not yet "
547 "supported in Pipeline definition")
548 # only apply override if it applies to everything
549 if configIR.dataId is None:
550 if configIR.file:
551 for configFile in configIR.file:
552 overrides.addFileOverride(os.path.expandvars(configFile))
553 if configIR.python is not None:
554 overrides.addPythonOverride(configIR.python)
555 for key, value in configIR.rest.items():
556 overrides.addValueOverride(key, value)
557 overrides.applyTo(config)
558 # This may need to be revisited
559 config.validate()
560 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
562 # lets evaluate the contracts
563 if self._pipelineIR.contracts is not None:
564 label_to_config = {x.label: x.config for x in taskDefs}
565 for contract in self._pipelineIR.contracts:
566 # execute this in its own line so it can raise a good error
567 # message if there was problems with the eval
568 success = eval(contract.contract, None, label_to_config)
569 if not success:
570 extra_info = f": {contract.msg}" if contract.msg is not None else ""
571 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
572 f"satisfied{extra_info}")
574 yield from pipeTools.orderPipeline(taskDefs)
576 def __len__(self):
577 return len(self._pipelineIR.tasks)
579 def __eq__(self, other: "Pipeline"):
580 if not isinstance(other, Pipeline):
581 return False
582 return self._pipelineIR == other._pipelineIR
585@dataclass(frozen=True)
586class TaskDatasetTypes:
587 """An immutable struct that extracts and classifies the dataset types used
588 by a `PipelineTask`
589 """
591 initInputs: NamedValueSet[DatasetType]
592 """Dataset types that are needed as inputs in order to construct this Task.
594 Task-level `initInputs` may be classified as either
595 `~PipelineDatasetTypes.initInputs` or
596 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
597 """
599 initOutputs: NamedValueSet[DatasetType]
600 """Dataset types that may be written after constructing this Task.
602 Task-level `initOutputs` may be classified as either
603 `~PipelineDatasetTypes.initOutputs` or
604 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
605 """
607 inputs: NamedValueSet[DatasetType]
608 """Dataset types that are regular inputs to this Task.
610 If an input dataset needed for a Quantum cannot be found in the input
611 collection(s) or produced by another Task in the Pipeline, that Quantum
612 (and all dependent Quanta) will not be produced.
614 Task-level `inputs` may be classified as either
615 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
616 at the Pipeline level.
617 """
619 prerequisites: NamedValueSet[DatasetType]
620 """Dataset types that are prerequisite inputs to this Task.
622 Prerequisite inputs must exist in the input collection(s) before the
623 pipeline is run, but do not constrain the graph - if a prerequisite is
624 missing for a Quantum, `PrerequisiteMissingError` is raised.
626 Prerequisite inputs are not resolved until the second stage of
627 QuantumGraph generation.
628 """
630 outputs: NamedValueSet[DatasetType]
631 """Dataset types that are produced by this Task.
633 Task-level `outputs` may be classified as either
634 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
635 at the Pipeline level.
636 """
638 @classmethod
639 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
640 """Extract and classify the dataset types from a single `PipelineTask`.
642 Parameters
643 ----------
644 taskDef: `TaskDef`
645 An instance of a `TaskDef` class for a particular `PipelineTask`.
646 registry: `Registry`
647 Registry used to construct normalized `DatasetType` objects and
648 retrieve those that are incomplete.
650 Returns
651 -------
652 types: `TaskDatasetTypes`
653 The dataset types used by this task.
654 """
655 def makeDatasetTypesSet(connectionType, freeze=True):
656 """Constructs a set of true `DatasetType` objects
658 Parameters
659 ----------
660 connectionType : `str`
661 Name of the connection type to produce a set for, corresponds
662 to an attribute of type `list` on the connection class instance
663 freeze : `bool`, optional
664 If `True`, call `NamedValueSet.freeze` on the object returned.
666 Returns
667 -------
668 datasetTypes : `NamedValueSet`
669 A set of all datasetTypes which correspond to the input
670 connection type specified in the connection class of this
671 `PipelineTask`
673 Notes
674 -----
675 This function is a closure over the variables ``registry`` and
676 ``taskDef``.
677 """
678 datasetTypes = NamedValueSet()
679 for c in iterConnections(taskDef.connections, connectionType):
680 dimensions = set(getattr(c, 'dimensions', set()))
681 if "skypix" in dimensions:
682 try:
683 datasetType = registry.getDatasetType(c.name)
684 except LookupError as err:
685 raise LookupError(
686 f"DatasetType '{c.name}' referenced by "
687 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
688 f"placeholder, but does not already exist in the registry. "
689 f"Note that reference catalog names are now used as the dataset "
690 f"type name instead of 'ref_cat'."
691 ) from err
692 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
693 rest2 = set(dim.name for dim in datasetType.dimensions
694 if not isinstance(dim, SkyPixDimension))
695 if rest1 != rest2:
696 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
697 f"connections ({rest1}) are inconsistent with those in "
698 f"registry's version of this dataset ({rest2}).")
699 else:
700 # Component dataset types are not explicitly in the
701 # registry. This complicates consistency checks with
702 # registry and requires we work out the composite storage
703 # class.
704 registryDatasetType = None
705 try:
706 registryDatasetType = registry.getDatasetType(c.name)
707 except KeyError:
708 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
709 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
710 if componentName else None
711 datasetType = c.makeDatasetType(
712 registry.dimensions,
713 parentStorageClass=parentStorageClass
714 )
715 registryDatasetType = datasetType
716 else:
717 datasetType = c.makeDatasetType(
718 registry.dimensions,
719 parentStorageClass=registryDatasetType.parentStorageClass
720 )
722 if registryDatasetType and datasetType != registryDatasetType:
723 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
724 f"registry definition ({registryDatasetType}) "
725 f"for {taskDef.label}.")
726 datasetTypes.add(datasetType)
727 if freeze:
728 datasetTypes.freeze()
729 return datasetTypes
731 # optionally add output dataset for metadata
732 outputs = makeDatasetTypesSet("outputs", freeze=False)
733 if taskDef.metadataDatasetName is not None:
734 # Metadata is supposed to be of the PropertySet type, its
735 # dimensions correspond to a task quantum
736 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
737 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
738 outputs.freeze()
740 return cls(
741 initInputs=makeDatasetTypesSet("initInputs"),
742 initOutputs=makeDatasetTypesSet("initOutputs"),
743 inputs=makeDatasetTypesSet("inputs"),
744 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
745 outputs=outputs,
746 )
749@dataclass(frozen=True)
750class PipelineDatasetTypes:
751 """An immutable struct that classifies the dataset types used in a
752 `Pipeline`.
753 """
755 initInputs: NamedValueSet[DatasetType]
756 """Dataset types that are needed as inputs in order to construct the Tasks
757 in this Pipeline.
759 This does not include dataset types that are produced when constructing
760 other Tasks in the Pipeline (these are classified as `initIntermediates`).
761 """
763 initOutputs: NamedValueSet[DatasetType]
764 """Dataset types that may be written after constructing the Tasks in this
765 Pipeline.
767 This does not include dataset types that are also used as inputs when
768 constructing other Tasks in the Pipeline (these are classified as
769 `initIntermediates`).
770 """
772 initIntermediates: NamedValueSet[DatasetType]
773 """Dataset types that are both used when constructing one or more Tasks
774 in the Pipeline and produced as a side-effect of constructing another
775 Task in the Pipeline.
776 """
778 inputs: NamedValueSet[DatasetType]
779 """Dataset types that are regular inputs for the full pipeline.
781 If an input dataset needed for a Quantum cannot be found in the input
782 collection(s), that Quantum (and all dependent Quanta) will not be
783 produced.
784 """
786 prerequisites: NamedValueSet[DatasetType]
787 """Dataset types that are prerequisite inputs for the full Pipeline.
789 Prerequisite inputs must exist in the input collection(s) before the
790 pipeline is run, but do not constrain the graph - if a prerequisite is
791 missing for a Quantum, `PrerequisiteMissingError` is raised.
793 Prerequisite inputs are not resolved until the second stage of
794 QuantumGraph generation.
795 """
797 intermediates: NamedValueSet[DatasetType]
798 """Dataset types that are output by one Task in the Pipeline and consumed
799 as inputs by one or more other Tasks in the Pipeline.
800 """
802 outputs: NamedValueSet[DatasetType]
803 """Dataset types that are output by a Task in the Pipeline and not consumed
804 by any other Task in the Pipeline.
805 """
807 byTask: Mapping[str, TaskDatasetTypes]
808 """Per-Task dataset types, keyed by label in the `Pipeline`.
810 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
811 neither has been modified since the dataset types were extracted, of
812 course).
813 """
815 @classmethod
816 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
817 """Extract and classify the dataset types from all tasks in a
818 `Pipeline`.
820 Parameters
821 ----------
822 pipeline: `Pipeline`
823 An ordered collection of tasks that can be run together.
824 registry: `Registry`
825 Registry used to construct normalized `DatasetType` objects and
826 retrieve those that are incomplete.
828 Returns
829 -------
830 types: `PipelineDatasetTypes`
831 The dataset types used by this `Pipeline`.
833 Raises
834 ------
835 ValueError
836 Raised if Tasks are inconsistent about which datasets are marked
837 prerequisite. This indicates that the Tasks cannot be run as part
838 of the same `Pipeline`.
839 """
840 allInputs = NamedValueSet()
841 allOutputs = NamedValueSet()
842 allInitInputs = NamedValueSet()
843 allInitOutputs = NamedValueSet()
844 prerequisites = NamedValueSet()
845 byTask = dict()
846 if isinstance(pipeline, Pipeline):
847 pipeline = pipeline.toExpandedPipeline()
848 for taskDef in pipeline:
849 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
850 allInitInputs |= thisTask.initInputs
851 allInitOutputs |= thisTask.initOutputs
852 allInputs |= thisTask.inputs
853 prerequisites |= thisTask.prerequisites
854 allOutputs |= thisTask.outputs
855 byTask[taskDef.label] = thisTask
856 if not prerequisites.isdisjoint(allInputs):
857 raise ValueError("{} marked as both prerequisites and regular inputs".format(
858 {dt.name for dt in allInputs & prerequisites}
859 ))
860 if not prerequisites.isdisjoint(allOutputs):
861 raise ValueError("{} marked as both prerequisites and outputs".format(
862 {dt.name for dt in allOutputs & prerequisites}
863 ))
864 # Make sure that components which are marked as inputs get treated as
865 # intermediates if there is an output which produces the composite
866 # containing the component
867 intermediateComponents = NamedValueSet()
868 intermediateComposites = NamedValueSet()
869 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
870 for dsType in allInputs:
871 # get the name of a possible component
872 name, component = dsType.nameAndComponent()
873 # if there is a component name, that means this is a component
874 # DatasetType, if there is an output which produces the parent of
875 # this component, treat this input as an intermediate
876 if component is not None:
877 if name in outputNameMapping:
878 if outputNameMapping[name].dimensions != dsType.dimensions:
879 raise ValueError(f"Component dataset type {dsType.name} has different "
880 f"dimensions ({dsType.dimensions}) than its parent "
881 f"({outputNameMapping[name].dimensions}).")
882 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
883 universe=registry.dimensions)
884 intermediateComponents.add(dsType)
885 intermediateComposites.add(composite)
887 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
888 common = a.names & b.names
889 for name in common:
890 if a[name] != b[name]:
891 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
893 checkConsistency(allInitInputs, allInitOutputs)
894 checkConsistency(allInputs, allOutputs)
895 checkConsistency(allInputs, intermediateComposites)
896 checkConsistency(allOutputs, intermediateComposites)
898 def frozen(s: NamedValueSet) -> NamedValueSet:
899 s.freeze()
900 return s
902 return cls(
903 initInputs=frozen(allInitInputs - allInitOutputs),
904 initIntermediates=frozen(allInitInputs & allInitOutputs),
905 initOutputs=frozen(allInitOutputs - allInitInputs),
906 inputs=frozen(allInputs - allOutputs - intermediateComponents),
907 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
908 outputs=frozen(allOutputs - allInputs - intermediateComposites),
909 prerequisites=frozen(prerequisites),
910 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
911 )