Coverage for python/lsst/pipe/base/pipeline.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Dict, Iterable, Mapping, Set, Union, Generator, TYPE_CHECKING, Optional, Tuple
35import copy
36import re
37import os
38import urllib.parse
39import warnings
41# -----------------------------
42# Imports for other modules --
43from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension, ButlerURI
44from lsst.utils import doImport
45from .configOverrides import ConfigOverrides
46from .connections import iterConnections
47from .pipelineTask import PipelineTask
49from . import pipelineIR
50from . import pipeTools
52if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from lsst.obs.base import Instrument
55# ----------------------------------
56# Local non-exported definitions --
57# ----------------------------------
59# ------------------------
60# Exported definitions --
61# ------------------------
64@dataclass
65class LabelSpecifier:
66 """A structure to specify a subset of labels to load
68 This structure may contain a set of labels to be used in subsetting a
69 pipeline, or a beginning and end point. Beginning or end may be empty,
70 in which case the range will be a half open interval. Unlike python
71 iteration bounds, end bounds are *INCLUDED*. Note that range based
72 selection is not well defined for pipelines that are not linear in nature,
73 and correct behavior is not guaranteed, or may vary from run to run.
74 """
75 labels: Optional[Set[str]] = None
76 begin: Optional[str] = None
77 end: Optional[str] = None
79 def __post_init__(self):
80 if self.labels is not None and (self.begin or self.end):
81 raise ValueError("This struct can only be initialized with a labels set or "
82 "a begin (and/or) end specifier")
85class TaskDef:
86 """TaskDef is a collection of information about task needed by Pipeline.
88 The information includes task name, configuration object and optional
89 task class. This class is just a collection of attributes and it exposes
90 all of them so that attributes could potentially be modified in place
91 (e.g. if configuration needs extra overrides).
93 Attributes
94 ----------
95 taskName : `str`
96 `PipelineTask` class name, currently it is not specified whether this
97 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
98 Framework should be prepared to handle all cases.
99 config : `lsst.pex.config.Config`
100 Instance of the configuration class corresponding to this task class,
101 usually with all overrides applied. This config will be frozen.
102 taskClass : `type` or ``None``
103 `PipelineTask` class object, can be ``None``. If ``None`` then
104 framework will have to locate and load class.
105 label : `str`, optional
106 Task label, usually a short string unique in a pipeline.
107 """
108 def __init__(self, taskName, config, taskClass=None, label=""):
109 self.taskName = taskName
110 config.freeze()
111 self.config = config
112 self.taskClass = taskClass
113 self.label = label
114 self.connections = config.connections.ConnectionsClass(config=config)
116 @property
117 def configDatasetName(self) -> str:
118 """Name of a dataset type for configuration of this task (`str`)
119 """
120 return self.label + "_config"
122 @property
123 def metadataDatasetName(self) -> Optional[str]:
124 """Name of a dataset type for metadata of this task, `None` if
125 metadata is not to be saved (`str`)
126 """
127 if self.config.saveMetadata:
128 return self.label + "_metadata"
129 else:
130 return None
132 def __str__(self):
133 rep = "TaskDef(" + self.taskName
134 if self.label:
135 rep += ", label=" + self.label
136 rep += ")"
137 return rep
139 def __eq__(self, other: object) -> bool:
140 if not isinstance(other, TaskDef):
141 return False
142 # This does not consider equality of configs when determining equality
143 # as config equality is a difficult thing to define. Should be updated
144 # after DM-27847
145 return self.taskClass == other.taskClass and self.label == other.label
147 def __hash__(self):
148 return hash((self.taskClass, self.label))
151class Pipeline:
152 """A `Pipeline` is a representation of a series of tasks to run, and the
153 configuration for those tasks.
155 Parameters
156 ----------
157 description : `str`
158 A description of that this pipeline does.
159 """
160 def __init__(self, description: str):
161 pipeline_dict = {"description": description, "tasks": {}}
162 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
164 @classmethod
165 def fromFile(cls, filename: str) -> Pipeline:
166 """Load a pipeline defined in a pipeline yaml file.
168 Parameters
169 ----------
170 filename: `str`
171 A path that points to a pipeline defined in yaml format. This
172 filename may also supply additional labels to be used in
173 subsetting the loaded Pipeline. These labels are separated from
174 the path by a \\#, and may be specified as a comma separated
175 list, or a range denoted as beginning..end. Beginning or end may
176 be empty, in which case the range will be a half open interval.
177 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
178 that range based selection is not well defined for pipelines that
179 are not linear in nature, and correct behavior is not guaranteed,
180 or may vary from run to run.
182 Returns
183 -------
184 pipeline: `Pipeline`
185 The pipeline loaded from specified location with appropriate (if
186 any) subsetting
188 Notes
189 -----
190 This method attempts to prune any contracts that contain labels which
191 are not in the declared subset of labels. This pruning is done using a
192 string based matching due to the nature of contracts and may prune more
193 than it should.
194 """
195 return cls.from_uri(filename)
197 @classmethod
198 def from_uri(cls, uri: Union[str, ButlerURI]) -> Pipeline:
199 """Load a pipeline defined in a pipeline yaml file at a location
200 specified by a URI.
202 Parameters
203 ----------
204 uri: `str` or `ButlerURI`
205 If a string is supplied this should be a URI path that points to a
206 pipeline defined in yaml format. This uri may also supply
207 additional labels to be used in subsetting the loaded Pipeline.
208 These labels are separated from the path by a \\#, and may be
209 specified as a comma separated list, or a range denoted as
210 beginning..end. Beginning or end may be empty, in which case the
211 range will be a half open interval. Unlike python iteration
212 bounds, end bounds are *INCLUDED*. Note that range based selection
213 is not well defined for pipelines that are not linear in nature,
214 and correct behavior is not guaranteed, or may vary from run to
215 run. The same specifiers can be used with a ButlerURI object, by
216 being the sole contents in the fragments attribute.
218 Returns
219 -------
220 pipeline: `Pipeline`
221 The pipeline loaded from specified location with appropriate (if
222 any) subsetting
224 Notes
225 -----
226 This method attempts to prune any contracts that contain labels which
227 are not in the declared subset of labels. This pruning is done using a
228 string based matching due to the nature of contracts and may prune more
229 than it should.
230 """
231 # Split up the uri and any labels that were supplied
232 uri, label_specifier = cls._parse_file_specifier(uri)
233 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
235 # If there are labels supplied, only keep those
236 if label_specifier is not None:
237 pipeline = pipeline.subsetFromLabels(label_specifier)
238 return pipeline
240 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
241 """Subset a pipeline to contain only labels specified in labelSpecifier
243 Parameters
244 ----------
245 labelSpecifier : `labelSpecifier`
246 Object containing labels that describes how to subset a pipeline.
248 Returns
249 -------
250 pipeline : `Pipeline`
251 A new pipeline object that is a subset of the old pipeline
253 Raises
254 ------
255 ValueError
256 Raised if there is an issue with specified labels
258 Notes
259 -----
260 This method attempts to prune any contracts that contain labels which
261 are not in the declared subset of labels. This pruning is done using a
262 string based matching due to the nature of contracts and may prune more
263 than it should.
264 """
265 # Labels supplied as a set
266 if labelSpecifier.labels:
267 labelSet = labelSpecifier.labels
268 # Labels supplied as a range, first create a list of all the labels
269 # in the pipeline sorted according to task dependency. Then only
270 # keep labels that lie between the supplied bounds
271 else:
272 # Create a copy of the pipeline to use when assessing the label
273 # ordering. Use a dict for fast searching while preserving order.
274 # Remove contracts so they do not fail in the expansion step. This
275 # is needed because a user may only configure the tasks they intend
276 # to run, which may cause some contracts to fail if they will later
277 # be dropped
278 pipeline = copy.deepcopy(self)
279 pipeline._pipelineIR.contracts = []
280 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
282 # Verify the bounds are in the labels
283 if labelSpecifier.begin is not None:
284 if labelSpecifier.begin not in labels:
285 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
286 "pipeline definition")
287 if labelSpecifier.end is not None:
288 if labelSpecifier.end not in labels:
289 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
290 "definition")
292 labelSet = set()
293 for label in labels:
294 if labelSpecifier.begin is not None:
295 if label != labelSpecifier.begin:
296 continue
297 else:
298 labelSpecifier.begin = None
299 labelSet.add(label)
300 if labelSpecifier.end is not None and label == labelSpecifier.end:
301 break
302 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
304 @staticmethod
305 def _parse_file_specifier(uri: Union[str, ButlerURI]
306 ) -> Tuple[ButlerURI, Optional[LabelSpecifier]]:
307 """Split appart a uri and any possible label subsets
308 """
309 if isinstance(uri, str):
310 # This is to support legacy pipelines during transition
311 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
312 if num_replace:
313 warnings.warn(f"The pipeline file {uri} seems to use the legacy : to separate "
314 "labels, this is deprecated and will be removed after June 2021, please use "
315 "# instead.",
316 category=FutureWarning)
317 if uri.count("#") > 1:
318 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
319 uri = ButlerURI(uri)
320 label_subset = uri.fragment or None
322 specifier: Optional[LabelSpecifier]
323 if label_subset is not None:
324 label_subset = urllib.parse.unquote(label_subset)
325 args: Dict[str, Union[Set[str], str, None]]
326 # labels supplied as a list
327 if ',' in label_subset:
328 if '..' in label_subset:
329 raise ValueError("Can only specify a list of labels or a range"
330 "when loading a Pipline not both")
331 args = {"labels": set(label_subset.split(","))}
332 # labels supplied as a range
333 elif '..' in label_subset:
334 # Try to de-structure the labelSubset, this will fail if more
335 # than one range is specified
336 begin, end, *rest = label_subset.split("..")
337 if rest:
338 raise ValueError("Only one range can be specified when loading a pipeline")
339 args = {"begin": begin if begin else None, "end": end if end else None}
340 # Assume anything else is a single label
341 else:
342 args = {"labels": {label_subset}}
344 specifier = LabelSpecifier(**args)
345 else:
346 specifier = None
348 return uri, specifier
350 @classmethod
351 def fromString(cls, pipeline_string: str) -> Pipeline:
352 """Create a pipeline from string formatted as a pipeline document.
354 Parameters
355 ----------
356 pipeline_string : `str`
357 A string that is formatted according like a pipeline document
359 Returns
360 -------
361 pipeline: `Pipeline`
362 """
363 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
364 return pipeline
366 @classmethod
367 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
368 """Create a pipeline from an already created `PipelineIR` object.
370 Parameters
371 ----------
372 deserialized_pipeline: `PipelineIR`
373 An already created pipeline intermediate representation object
375 Returns
376 -------
377 pipeline: `Pipeline`
378 """
379 pipeline = cls.__new__(cls)
380 pipeline._pipelineIR = deserialized_pipeline
381 return pipeline
383 @classmethod
384 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
385 """Create a new pipeline by copying an already existing `Pipeline`.
387 Parameters
388 ----------
389 pipeline: `Pipeline`
390 An already created pipeline intermediate representation object
392 Returns
393 -------
394 pipeline: `Pipeline`
395 """
396 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
398 def __str__(self) -> str:
399 return str(self._pipelineIR)
401 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
402 """Add an instrument to the pipeline, or replace an instrument that is
403 already defined.
405 Parameters
406 ----------
407 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
408 Either a derived class object of a `lsst.daf.butler.instrument` or
409 a string corresponding to a fully qualified
410 `lsst.daf.butler.instrument` name.
411 """
412 if isinstance(instrument, str):
413 pass
414 else:
415 # TODO: assume that this is a subclass of Instrument, no type
416 # checking
417 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
418 self._pipelineIR.instrument = instrument
420 def getInstrument(self) -> Instrument:
421 """Get the instrument from the pipeline.
423 Returns
424 -------
425 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
426 A derived class object of a `lsst.daf.butler.instrument`, a string
427 corresponding to a fully qualified `lsst.daf.butler.instrument`
428 name, or None if the pipeline does not have an instrument.
429 """
430 return self._pipelineIR.instrument
432 def addTask(self, task: Union[PipelineTask, str], label: str) -> None:
433 """Add a new task to the pipeline, or replace a task that is already
434 associated with the supplied label.
436 Parameters
437 ----------
438 task: `PipelineTask` or `str`
439 Either a derived class object of a `PipelineTask` or a string
440 corresponding to a fully qualified `PipelineTask` name.
441 label: `str`
442 A label that is used to identify the `PipelineTask` being added
443 """
444 if isinstance(task, str):
445 taskName = task
446 elif issubclass(task, PipelineTask):
447 taskName = f"{task.__module__}.{task.__qualname__}"
448 else:
449 raise ValueError("task must be either a child class of PipelineTask or a string containing"
450 " a fully qualified name to one")
451 if not label:
452 # in some cases (with command line-generated pipeline) tasks can
453 # be defined without label which is not acceptable, use task
454 # _DefaultName in that case
455 if isinstance(task, str):
456 task = doImport(task)
457 label = task._DefaultName
458 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
460 def removeTask(self, label: str) -> None:
461 """Remove a task from the pipeline.
463 Parameters
464 ----------
465 label : `str`
466 The label used to identify the task that is to be removed
468 Raises
469 ------
470 KeyError
471 If no task with that label exists in the pipeline
473 """
474 self._pipelineIR.tasks.pop(label)
476 def addConfigOverride(self, label: str, key: str, value: object) -> None:
477 """Apply single config override.
479 Parameters
480 ----------
481 label : `str`
482 Label of the task.
483 key: `str`
484 Fully-qualified field name.
485 value : object
486 Value to be given to a field.
487 """
488 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
490 def addConfigFile(self, label: str, filename: str) -> None:
491 """Add overrides from a specified file.
493 Parameters
494 ----------
495 label : `str`
496 The label used to identify the task associated with config to
497 modify
498 filename : `str`
499 Path to the override file.
500 """
501 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
503 def addConfigPython(self, label: str, pythonString: str) -> None:
504 """Add Overrides by running a snippet of python code against a config.
506 Parameters
507 ----------
508 label : `str`
509 The label used to identity the task associated with config to
510 modify.
511 pythonString: `str`
512 A string which is valid python code to be executed. This is done
513 with config as the only local accessible value.
514 """
515 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
517 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
518 if label == "parameters":
519 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
520 raise ValueError("Cannot override parameters that are not defined in pipeline")
521 self._pipelineIR.parameters.mapping.update(newConfig.rest)
522 if newConfig.file:
523 raise ValueError("Setting parameters section with config file is not supported")
524 if newConfig.python:
525 raise ValueError("Setting parameters section using python block in unsupported")
526 return
527 if label not in self._pipelineIR.tasks:
528 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
529 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
531 def toFile(self, filename: str) -> None:
532 self._pipelineIR.to_file(filename)
534 def write_to_uri(self, uri: Union[str, ButlerURI]) -> None:
535 self._pipelineIR.write_to_uri(uri)
537 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
538 """Returns a generator of TaskDefs which can be used to create quantum
539 graphs.
541 Returns
542 -------
543 generator : generator of `TaskDef`
544 The generator returned will be the sorted iterator of tasks which
545 are to be used in constructing a quantum graph.
547 Raises
548 ------
549 NotImplementedError
550 If a dataId is supplied in a config block. This is in place for
551 future use
552 """
553 taskDefs = []
554 for label, taskIR in self._pipelineIR.tasks.items():
555 taskClass = doImport(taskIR.klass)
556 taskName = taskClass.__qualname__
557 config = taskClass.ConfigClass()
558 overrides = ConfigOverrides()
559 if self._pipelineIR.instrument is not None:
560 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
561 if taskIR.config is not None:
562 for configIR in (configIr.formatted(self._pipelineIR.parameters)
563 for configIr in taskIR.config):
564 if configIR.dataId is not None:
565 raise NotImplementedError("Specializing a config on a partial data id is not yet "
566 "supported in Pipeline definition")
567 # only apply override if it applies to everything
568 if configIR.dataId is None:
569 if configIR.file:
570 for configFile in configIR.file:
571 overrides.addFileOverride(os.path.expandvars(configFile))
572 if configIR.python is not None:
573 overrides.addPythonOverride(configIR.python)
574 for key, value in configIR.rest.items():
575 overrides.addValueOverride(key, value)
576 overrides.applyTo(config)
577 # This may need to be revisited
578 config.validate()
579 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
581 # lets evaluate the contracts
582 if self._pipelineIR.contracts is not None:
583 label_to_config = {x.label: x.config for x in taskDefs}
584 for contract in self._pipelineIR.contracts:
585 # execute this in its own line so it can raise a good error
586 # message if there was problems with the eval
587 success = eval(contract.contract, None, label_to_config)
588 if not success:
589 extra_info = f": {contract.msg}" if contract.msg is not None else ""
590 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
591 f"satisfied{extra_info}")
593 yield from pipeTools.orderPipeline(taskDefs)
595 def __len__(self):
596 return len(self._pipelineIR.tasks)
598 def __eq__(self, other: object):
599 if not isinstance(other, Pipeline):
600 return False
601 return self._pipelineIR == other._pipelineIR
604@dataclass(frozen=True)
605class TaskDatasetTypes:
606 """An immutable struct that extracts and classifies the dataset types used
607 by a `PipelineTask`
608 """
610 initInputs: NamedValueSet[DatasetType]
611 """Dataset types that are needed as inputs in order to construct this Task.
613 Task-level `initInputs` may be classified as either
614 `~PipelineDatasetTypes.initInputs` or
615 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
616 """
618 initOutputs: NamedValueSet[DatasetType]
619 """Dataset types that may be written after constructing this Task.
621 Task-level `initOutputs` may be classified as either
622 `~PipelineDatasetTypes.initOutputs` or
623 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
624 """
626 inputs: NamedValueSet[DatasetType]
627 """Dataset types that are regular inputs to this Task.
629 If an input dataset needed for a Quantum cannot be found in the input
630 collection(s) or produced by another Task in the Pipeline, that Quantum
631 (and all dependent Quanta) will not be produced.
633 Task-level `inputs` may be classified as either
634 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
635 at the Pipeline level.
636 """
638 prerequisites: NamedValueSet[DatasetType]
639 """Dataset types that are prerequisite inputs to this Task.
641 Prerequisite inputs must exist in the input collection(s) before the
642 pipeline is run, but do not constrain the graph - if a prerequisite is
643 missing for a Quantum, `PrerequisiteMissingError` is raised.
645 Prerequisite inputs are not resolved until the second stage of
646 QuantumGraph generation.
647 """
649 outputs: NamedValueSet[DatasetType]
650 """Dataset types that are produced by this Task.
652 Task-level `outputs` may be classified as either
653 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
654 at the Pipeline level.
655 """
657 @classmethod
658 def fromTaskDef(
659 cls,
660 taskDef: TaskDef,
661 *,
662 registry: Registry,
663 include_configs: bool = True,
664 ) -> TaskDatasetTypes:
665 """Extract and classify the dataset types from a single `PipelineTask`.
667 Parameters
668 ----------
669 taskDef: `TaskDef`
670 An instance of a `TaskDef` class for a particular `PipelineTask`.
671 registry: `Registry`
672 Registry used to construct normalized `DatasetType` objects and
673 retrieve those that are incomplete.
674 include_configs : `bool`, optional
675 If `True` (default) include config dataset types as
676 ``initOutputs``.
678 Returns
679 -------
680 types: `TaskDatasetTypes`
681 The dataset types used by this task.
682 """
683 def makeDatasetTypesSet(connectionType: str, freeze: bool = True) -> NamedValueSet[DatasetType]:
684 """Constructs a set of true `DatasetType` objects
686 Parameters
687 ----------
688 connectionType : `str`
689 Name of the connection type to produce a set for, corresponds
690 to an attribute of type `list` on the connection class instance
691 freeze : `bool`, optional
692 If `True`, call `NamedValueSet.freeze` on the object returned.
694 Returns
695 -------
696 datasetTypes : `NamedValueSet`
697 A set of all datasetTypes which correspond to the input
698 connection type specified in the connection class of this
699 `PipelineTask`
701 Notes
702 -----
703 This function is a closure over the variables ``registry`` and
704 ``taskDef``.
705 """
706 datasetTypes = NamedValueSet()
707 for c in iterConnections(taskDef.connections, connectionType):
708 dimensions = set(getattr(c, 'dimensions', set()))
709 if "skypix" in dimensions:
710 try:
711 datasetType = registry.getDatasetType(c.name)
712 except LookupError as err:
713 raise LookupError(
714 f"DatasetType '{c.name}' referenced by "
715 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
716 f"placeholder, but does not already exist in the registry. "
717 f"Note that reference catalog names are now used as the dataset "
718 f"type name instead of 'ref_cat'."
719 ) from err
720 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
721 rest2 = set(dim.name for dim in datasetType.dimensions
722 if not isinstance(dim, SkyPixDimension))
723 if rest1 != rest2:
724 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
725 f"connections ({rest1}) are inconsistent with those in "
726 f"registry's version of this dataset ({rest2}).")
727 else:
728 # Component dataset types are not explicitly in the
729 # registry. This complicates consistency checks with
730 # registry and requires we work out the composite storage
731 # class.
732 registryDatasetType = None
733 try:
734 registryDatasetType = registry.getDatasetType(c.name)
735 except KeyError:
736 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
737 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
738 if componentName else None
739 datasetType = c.makeDatasetType(
740 registry.dimensions,
741 parentStorageClass=parentStorageClass
742 )
743 registryDatasetType = datasetType
744 else:
745 datasetType = c.makeDatasetType(
746 registry.dimensions,
747 parentStorageClass=registryDatasetType.parentStorageClass
748 )
750 if registryDatasetType and datasetType != registryDatasetType:
751 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
752 f"registry definition ({registryDatasetType}) "
753 f"for {taskDef.label}.")
754 datasetTypes.add(datasetType)
755 if freeze:
756 datasetTypes.freeze()
757 return datasetTypes
759 # optionally add initOutput dataset for config
760 initOutputs = makeDatasetTypesSet("initOutputs", freeze=False)
761 if include_configs:
762 initOutputs.add(
763 DatasetType(
764 taskDef.configDatasetName,
765 registry.dimensions.empty,
766 storageClass="Config",
767 )
768 )
769 initOutputs.freeze()
771 # optionally add output dataset for metadata
772 outputs = makeDatasetTypesSet("outputs", freeze=False)
773 if taskDef.metadataDatasetName is not None:
774 # Metadata is supposed to be of the PropertySet type, its
775 # dimensions correspond to a task quantum
776 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
777 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
778 outputs.freeze()
780 return cls(
781 initInputs=makeDatasetTypesSet("initInputs"),
782 initOutputs=initOutputs,
783 inputs=makeDatasetTypesSet("inputs"),
784 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
785 outputs=outputs,
786 )
789@dataclass(frozen=True)
790class PipelineDatasetTypes:
791 """An immutable struct that classifies the dataset types used in a
792 `Pipeline`.
793 """
795 initInputs: NamedValueSet[DatasetType]
796 """Dataset types that are needed as inputs in order to construct the Tasks
797 in this Pipeline.
799 This does not include dataset types that are produced when constructing
800 other Tasks in the Pipeline (these are classified as `initIntermediates`).
801 """
803 initOutputs: NamedValueSet[DatasetType]
804 """Dataset types that may be written after constructing the Tasks in this
805 Pipeline.
807 This does not include dataset types that are also used as inputs when
808 constructing other Tasks in the Pipeline (these are classified as
809 `initIntermediates`).
810 """
812 initIntermediates: NamedValueSet[DatasetType]
813 """Dataset types that are both used when constructing one or more Tasks
814 in the Pipeline and produced as a side-effect of constructing another
815 Task in the Pipeline.
816 """
818 inputs: NamedValueSet[DatasetType]
819 """Dataset types that are regular inputs for the full pipeline.
821 If an input dataset needed for a Quantum cannot be found in the input
822 collection(s), that Quantum (and all dependent Quanta) will not be
823 produced.
824 """
826 prerequisites: NamedValueSet[DatasetType]
827 """Dataset types that are prerequisite inputs for the full Pipeline.
829 Prerequisite inputs must exist in the input collection(s) before the
830 pipeline is run, but do not constrain the graph - if a prerequisite is
831 missing for a Quantum, `PrerequisiteMissingError` is raised.
833 Prerequisite inputs are not resolved until the second stage of
834 QuantumGraph generation.
835 """
837 intermediates: NamedValueSet[DatasetType]
838 """Dataset types that are output by one Task in the Pipeline and consumed
839 as inputs by one or more other Tasks in the Pipeline.
840 """
842 outputs: NamedValueSet[DatasetType]
843 """Dataset types that are output by a Task in the Pipeline and not consumed
844 by any other Task in the Pipeline.
845 """
847 byTask: Mapping[str, TaskDatasetTypes]
848 """Per-Task dataset types, keyed by label in the `Pipeline`.
850 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
851 neither has been modified since the dataset types were extracted, of
852 course).
853 """
855 @classmethod
856 def fromPipeline(
857 cls,
858 pipeline: Union[Pipeline, Iterable[TaskDef]],
859 *,
860 registry: Registry,
861 include_configs: bool = True,
862 include_packages: bool = True,
863 ) -> PipelineDatasetTypes:
864 """Extract and classify the dataset types from all tasks in a
865 `Pipeline`.
867 Parameters
868 ----------
869 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
870 A dependency-ordered collection of tasks that can be run
871 together.
872 registry: `Registry`
873 Registry used to construct normalized `DatasetType` objects and
874 retrieve those that are incomplete.
875 include_configs : `bool`, optional
876 If `True` (default) include config dataset types as
877 ``initOutputs``.
878 include_packages : `bool`, optional
879 If `True` (default) include the dataset type for software package
880 versions in ``initOutputs``.
882 Returns
883 -------
884 types: `PipelineDatasetTypes`
885 The dataset types used by this `Pipeline`.
887 Raises
888 ------
889 ValueError
890 Raised if Tasks are inconsistent about which datasets are marked
891 prerequisite. This indicates that the Tasks cannot be run as part
892 of the same `Pipeline`.
893 """
894 allInputs = NamedValueSet()
895 allOutputs = NamedValueSet()
896 allInitInputs = NamedValueSet()
897 allInitOutputs = NamedValueSet()
898 prerequisites = NamedValueSet()
899 byTask = dict()
900 if include_packages:
901 allInitOutputs.add(
902 DatasetType(
903 "packages",
904 registry.dimensions.empty,
905 storageClass="Packages",
906 )
907 )
908 if isinstance(pipeline, Pipeline):
909 pipeline = pipeline.toExpandedPipeline()
910 for taskDef in pipeline:
911 thisTask = TaskDatasetTypes.fromTaskDef(
912 taskDef,
913 registry=registry,
914 include_configs=include_configs,
915 )
916 allInitInputs |= thisTask.initInputs
917 allInitOutputs |= thisTask.initOutputs
918 allInputs |= thisTask.inputs
919 prerequisites |= thisTask.prerequisites
920 allOutputs |= thisTask.outputs
921 byTask[taskDef.label] = thisTask
922 if not prerequisites.isdisjoint(allInputs):
923 raise ValueError("{} marked as both prerequisites and regular inputs".format(
924 {dt.name for dt in allInputs & prerequisites}
925 ))
926 if not prerequisites.isdisjoint(allOutputs):
927 raise ValueError("{} marked as both prerequisites and outputs".format(
928 {dt.name for dt in allOutputs & prerequisites}
929 ))
930 # Make sure that components which are marked as inputs get treated as
931 # intermediates if there is an output which produces the composite
932 # containing the component
933 intermediateComponents = NamedValueSet()
934 intermediateComposites = NamedValueSet()
935 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
936 for dsType in allInputs:
937 # get the name of a possible component
938 name, component = dsType.nameAndComponent()
939 # if there is a component name, that means this is a component
940 # DatasetType, if there is an output which produces the parent of
941 # this component, treat this input as an intermediate
942 if component is not None:
943 if name in outputNameMapping:
944 if outputNameMapping[name].dimensions != dsType.dimensions:
945 raise ValueError(f"Component dataset type {dsType.name} has different "
946 f"dimensions ({dsType.dimensions}) than its parent "
947 f"({outputNameMapping[name].dimensions}).")
948 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
949 universe=registry.dimensions)
950 intermediateComponents.add(dsType)
951 intermediateComposites.add(composite)
953 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
954 common = a.names & b.names
955 for name in common:
956 if a[name] != b[name]:
957 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
959 checkConsistency(allInitInputs, allInitOutputs)
960 checkConsistency(allInputs, allOutputs)
961 checkConsistency(allInputs, intermediateComposites)
962 checkConsistency(allOutputs, intermediateComposites)
964 def frozen(s: NamedValueSet) -> NamedValueSet:
965 s.freeze()
966 return s
968 return cls(
969 initInputs=frozen(allInitInputs - allInitOutputs),
970 initIntermediates=frozen(allInitInputs & allInitOutputs),
971 initOutputs=frozen(allInitOutputs - allInitInputs),
972 inputs=frozen(allInputs - allOutputs - intermediateComponents),
973 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
974 outputs=frozen(allOutputs - allInputs - intermediateComposites),
975 prerequisites=frozen(prerequisites),
976 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
977 )