Coverage for python/lsst/pipe/base/pipeline.py : 20%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Dict, Mapping, Set, Union, Generator, TYPE_CHECKING, Optional, Tuple
35import copy
36import re
37import os
38import urllib.parse
39import warnings
41# -----------------------------
42# Imports for other modules --
43from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension, ButlerURI
44from lsst.utils import doImport
45from .configOverrides import ConfigOverrides
46from .connections import iterConnections
47from .pipelineTask import PipelineTask
49from . import pipelineIR
50from . import pipeTools
52if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from lsst.obs.base import Instrument
55# ----------------------------------
56# Local non-exported definitions --
57# ----------------------------------
59# ------------------------
60# Exported definitions --
61# ------------------------
64@dataclass
65class LabelSpecifier:
66 """A structure to specify a subset of labels to load
68 This structure may contain a set of labels to be used in subsetting a
69 pipeline, or a beginning and end point. Beginning or end may be empty,
70 in which case the range will be a half open interval. Unlike python
71 iteration bounds, end bounds are *INCLUDED*. Note that range based
72 selection is not well defined for pipelines that are not linear in nature,
73 and correct behavior is not guaranteed, or may vary from run to run.
74 """
75 labels: Optional[Set[str]] = None
76 begin: Optional[str] = None
77 end: Optional[str] = None
79 def __post_init__(self):
80 if self.labels is not None and (self.begin or self.end):
81 raise ValueError("This struct can only be initialized with a labels set or "
82 "a begin (and/or) end specifier")
85class TaskDef:
86 """TaskDef is a collection of information about task needed by Pipeline.
88 The information includes task name, configuration object and optional
89 task class. This class is just a collection of attributes and it exposes
90 all of them so that attributes could potentially be modified in place
91 (e.g. if configuration needs extra overrides).
93 Attributes
94 ----------
95 taskName : `str`
96 `PipelineTask` class name, currently it is not specified whether this
97 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
98 Framework should be prepared to handle all cases.
99 config : `lsst.pex.config.Config`
100 Instance of the configuration class corresponding to this task class,
101 usually with all overrides applied. This config will be frozen.
102 taskClass : `type` or ``None``
103 `PipelineTask` class object, can be ``None``. If ``None`` then
104 framework will have to locate and load class.
105 label : `str`, optional
106 Task label, usually a short string unique in a pipeline.
107 """
108 def __init__(self, taskName, config, taskClass=None, label=""):
109 self.taskName = taskName
110 config.freeze()
111 self.config = config
112 self.taskClass = taskClass
113 self.label = label
114 self.connections = config.connections.ConnectionsClass(config=config)
116 @property
117 def configDatasetName(self) -> str:
118 """Name of a dataset type for configuration of this task (`str`)
119 """
120 return self.label + "_config"
122 @property
123 def metadataDatasetName(self) -> Optional[str]:
124 """Name of a dataset type for metadata of this task, `None` if
125 metadata is not to be saved (`str`)
126 """
127 if self.config.saveMetadata:
128 return self.label + "_metadata"
129 else:
130 return None
132 def __str__(self):
133 rep = "TaskDef(" + self.taskName
134 if self.label:
135 rep += ", label=" + self.label
136 rep += ")"
137 return rep
139 def __eq__(self, other: object) -> bool:
140 if not isinstance(other, TaskDef):
141 return False
142 # This does not consider equality of configs when determining equality
143 # as config equality is a difficult thing to define. Should be updated
144 # after DM-27847
145 return self.taskClass == other.taskClass and self.label == other.label
147 def __hash__(self):
148 return hash((self.taskClass, self.label))
151class Pipeline:
152 """A `Pipeline` is a representation of a series of tasks to run, and the
153 configuration for those tasks.
155 Parameters
156 ----------
157 description : `str`
158 A description of that this pipeline does.
159 """
160 def __init__(self, description: str):
161 pipeline_dict = {"description": description, "tasks": {}}
162 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
164 @classmethod
165 def fromFile(cls, filename: str) -> Pipeline:
166 """Load a pipeline defined in a pipeline yaml file.
168 Parameters
169 ----------
170 filename: `str`
171 A path that points to a pipeline defined in yaml format. This
172 filename may also supply additional labels to be used in
173 subsetting the loaded Pipeline. These labels are separated from
174 the path by a \\#, and may be specified as a comma separated
175 list, or a range denoted as beginning..end. Beginning or end may
176 be empty, in which case the range will be a half open interval.
177 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
178 that range based selection is not well defined for pipelines that
179 are not linear in nature, and correct behavior is not guaranteed,
180 or may vary from run to run.
182 Returns
183 -------
184 pipeline: `Pipeline`
185 The pipeline loaded from specified location with appropriate (if
186 any) subsetting
188 Notes
189 -----
190 This method attempts to prune any contracts that contain labels which
191 are not in the declared subset of labels. This pruning is done using a
192 string based matching due to the nature of contracts and may prune more
193 than it should.
194 """
195 return cls.from_uri(filename)
197 @classmethod
198 def from_uri(cls, uri: Union[str, ButlerURI]) -> Pipeline:
199 """Load a pipeline defined in a pipeline yaml file at a location
200 specified by a URI.
202 Parameters
203 ----------
204 uri: `str` or `ButlerURI`
205 If a string is supplied this should be a URI path that points to a
206 pipeline defined in yaml format. This uri may also supply
207 additional labels to be used in subsetting the loaded Pipeline.
208 These labels are separated from the path by a \\#, and may be
209 specified as a comma separated list, or a range denoted as
210 beginning..end. Beginning or end may be empty, in which case the
211 range will be a half open interval. Unlike python iteration
212 bounds, end bounds are *INCLUDED*. Note that range based selection
213 is not well defined for pipelines that are not linear in nature,
214 and correct behavior is not guaranteed, or may vary from run to
215 run. The same specifiers can be used with a ButlerURI object, by
216 being the sole contents in the fragments attribute.
218 Returns
219 -------
220 pipeline: `Pipeline`
221 The pipeline loaded from specified location with appropriate (if
222 any) subsetting
224 Notes
225 -----
226 This method attempts to prune any contracts that contain labels which
227 are not in the declared subset of labels. This pruning is done using a
228 string based matching due to the nature of contracts and may prune more
229 than it should.
230 """
231 # Split up the uri and any labels that were supplied
232 uri, label_specifier = cls._parse_file_specifier(uri)
233 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
235 # If there are labels supplied, only keep those
236 if label_specifier is not None:
237 pipeline = pipeline.subsetFromLabels(label_specifier)
238 return pipeline
240 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
241 """Subset a pipeline to contain only labels specified in labelSpecifier
243 Parameters
244 ----------
245 labelSpecifier : `labelSpecifier`
246 Object containing labels that describes how to subset a pipeline.
248 Returns
249 -------
250 pipeline : `Pipeline`
251 A new pipeline object that is a subset of the old pipeline
253 Raises
254 ------
255 ValueError
256 Raised if there is an issue with specified labels
258 Notes
259 -----
260 This method attempts to prune any contracts that contain labels which
261 are not in the declared subset of labels. This pruning is done using a
262 string based matching due to the nature of contracts and may prune more
263 than it should.
264 """
265 # Labels supplied as a set
266 if labelSpecifier.labels:
267 labelSet = labelSpecifier.labels
268 # Labels supplied as a range, first create a list of all the labels
269 # in the pipeline sorted according to task dependency. Then only
270 # keep labels that lie between the supplied bounds
271 else:
272 # Create a copy of the pipeline to use when assessing the label
273 # ordering. Use a dict for fast searching while preserving order.
274 # Remove contracts so they do not fail in the expansion step. This
275 # is needed because a user may only configure the tasks they intend
276 # to run, which may cause some contracts to fail if they will later
277 # be dropped
278 pipeline = copy.deepcopy(self)
279 pipeline._pipelineIR.contracts = []
280 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
282 # Verify the bounds are in the labels
283 if labelSpecifier.begin is not None:
284 if labelSpecifier.begin not in labels:
285 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
286 "pipeline definition")
287 if labelSpecifier.end is not None:
288 if labelSpecifier.end not in labels:
289 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
290 "definition")
292 labelSet = set()
293 for label in labels:
294 if labelSpecifier.begin is not None:
295 if label != labelSpecifier.begin:
296 continue
297 else:
298 labelSpecifier.begin = None
299 labelSet.add(label)
300 if labelSpecifier.end is not None and label == labelSpecifier.end:
301 break
302 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
304 @staticmethod
305 def _parse_file_specifier(uri: Union[str, ButlerURI]
306 ) -> Tuple[ButlerURI, Optional[LabelSpecifier]]:
307 """Split appart a uri and any possible label subsets
308 """
309 if isinstance(uri, str):
310 # This is to support legacy pipelines during transition
311 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
312 if num_replace:
313 warnings.warn(f"The pipeline file {uri} seems to use the legacy : to separate "
314 "labels, this is deprecated and will be removed after June 2021, please use "
315 "# instead.",
316 category=FutureWarning)
317 if uri.count("#") > 1:
318 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
319 uri = ButlerURI(uri)
320 label_subset = uri.fragment or None
322 specifier: Optional[LabelSpecifier]
323 if label_subset is not None:
324 label_subset = urllib.parse.unquote(label_subset)
325 args: Dict[str, Union[Set[str], str, None]]
326 # labels supplied as a list
327 if ',' in label_subset:
328 if '..' in label_subset:
329 raise ValueError("Can only specify a list of labels or a range"
330 "when loading a Pipline not both")
331 args = {"labels": set(label_subset.split(","))}
332 # labels supplied as a range
333 elif '..' in label_subset:
334 # Try to de-structure the labelSubset, this will fail if more
335 # than one range is specified
336 begin, end, *rest = label_subset.split("..")
337 if rest:
338 raise ValueError("Only one range can be specified when loading a pipeline")
339 args = {"begin": begin if begin else None, "end": end if end else None}
340 # Assume anything else is a single label
341 else:
342 args = {"labels": {label_subset}}
344 specifier = LabelSpecifier(**args)
345 else:
346 specifier = None
348 return uri, specifier
350 @classmethod
351 def fromString(cls, pipeline_string: str) -> Pipeline:
352 """Create a pipeline from string formatted as a pipeline document.
354 Parameters
355 ----------
356 pipeline_string : `str`
357 A string that is formatted according like a pipeline document
359 Returns
360 -------
361 pipeline: `Pipeline`
362 """
363 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
364 return pipeline
366 @classmethod
367 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
368 """Create a pipeline from an already created `PipelineIR` object.
370 Parameters
371 ----------
372 deserialized_pipeline: `PipelineIR`
373 An already created pipeline intermediate representation object
375 Returns
376 -------
377 pipeline: `Pipeline`
378 """
379 pipeline = cls.__new__(cls)
380 pipeline._pipelineIR = deserialized_pipeline
381 return pipeline
383 @classmethod
384 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
385 """Create a new pipeline by copying an already existing `Pipeline`.
387 Parameters
388 ----------
389 pipeline: `Pipeline`
390 An already created pipeline intermediate representation object
392 Returns
393 -------
394 pipeline: `Pipeline`
395 """
396 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
398 def __str__(self) -> str:
399 return str(self._pipelineIR)
401 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
402 """Add an instrument to the pipeline, or replace an instrument that is
403 already defined.
405 Parameters
406 ----------
407 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
408 Either a derived class object of a `lsst.daf.butler.instrument` or
409 a string corresponding to a fully qualified
410 `lsst.daf.butler.instrument` name.
411 """
412 if isinstance(instrument, str):
413 pass
414 else:
415 # TODO: assume that this is a subclass of Instrument, no type
416 # checking
417 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
418 self._pipelineIR.instrument = instrument
420 def getInstrument(self) -> Instrument:
421 """Get the instrument from the pipeline.
423 Returns
424 -------
425 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
426 A derived class object of a `lsst.daf.butler.instrument`, a string
427 corresponding to a fully qualified `lsst.daf.butler.instrument`
428 name, or None if the pipeline does not have an instrument.
429 """
430 return self._pipelineIR.instrument
432 def addTask(self, task: Union[PipelineTask, str], label: str) -> None:
433 """Add a new task to the pipeline, or replace a task that is already
434 associated with the supplied label.
436 Parameters
437 ----------
438 task: `PipelineTask` or `str`
439 Either a derived class object of a `PipelineTask` or a string
440 corresponding to a fully qualified `PipelineTask` name.
441 label: `str`
442 A label that is used to identify the `PipelineTask` being added
443 """
444 if isinstance(task, str):
445 taskName = task
446 elif issubclass(task, PipelineTask):
447 taskName = f"{task.__module__}.{task.__qualname__}"
448 else:
449 raise ValueError("task must be either a child class of PipelineTask or a string containing"
450 " a fully qualified name to one")
451 if not label:
452 # in some cases (with command line-generated pipeline) tasks can
453 # be defined without label which is not acceptable, use task
454 # _DefaultName in that case
455 if isinstance(task, str):
456 task = doImport(task)
457 label = task._DefaultName
458 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
460 def removeTask(self, label: str) -> None:
461 """Remove a task from the pipeline.
463 Parameters
464 ----------
465 label : `str`
466 The label used to identify the task that is to be removed
468 Raises
469 ------
470 KeyError
471 If no task with that label exists in the pipeline
473 """
474 self._pipelineIR.tasks.pop(label)
476 def addConfigOverride(self, label: str, key: str, value: object) -> None:
477 """Apply single config override.
479 Parameters
480 ----------
481 label : `str`
482 Label of the task.
483 key: `str`
484 Fully-qualified field name.
485 value : object
486 Value to be given to a field.
487 """
488 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
490 def addConfigFile(self, label: str, filename: str) -> None:
491 """Add overrides from a specified file.
493 Parameters
494 ----------
495 label : `str`
496 The label used to identify the task associated with config to
497 modify
498 filename : `str`
499 Path to the override file.
500 """
501 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
503 def addConfigPython(self, label: str, pythonString: str) -> None:
504 """Add Overrides by running a snippet of python code against a config.
506 Parameters
507 ----------
508 label : `str`
509 The label used to identity the task associated with config to
510 modify.
511 pythonString: `str`
512 A string which is valid python code to be executed. This is done
513 with config as the only local accessible value.
514 """
515 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
517 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
518 if label == "parameters":
519 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
520 raise ValueError("Cannot override parameters that are not defined in pipeline")
521 self._pipelineIR.parameters.mapping.update(newConfig.rest)
522 if newConfig.file:
523 raise ValueError("Setting parameters section with config file is not supported")
524 if newConfig.python:
525 raise ValueError("Setting parameters section using python block in unsupported")
526 return
527 if label not in self._pipelineIR.tasks:
528 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
529 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
531 def toFile(self, filename: str) -> None:
532 self._pipelineIR.to_file(filename)
534 def write_to_uri(self, uri: Union[str, ButlerURI]) -> None:
535 self._pipelineIR.write_to_uri(uri)
537 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
538 """Returns a generator of TaskDefs which can be used to create quantum
539 graphs.
541 Returns
542 -------
543 generator : generator of `TaskDef`
544 The generator returned will be the sorted iterator of tasks which
545 are to be used in constructing a quantum graph.
547 Raises
548 ------
549 NotImplementedError
550 If a dataId is supplied in a config block. This is in place for
551 future use
552 """
553 taskDefs = []
554 for label, taskIR in self._pipelineIR.tasks.items():
555 taskClass = doImport(taskIR.klass)
556 taskName = taskClass.__qualname__
557 config = taskClass.ConfigClass()
558 overrides = ConfigOverrides()
559 if self._pipelineIR.instrument is not None:
560 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
561 if taskIR.config is not None:
562 for configIR in (configIr.formatted(self._pipelineIR.parameters)
563 for configIr in taskIR.config):
564 if configIR.dataId is not None:
565 raise NotImplementedError("Specializing a config on a partial data id is not yet "
566 "supported in Pipeline definition")
567 # only apply override if it applies to everything
568 if configIR.dataId is None:
569 if configIR.file:
570 for configFile in configIR.file:
571 overrides.addFileOverride(os.path.expandvars(configFile))
572 if configIR.python is not None:
573 overrides.addPythonOverride(configIR.python)
574 for key, value in configIR.rest.items():
575 overrides.addValueOverride(key, value)
576 overrides.applyTo(config)
577 # This may need to be revisited
578 config.validate()
579 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
581 # lets evaluate the contracts
582 if self._pipelineIR.contracts is not None:
583 label_to_config = {x.label: x.config for x in taskDefs}
584 for contract in self._pipelineIR.contracts:
585 # execute this in its own line so it can raise a good error
586 # message if there was problems with the eval
587 success = eval(contract.contract, None, label_to_config)
588 if not success:
589 extra_info = f": {contract.msg}" if contract.msg is not None else ""
590 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
591 f"satisfied{extra_info}")
593 yield from pipeTools.orderPipeline(taskDefs)
595 def __len__(self):
596 return len(self._pipelineIR.tasks)
598 def __eq__(self, other: object):
599 if not isinstance(other, Pipeline):
600 return False
601 return self._pipelineIR == other._pipelineIR
604@dataclass(frozen=True)
605class TaskDatasetTypes:
606 """An immutable struct that extracts and classifies the dataset types used
607 by a `PipelineTask`
608 """
610 initInputs: NamedValueSet[DatasetType]
611 """Dataset types that are needed as inputs in order to construct this Task.
613 Task-level `initInputs` may be classified as either
614 `~PipelineDatasetTypes.initInputs` or
615 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
616 """
618 initOutputs: NamedValueSet[DatasetType]
619 """Dataset types that may be written after constructing this Task.
621 Task-level `initOutputs` may be classified as either
622 `~PipelineDatasetTypes.initOutputs` or
623 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
624 """
626 inputs: NamedValueSet[DatasetType]
627 """Dataset types that are regular inputs to this Task.
629 If an input dataset needed for a Quantum cannot be found in the input
630 collection(s) or produced by another Task in the Pipeline, that Quantum
631 (and all dependent Quanta) will not be produced.
633 Task-level `inputs` may be classified as either
634 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
635 at the Pipeline level.
636 """
638 prerequisites: NamedValueSet[DatasetType]
639 """Dataset types that are prerequisite inputs to this Task.
641 Prerequisite inputs must exist in the input collection(s) before the
642 pipeline is run, but do not constrain the graph - if a prerequisite is
643 missing for a Quantum, `PrerequisiteMissingError` is raised.
645 Prerequisite inputs are not resolved until the second stage of
646 QuantumGraph generation.
647 """
649 outputs: NamedValueSet[DatasetType]
650 """Dataset types that are produced by this Task.
652 Task-level `outputs` may be classified as either
653 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
654 at the Pipeline level.
655 """
657 @classmethod
658 def fromTaskDef(cls, taskDef: TaskDef, *, registry: Registry) -> TaskDatasetTypes:
659 """Extract and classify the dataset types from a single `PipelineTask`.
661 Parameters
662 ----------
663 taskDef: `TaskDef`
664 An instance of a `TaskDef` class for a particular `PipelineTask`.
665 registry: `Registry`
666 Registry used to construct normalized `DatasetType` objects and
667 retrieve those that are incomplete.
669 Returns
670 -------
671 types: `TaskDatasetTypes`
672 The dataset types used by this task.
673 """
674 def makeDatasetTypesSet(connectionType: str, freeze: bool = True) -> NamedValueSet[DatasetType]:
675 """Constructs a set of true `DatasetType` objects
677 Parameters
678 ----------
679 connectionType : `str`
680 Name of the connection type to produce a set for, corresponds
681 to an attribute of type `list` on the connection class instance
682 freeze : `bool`, optional
683 If `True`, call `NamedValueSet.freeze` on the object returned.
685 Returns
686 -------
687 datasetTypes : `NamedValueSet`
688 A set of all datasetTypes which correspond to the input
689 connection type specified in the connection class of this
690 `PipelineTask`
692 Notes
693 -----
694 This function is a closure over the variables ``registry`` and
695 ``taskDef``.
696 """
697 datasetTypes = NamedValueSet()
698 for c in iterConnections(taskDef.connections, connectionType):
699 dimensions = set(getattr(c, 'dimensions', set()))
700 if "skypix" in dimensions:
701 try:
702 datasetType = registry.getDatasetType(c.name)
703 except LookupError as err:
704 raise LookupError(
705 f"DatasetType '{c.name}' referenced by "
706 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
707 f"placeholder, but does not already exist in the registry. "
708 f"Note that reference catalog names are now used as the dataset "
709 f"type name instead of 'ref_cat'."
710 ) from err
711 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
712 rest2 = set(dim.name for dim in datasetType.dimensions
713 if not isinstance(dim, SkyPixDimension))
714 if rest1 != rest2:
715 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
716 f"connections ({rest1}) are inconsistent with those in "
717 f"registry's version of this dataset ({rest2}).")
718 else:
719 # Component dataset types are not explicitly in the
720 # registry. This complicates consistency checks with
721 # registry and requires we work out the composite storage
722 # class.
723 registryDatasetType = None
724 try:
725 registryDatasetType = registry.getDatasetType(c.name)
726 except KeyError:
727 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
728 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
729 if componentName else None
730 datasetType = c.makeDatasetType(
731 registry.dimensions,
732 parentStorageClass=parentStorageClass
733 )
734 registryDatasetType = datasetType
735 else:
736 datasetType = c.makeDatasetType(
737 registry.dimensions,
738 parentStorageClass=registryDatasetType.parentStorageClass
739 )
741 if registryDatasetType and datasetType != registryDatasetType:
742 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
743 f"registry definition ({registryDatasetType}) "
744 f"for {taskDef.label}.")
745 datasetTypes.add(datasetType)
746 if freeze:
747 datasetTypes.freeze()
748 return datasetTypes
750 # optionally add output dataset for metadata
751 outputs = makeDatasetTypesSet("outputs", freeze=False)
752 if taskDef.metadataDatasetName is not None:
753 # Metadata is supposed to be of the PropertySet type, its
754 # dimensions correspond to a task quantum
755 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
756 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
757 outputs.freeze()
759 return cls(
760 initInputs=makeDatasetTypesSet("initInputs"),
761 initOutputs=makeDatasetTypesSet("initOutputs"),
762 inputs=makeDatasetTypesSet("inputs"),
763 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
764 outputs=outputs,
765 )
768@dataclass(frozen=True)
769class PipelineDatasetTypes:
770 """An immutable struct that classifies the dataset types used in a
771 `Pipeline`.
772 """
774 initInputs: NamedValueSet[DatasetType]
775 """Dataset types that are needed as inputs in order to construct the Tasks
776 in this Pipeline.
778 This does not include dataset types that are produced when constructing
779 other Tasks in the Pipeline (these are classified as `initIntermediates`).
780 """
782 initOutputs: NamedValueSet[DatasetType]
783 """Dataset types that may be written after constructing the Tasks in this
784 Pipeline.
786 This does not include dataset types that are also used as inputs when
787 constructing other Tasks in the Pipeline (these are classified as
788 `initIntermediates`).
789 """
791 initIntermediates: NamedValueSet[DatasetType]
792 """Dataset types that are both used when constructing one or more Tasks
793 in the Pipeline and produced as a side-effect of constructing another
794 Task in the Pipeline.
795 """
797 inputs: NamedValueSet[DatasetType]
798 """Dataset types that are regular inputs for the full pipeline.
800 If an input dataset needed for a Quantum cannot be found in the input
801 collection(s), that Quantum (and all dependent Quanta) will not be
802 produced.
803 """
805 prerequisites: NamedValueSet[DatasetType]
806 """Dataset types that are prerequisite inputs for the full Pipeline.
808 Prerequisite inputs must exist in the input collection(s) before the
809 pipeline is run, but do not constrain the graph - if a prerequisite is
810 missing for a Quantum, `PrerequisiteMissingError` is raised.
812 Prerequisite inputs are not resolved until the second stage of
813 QuantumGraph generation.
814 """
816 intermediates: NamedValueSet[DatasetType]
817 """Dataset types that are output by one Task in the Pipeline and consumed
818 as inputs by one or more other Tasks in the Pipeline.
819 """
821 outputs: NamedValueSet[DatasetType]
822 """Dataset types that are output by a Task in the Pipeline and not consumed
823 by any other Task in the Pipeline.
824 """
826 byTask: Mapping[str, TaskDatasetTypes]
827 """Per-Task dataset types, keyed by label in the `Pipeline`.
829 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
830 neither has been modified since the dataset types were extracted, of
831 course).
832 """
834 @classmethod
835 def fromPipeline(cls, pipeline, *, registry: Registry) -> PipelineDatasetTypes:
836 """Extract and classify the dataset types from all tasks in a
837 `Pipeline`.
839 Parameters
840 ----------
841 pipeline: `Pipeline`
842 An ordered collection of tasks that can be run together.
843 registry: `Registry`
844 Registry used to construct normalized `DatasetType` objects and
845 retrieve those that are incomplete.
847 Returns
848 -------
849 types: `PipelineDatasetTypes`
850 The dataset types used by this `Pipeline`.
852 Raises
853 ------
854 ValueError
855 Raised if Tasks are inconsistent about which datasets are marked
856 prerequisite. This indicates that the Tasks cannot be run as part
857 of the same `Pipeline`.
858 """
859 allInputs = NamedValueSet()
860 allOutputs = NamedValueSet()
861 allInitInputs = NamedValueSet()
862 allInitOutputs = NamedValueSet()
863 prerequisites = NamedValueSet()
864 byTask = dict()
865 if isinstance(pipeline, Pipeline):
866 pipeline = pipeline.toExpandedPipeline()
867 for taskDef in pipeline:
868 thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
869 allInitInputs |= thisTask.initInputs
870 allInitOutputs |= thisTask.initOutputs
871 allInputs |= thisTask.inputs
872 prerequisites |= thisTask.prerequisites
873 allOutputs |= thisTask.outputs
874 byTask[taskDef.label] = thisTask
875 if not prerequisites.isdisjoint(allInputs):
876 raise ValueError("{} marked as both prerequisites and regular inputs".format(
877 {dt.name for dt in allInputs & prerequisites}
878 ))
879 if not prerequisites.isdisjoint(allOutputs):
880 raise ValueError("{} marked as both prerequisites and outputs".format(
881 {dt.name for dt in allOutputs & prerequisites}
882 ))
883 # Make sure that components which are marked as inputs get treated as
884 # intermediates if there is an output which produces the composite
885 # containing the component
886 intermediateComponents = NamedValueSet()
887 intermediateComposites = NamedValueSet()
888 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
889 for dsType in allInputs:
890 # get the name of a possible component
891 name, component = dsType.nameAndComponent()
892 # if there is a component name, that means this is a component
893 # DatasetType, if there is an output which produces the parent of
894 # this component, treat this input as an intermediate
895 if component is not None:
896 if name in outputNameMapping:
897 if outputNameMapping[name].dimensions != dsType.dimensions:
898 raise ValueError(f"Component dataset type {dsType.name} has different "
899 f"dimensions ({dsType.dimensions}) than its parent "
900 f"({outputNameMapping[name].dimensions}).")
901 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
902 universe=registry.dimensions)
903 intermediateComponents.add(dsType)
904 intermediateComposites.add(composite)
906 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
907 common = a.names & b.names
908 for name in common:
909 if a[name] != b[name]:
910 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
912 checkConsistency(allInitInputs, allInitOutputs)
913 checkConsistency(allInputs, allOutputs)
914 checkConsistency(allInputs, intermediateComposites)
915 checkConsistency(allOutputs, intermediateComposites)
917 def frozen(s: NamedValueSet) -> NamedValueSet:
918 s.freeze()
919 return s
921 return cls(
922 initInputs=frozen(allInitInputs - allInitOutputs),
923 initIntermediates=frozen(allInitInputs & allInitOutputs),
924 initOutputs=frozen(allInitOutputs - allInitInputs),
925 inputs=frozen(allInputs - allOutputs - intermediateComponents),
926 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
927 outputs=frozen(allOutputs - allInputs - intermediateComposites),
928 prerequisites=frozen(prerequisites),
929 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
930 )