Coverage for python/lsst/pipe/base/pipeline.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import (ClassVar, Dict, Iterable, Iterator, Mapping, Set, Union,
34 Generator, TYPE_CHECKING, Optional, Tuple)
36import copy
37import re
38import os
39import urllib.parse
40import warnings
42# -----------------------------
43# Imports for other modules --
44from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension, ButlerURI
45from lsst.utils import doImport
46from .configOverrides import ConfigOverrides
47from .connections import iterConnections
48from .pipelineTask import PipelineTask
50from . import pipelineIR
51from . import pipeTools
53if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true
54 from lsst.obs.base import Instrument
56# ----------------------------------
57# Local non-exported definitions --
58# ----------------------------------
60# ------------------------
61# Exported definitions --
62# ------------------------
65@dataclass
66class LabelSpecifier:
67 """A structure to specify a subset of labels to load
69 This structure may contain a set of labels to be used in subsetting a
70 pipeline, or a beginning and end point. Beginning or end may be empty,
71 in which case the range will be a half open interval. Unlike python
72 iteration bounds, end bounds are *INCLUDED*. Note that range based
73 selection is not well defined for pipelines that are not linear in nature,
74 and correct behavior is not guaranteed, or may vary from run to run.
75 """
76 labels: Optional[Set[str]] = None
77 begin: Optional[str] = None
78 end: Optional[str] = None
80 def __post_init__(self):
81 if self.labels is not None and (self.begin or self.end):
82 raise ValueError("This struct can only be initialized with a labels set or "
83 "a begin (and/or) end specifier")
86class TaskDef:
87 """TaskDef is a collection of information about task needed by Pipeline.
89 The information includes task name, configuration object and optional
90 task class. This class is just a collection of attributes and it exposes
91 all of them so that attributes could potentially be modified in place
92 (e.g. if configuration needs extra overrides).
94 Attributes
95 ----------
96 taskName : `str`
97 `PipelineTask` class name, currently it is not specified whether this
98 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
99 Framework should be prepared to handle all cases.
100 config : `lsst.pex.config.Config`
101 Instance of the configuration class corresponding to this task class,
102 usually with all overrides applied. This config will be frozen.
103 taskClass : `type` or ``None``
104 `PipelineTask` class object, can be ``None``. If ``None`` then
105 framework will have to locate and load class.
106 label : `str`, optional
107 Task label, usually a short string unique in a pipeline.
108 """
109 def __init__(self, taskName, config, taskClass=None, label=""):
110 self.taskName = taskName
111 config.freeze()
112 self.config = config
113 self.taskClass = taskClass
114 self.label = label
115 self.connections = config.connections.ConnectionsClass(config=config)
117 @property
118 def configDatasetName(self) -> str:
119 """Name of a dataset type for configuration of this task (`str`)
120 """
121 return self.label + "_config"
123 @property
124 def metadataDatasetName(self) -> Optional[str]:
125 """Name of a dataset type for metadata of this task, `None` if
126 metadata is not to be saved (`str`)
127 """
128 if self.config.saveMetadata:
129 return self.label + "_metadata"
130 else:
131 return None
133 @property
134 def logOutputDatasetName(self) -> Optional[str]:
135 """Name of a dataset type for log output from this task, `None` if
136 logs are not to be saved (`str`)
137 """
138 if self.config.saveLogOutput:
139 return self.label + "_log"
140 else:
141 return None
143 def __str__(self):
144 rep = "TaskDef(" + self.taskName
145 if self.label:
146 rep += ", label=" + self.label
147 rep += ")"
148 return rep
150 def __eq__(self, other: object) -> bool:
151 if not isinstance(other, TaskDef):
152 return False
153 # This does not consider equality of configs when determining equality
154 # as config equality is a difficult thing to define. Should be updated
155 # after DM-27847
156 return self.taskClass == other.taskClass and self.label == other.label
158 def __hash__(self):
159 return hash((self.taskClass, self.label))
162class Pipeline:
163 """A `Pipeline` is a representation of a series of tasks to run, and the
164 configuration for those tasks.
166 Parameters
167 ----------
168 description : `str`
169 A description of that this pipeline does.
170 """
171 def __init__(self, description: str):
172 pipeline_dict = {"description": description, "tasks": {}}
173 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
175 @classmethod
176 def fromFile(cls, filename: str) -> Pipeline:
177 """Load a pipeline defined in a pipeline yaml file.
179 Parameters
180 ----------
181 filename: `str`
182 A path that points to a pipeline defined in yaml format. This
183 filename may also supply additional labels to be used in
184 subsetting the loaded Pipeline. These labels are separated from
185 the path by a \\#, and may be specified as a comma separated
186 list, or a range denoted as beginning..end. Beginning or end may
187 be empty, in which case the range will be a half open interval.
188 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
189 that range based selection is not well defined for pipelines that
190 are not linear in nature, and correct behavior is not guaranteed,
191 or may vary from run to run.
193 Returns
194 -------
195 pipeline: `Pipeline`
196 The pipeline loaded from specified location with appropriate (if
197 any) subsetting
199 Notes
200 -----
201 This method attempts to prune any contracts that contain labels which
202 are not in the declared subset of labels. This pruning is done using a
203 string based matching due to the nature of contracts and may prune more
204 than it should.
205 """
206 return cls.from_uri(filename)
208 @classmethod
209 def from_uri(cls, uri: Union[str, ButlerURI]) -> Pipeline:
210 """Load a pipeline defined in a pipeline yaml file at a location
211 specified by a URI.
213 Parameters
214 ----------
215 uri: `str` or `ButlerURI`
216 If a string is supplied this should be a URI path that points to a
217 pipeline defined in yaml format. This uri may also supply
218 additional labels to be used in subsetting the loaded Pipeline.
219 These labels are separated from the path by a \\#, and may be
220 specified as a comma separated list, or a range denoted as
221 beginning..end. Beginning or end may be empty, in which case the
222 range will be a half open interval. Unlike python iteration
223 bounds, end bounds are *INCLUDED*. Note that range based selection
224 is not well defined for pipelines that are not linear in nature,
225 and correct behavior is not guaranteed, or may vary from run to
226 run. The same specifiers can be used with a ButlerURI object, by
227 being the sole contents in the fragments attribute.
229 Returns
230 -------
231 pipeline: `Pipeline`
232 The pipeline loaded from specified location with appropriate (if
233 any) subsetting
235 Notes
236 -----
237 This method attempts to prune any contracts that contain labels which
238 are not in the declared subset of labels. This pruning is done using a
239 string based matching due to the nature of contracts and may prune more
240 than it should.
241 """
242 # Split up the uri and any labels that were supplied
243 uri, label_specifier = cls._parse_file_specifier(uri)
244 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
246 # If there are labels supplied, only keep those
247 if label_specifier is not None:
248 pipeline = pipeline.subsetFromLabels(label_specifier)
249 return pipeline
251 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
252 """Subset a pipeline to contain only labels specified in labelSpecifier
254 Parameters
255 ----------
256 labelSpecifier : `labelSpecifier`
257 Object containing labels that describes how to subset a pipeline.
259 Returns
260 -------
261 pipeline : `Pipeline`
262 A new pipeline object that is a subset of the old pipeline
264 Raises
265 ------
266 ValueError
267 Raised if there is an issue with specified labels
269 Notes
270 -----
271 This method attempts to prune any contracts that contain labels which
272 are not in the declared subset of labels. This pruning is done using a
273 string based matching due to the nature of contracts and may prune more
274 than it should.
275 """
276 # Labels supplied as a set
277 if labelSpecifier.labels:
278 labelSet = labelSpecifier.labels
279 # Labels supplied as a range, first create a list of all the labels
280 # in the pipeline sorted according to task dependency. Then only
281 # keep labels that lie between the supplied bounds
282 else:
283 # Create a copy of the pipeline to use when assessing the label
284 # ordering. Use a dict for fast searching while preserving order.
285 # Remove contracts so they do not fail in the expansion step. This
286 # is needed because a user may only configure the tasks they intend
287 # to run, which may cause some contracts to fail if they will later
288 # be dropped
289 pipeline = copy.deepcopy(self)
290 pipeline._pipelineIR.contracts = []
291 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
293 # Verify the bounds are in the labels
294 if labelSpecifier.begin is not None:
295 if labelSpecifier.begin not in labels:
296 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
297 "pipeline definition")
298 if labelSpecifier.end is not None:
299 if labelSpecifier.end not in labels:
300 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
301 "definition")
303 labelSet = set()
304 for label in labels:
305 if labelSpecifier.begin is not None:
306 if label != labelSpecifier.begin:
307 continue
308 else:
309 labelSpecifier.begin = None
310 labelSet.add(label)
311 if labelSpecifier.end is not None and label == labelSpecifier.end:
312 break
313 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
315 @staticmethod
316 def _parse_file_specifier(uri: Union[str, ButlerURI]
317 ) -> Tuple[ButlerURI, Optional[LabelSpecifier]]:
318 """Split appart a uri and any possible label subsets
319 """
320 if isinstance(uri, str):
321 # This is to support legacy pipelines during transition
322 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
323 if num_replace:
324 warnings.warn(f"The pipeline file {uri} seems to use the legacy : to separate "
325 "labels, this is deprecated and will be removed after June 2021, please use "
326 "# instead.",
327 category=FutureWarning)
328 if uri.count("#") > 1:
329 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
330 uri = ButlerURI(uri)
331 label_subset = uri.fragment or None
333 specifier: Optional[LabelSpecifier]
334 if label_subset is not None:
335 label_subset = urllib.parse.unquote(label_subset)
336 args: Dict[str, Union[Set[str], str, None]]
337 # labels supplied as a list
338 if ',' in label_subset:
339 if '..' in label_subset:
340 raise ValueError("Can only specify a list of labels or a range"
341 "when loading a Pipline not both")
342 args = {"labels": set(label_subset.split(","))}
343 # labels supplied as a range
344 elif '..' in label_subset:
345 # Try to de-structure the labelSubset, this will fail if more
346 # than one range is specified
347 begin, end, *rest = label_subset.split("..")
348 if rest:
349 raise ValueError("Only one range can be specified when loading a pipeline")
350 args = {"begin": begin if begin else None, "end": end if end else None}
351 # Assume anything else is a single label
352 else:
353 args = {"labels": {label_subset}}
355 specifier = LabelSpecifier(**args)
356 else:
357 specifier = None
359 return uri, specifier
361 @classmethod
362 def fromString(cls, pipeline_string: str) -> Pipeline:
363 """Create a pipeline from string formatted as a pipeline document.
365 Parameters
366 ----------
367 pipeline_string : `str`
368 A string that is formatted according like a pipeline document
370 Returns
371 -------
372 pipeline: `Pipeline`
373 """
374 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
375 return pipeline
377 @classmethod
378 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
379 """Create a pipeline from an already created `PipelineIR` object.
381 Parameters
382 ----------
383 deserialized_pipeline: `PipelineIR`
384 An already created pipeline intermediate representation object
386 Returns
387 -------
388 pipeline: `Pipeline`
389 """
390 pipeline = cls.__new__(cls)
391 pipeline._pipelineIR = deserialized_pipeline
392 return pipeline
394 @classmethod
395 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
396 """Create a new pipeline by copying an already existing `Pipeline`.
398 Parameters
399 ----------
400 pipeline: `Pipeline`
401 An already created pipeline intermediate representation object
403 Returns
404 -------
405 pipeline: `Pipeline`
406 """
407 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
409 def __str__(self) -> str:
410 return str(self._pipelineIR)
412 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
413 """Add an instrument to the pipeline, or replace an instrument that is
414 already defined.
416 Parameters
417 ----------
418 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
419 Either a derived class object of a `lsst.daf.butler.instrument` or
420 a string corresponding to a fully qualified
421 `lsst.daf.butler.instrument` name.
422 """
423 if isinstance(instrument, str):
424 pass
425 else:
426 # TODO: assume that this is a subclass of Instrument, no type
427 # checking
428 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
429 self._pipelineIR.instrument = instrument
431 def getInstrument(self) -> Instrument:
432 """Get the instrument from the pipeline.
434 Returns
435 -------
436 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
437 A derived class object of a `lsst.daf.butler.instrument`, a string
438 corresponding to a fully qualified `lsst.daf.butler.instrument`
439 name, or None if the pipeline does not have an instrument.
440 """
441 return self._pipelineIR.instrument
443 def addTask(self, task: Union[PipelineTask, str], label: str) -> None:
444 """Add a new task to the pipeline, or replace a task that is already
445 associated with the supplied label.
447 Parameters
448 ----------
449 task: `PipelineTask` or `str`
450 Either a derived class object of a `PipelineTask` or a string
451 corresponding to a fully qualified `PipelineTask` name.
452 label: `str`
453 A label that is used to identify the `PipelineTask` being added
454 """
455 if isinstance(task, str):
456 taskName = task
457 elif issubclass(task, PipelineTask):
458 taskName = f"{task.__module__}.{task.__qualname__}"
459 else:
460 raise ValueError("task must be either a child class of PipelineTask or a string containing"
461 " a fully qualified name to one")
462 if not label:
463 # in some cases (with command line-generated pipeline) tasks can
464 # be defined without label which is not acceptable, use task
465 # _DefaultName in that case
466 if isinstance(task, str):
467 task = doImport(task)
468 label = task._DefaultName
469 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
471 def removeTask(self, label: str) -> None:
472 """Remove a task from the pipeline.
474 Parameters
475 ----------
476 label : `str`
477 The label used to identify the task that is to be removed
479 Raises
480 ------
481 KeyError
482 If no task with that label exists in the pipeline
484 """
485 self._pipelineIR.tasks.pop(label)
487 def addConfigOverride(self, label: str, key: str, value: object) -> None:
488 """Apply single config override.
490 Parameters
491 ----------
492 label : `str`
493 Label of the task.
494 key: `str`
495 Fully-qualified field name.
496 value : object
497 Value to be given to a field.
498 """
499 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
501 def addConfigFile(self, label: str, filename: str) -> None:
502 """Add overrides from a specified file.
504 Parameters
505 ----------
506 label : `str`
507 The label used to identify the task associated with config to
508 modify
509 filename : `str`
510 Path to the override file.
511 """
512 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
514 def addConfigPython(self, label: str, pythonString: str) -> None:
515 """Add Overrides by running a snippet of python code against a config.
517 Parameters
518 ----------
519 label : `str`
520 The label used to identity the task associated with config to
521 modify.
522 pythonString: `str`
523 A string which is valid python code to be executed. This is done
524 with config as the only local accessible value.
525 """
526 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
528 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
529 if label == "parameters":
530 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
531 raise ValueError("Cannot override parameters that are not defined in pipeline")
532 self._pipelineIR.parameters.mapping.update(newConfig.rest)
533 if newConfig.file:
534 raise ValueError("Setting parameters section with config file is not supported")
535 if newConfig.python:
536 raise ValueError("Setting parameters section using python block in unsupported")
537 return
538 if label not in self._pipelineIR.tasks:
539 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
540 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
542 def toFile(self, filename: str) -> None:
543 self._pipelineIR.to_file(filename)
545 def write_to_uri(self, uri: Union[str, ButlerURI]) -> None:
546 self._pipelineIR.write_to_uri(uri)
548 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
549 """Returns a generator of TaskDefs which can be used to create quantum
550 graphs.
552 Returns
553 -------
554 generator : generator of `TaskDef`
555 The generator returned will be the sorted iterator of tasks which
556 are to be used in constructing a quantum graph.
558 Raises
559 ------
560 NotImplementedError
561 If a dataId is supplied in a config block. This is in place for
562 future use
563 """
564 taskDefs = []
565 for label, taskIR in self._pipelineIR.tasks.items():
566 taskClass = doImport(taskIR.klass)
567 taskName = taskClass.__qualname__
568 config = taskClass.ConfigClass()
569 overrides = ConfigOverrides()
570 if self._pipelineIR.instrument is not None:
571 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
572 if taskIR.config is not None:
573 for configIR in (configIr.formatted(self._pipelineIR.parameters)
574 for configIr in taskIR.config):
575 if configIR.dataId is not None:
576 raise NotImplementedError("Specializing a config on a partial data id is not yet "
577 "supported in Pipeline definition")
578 # only apply override if it applies to everything
579 if configIR.dataId is None:
580 if configIR.file:
581 for configFile in configIR.file:
582 overrides.addFileOverride(os.path.expandvars(configFile))
583 if configIR.python is not None:
584 overrides.addPythonOverride(configIR.python)
585 for key, value in configIR.rest.items():
586 overrides.addValueOverride(key, value)
587 overrides.applyTo(config)
588 # This may need to be revisited
589 config.validate()
590 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
592 # lets evaluate the contracts
593 if self._pipelineIR.contracts is not None:
594 label_to_config = {x.label: x.config for x in taskDefs}
595 for contract in self._pipelineIR.contracts:
596 # execute this in its own line so it can raise a good error
597 # message if there was problems with the eval
598 success = eval(contract.contract, None, label_to_config)
599 if not success:
600 extra_info = f": {contract.msg}" if contract.msg is not None else ""
601 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
602 f"satisfied{extra_info}")
604 yield from pipeTools.orderPipeline(taskDefs)
606 def __len__(self):
607 return len(self._pipelineIR.tasks)
609 def __eq__(self, other: object):
610 if not isinstance(other, Pipeline):
611 return False
612 return self._pipelineIR == other._pipelineIR
615@dataclass(frozen=True)
616class TaskDatasetTypes:
617 """An immutable struct that extracts and classifies the dataset types used
618 by a `PipelineTask`
619 """
621 initInputs: NamedValueSet[DatasetType]
622 """Dataset types that are needed as inputs in order to construct this Task.
624 Task-level `initInputs` may be classified as either
625 `~PipelineDatasetTypes.initInputs` or
626 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
627 """
629 initOutputs: NamedValueSet[DatasetType]
630 """Dataset types that may be written after constructing this Task.
632 Task-level `initOutputs` may be classified as either
633 `~PipelineDatasetTypes.initOutputs` or
634 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
635 """
637 inputs: NamedValueSet[DatasetType]
638 """Dataset types that are regular inputs to this Task.
640 If an input dataset needed for a Quantum cannot be found in the input
641 collection(s) or produced by another Task in the Pipeline, that Quantum
642 (and all dependent Quanta) will not be produced.
644 Task-level `inputs` may be classified as either
645 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
646 at the Pipeline level.
647 """
649 prerequisites: NamedValueSet[DatasetType]
650 """Dataset types that are prerequisite inputs to this Task.
652 Prerequisite inputs must exist in the input collection(s) before the
653 pipeline is run, but do not constrain the graph - if a prerequisite is
654 missing for a Quantum, `PrerequisiteMissingError` is raised.
656 Prerequisite inputs are not resolved until the second stage of
657 QuantumGraph generation.
658 """
660 outputs: NamedValueSet[DatasetType]
661 """Dataset types that are produced by this Task.
663 Task-level `outputs` may be classified as either
664 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
665 at the Pipeline level.
666 """
668 @classmethod
669 def fromTaskDef(
670 cls,
671 taskDef: TaskDef,
672 *,
673 registry: Registry,
674 include_configs: bool = True,
675 ) -> TaskDatasetTypes:
676 """Extract and classify the dataset types from a single `PipelineTask`.
678 Parameters
679 ----------
680 taskDef: `TaskDef`
681 An instance of a `TaskDef` class for a particular `PipelineTask`.
682 registry: `Registry`
683 Registry used to construct normalized `DatasetType` objects and
684 retrieve those that are incomplete.
685 include_configs : `bool`, optional
686 If `True` (default) include config dataset types as
687 ``initOutputs``.
689 Returns
690 -------
691 types: `TaskDatasetTypes`
692 The dataset types used by this task.
693 """
694 def makeDatasetTypesSet(connectionType: str, freeze: bool = True) -> NamedValueSet[DatasetType]:
695 """Constructs a set of true `DatasetType` objects
697 Parameters
698 ----------
699 connectionType : `str`
700 Name of the connection type to produce a set for, corresponds
701 to an attribute of type `list` on the connection class instance
702 freeze : `bool`, optional
703 If `True`, call `NamedValueSet.freeze` on the object returned.
705 Returns
706 -------
707 datasetTypes : `NamedValueSet`
708 A set of all datasetTypes which correspond to the input
709 connection type specified in the connection class of this
710 `PipelineTask`
712 Notes
713 -----
714 This function is a closure over the variables ``registry`` and
715 ``taskDef``.
716 """
717 datasetTypes = NamedValueSet()
718 for c in iterConnections(taskDef.connections, connectionType):
719 dimensions = set(getattr(c, 'dimensions', set()))
720 if "skypix" in dimensions:
721 try:
722 datasetType = registry.getDatasetType(c.name)
723 except LookupError as err:
724 raise LookupError(
725 f"DatasetType '{c.name}' referenced by "
726 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
727 f"placeholder, but does not already exist in the registry. "
728 f"Note that reference catalog names are now used as the dataset "
729 f"type name instead of 'ref_cat'."
730 ) from err
731 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
732 rest2 = set(dim.name for dim in datasetType.dimensions
733 if not isinstance(dim, SkyPixDimension))
734 if rest1 != rest2:
735 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
736 f"connections ({rest1}) are inconsistent with those in "
737 f"registry's version of this dataset ({rest2}).")
738 else:
739 # Component dataset types are not explicitly in the
740 # registry. This complicates consistency checks with
741 # registry and requires we work out the composite storage
742 # class.
743 registryDatasetType = None
744 try:
745 registryDatasetType = registry.getDatasetType(c.name)
746 except KeyError:
747 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
748 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
749 if componentName else None
750 datasetType = c.makeDatasetType(
751 registry.dimensions,
752 parentStorageClass=parentStorageClass
753 )
754 registryDatasetType = datasetType
755 else:
756 datasetType = c.makeDatasetType(
757 registry.dimensions,
758 parentStorageClass=registryDatasetType.parentStorageClass
759 )
761 if registryDatasetType and datasetType != registryDatasetType:
762 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
763 f"registry definition ({registryDatasetType}) "
764 f"for {taskDef.label}.")
765 datasetTypes.add(datasetType)
766 if freeze:
767 datasetTypes.freeze()
768 return datasetTypes
770 # optionally add initOutput dataset for config
771 initOutputs = makeDatasetTypesSet("initOutputs", freeze=False)
772 if include_configs:
773 initOutputs.add(
774 DatasetType(
775 taskDef.configDatasetName,
776 registry.dimensions.empty,
777 storageClass="Config",
778 )
779 )
780 initOutputs.freeze()
782 # optionally add output dataset for metadata
783 outputs = makeDatasetTypesSet("outputs", freeze=False)
784 if taskDef.metadataDatasetName is not None:
785 # Metadata is supposed to be of the PropertySet type, its
786 # dimensions correspond to a task quantum
787 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
788 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
789 if taskDef.logOutputDatasetName is not None:
790 # Log output dimensions correspond to a task quantum.
791 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
792 outputs |= {DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")}
794 outputs.freeze()
796 return cls(
797 initInputs=makeDatasetTypesSet("initInputs"),
798 initOutputs=initOutputs,
799 inputs=makeDatasetTypesSet("inputs"),
800 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
801 outputs=outputs,
802 )
805@dataclass(frozen=True)
806class PipelineDatasetTypes:
807 """An immutable struct that classifies the dataset types used in a
808 `Pipeline`.
809 """
811 packagesDatasetName: ClassVar[str] = "packages"
812 """Name of a dataset type used to save package versions.
813 """
815 initInputs: NamedValueSet[DatasetType]
816 """Dataset types that are needed as inputs in order to construct the Tasks
817 in this Pipeline.
819 This does not include dataset types that are produced when constructing
820 other Tasks in the Pipeline (these are classified as `initIntermediates`).
821 """
823 initOutputs: NamedValueSet[DatasetType]
824 """Dataset types that may be written after constructing the Tasks in this
825 Pipeline.
827 This does not include dataset types that are also used as inputs when
828 constructing other Tasks in the Pipeline (these are classified as
829 `initIntermediates`).
830 """
832 initIntermediates: NamedValueSet[DatasetType]
833 """Dataset types that are both used when constructing one or more Tasks
834 in the Pipeline and produced as a side-effect of constructing another
835 Task in the Pipeline.
836 """
838 inputs: NamedValueSet[DatasetType]
839 """Dataset types that are regular inputs for the full pipeline.
841 If an input dataset needed for a Quantum cannot be found in the input
842 collection(s), that Quantum (and all dependent Quanta) will not be
843 produced.
844 """
846 prerequisites: NamedValueSet[DatasetType]
847 """Dataset types that are prerequisite inputs for the full Pipeline.
849 Prerequisite inputs must exist in the input collection(s) before the
850 pipeline is run, but do not constrain the graph - if a prerequisite is
851 missing for a Quantum, `PrerequisiteMissingError` is raised.
853 Prerequisite inputs are not resolved until the second stage of
854 QuantumGraph generation.
855 """
857 intermediates: NamedValueSet[DatasetType]
858 """Dataset types that are output by one Task in the Pipeline and consumed
859 as inputs by one or more other Tasks in the Pipeline.
860 """
862 outputs: NamedValueSet[DatasetType]
863 """Dataset types that are output by a Task in the Pipeline and not consumed
864 by any other Task in the Pipeline.
865 """
867 byTask: Mapping[str, TaskDatasetTypes]
868 """Per-Task dataset types, keyed by label in the `Pipeline`.
870 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
871 neither has been modified since the dataset types were extracted, of
872 course).
873 """
875 @classmethod
876 def fromPipeline(
877 cls,
878 pipeline: Union[Pipeline, Iterable[TaskDef]],
879 *,
880 registry: Registry,
881 include_configs: bool = True,
882 include_packages: bool = True,
883 ) -> PipelineDatasetTypes:
884 """Extract and classify the dataset types from all tasks in a
885 `Pipeline`.
887 Parameters
888 ----------
889 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
890 A dependency-ordered collection of tasks that can be run
891 together.
892 registry: `Registry`
893 Registry used to construct normalized `DatasetType` objects and
894 retrieve those that are incomplete.
895 include_configs : `bool`, optional
896 If `True` (default) include config dataset types as
897 ``initOutputs``.
898 include_packages : `bool`, optional
899 If `True` (default) include the dataset type for software package
900 versions in ``initOutputs``.
902 Returns
903 -------
904 types: `PipelineDatasetTypes`
905 The dataset types used by this `Pipeline`.
907 Raises
908 ------
909 ValueError
910 Raised if Tasks are inconsistent about which datasets are marked
911 prerequisite. This indicates that the Tasks cannot be run as part
912 of the same `Pipeline`.
913 """
914 allInputs = NamedValueSet()
915 allOutputs = NamedValueSet()
916 allInitInputs = NamedValueSet()
917 allInitOutputs = NamedValueSet()
918 prerequisites = NamedValueSet()
919 byTask = dict()
920 if include_packages:
921 allInitOutputs.add(
922 DatasetType(
923 cls.packagesDatasetName,
924 registry.dimensions.empty,
925 storageClass="Packages",
926 )
927 )
928 if isinstance(pipeline, Pipeline):
929 pipeline = pipeline.toExpandedPipeline()
930 for taskDef in pipeline:
931 thisTask = TaskDatasetTypes.fromTaskDef(
932 taskDef,
933 registry=registry,
934 include_configs=include_configs,
935 )
936 allInitInputs |= thisTask.initInputs
937 allInitOutputs |= thisTask.initOutputs
938 allInputs |= thisTask.inputs
939 prerequisites |= thisTask.prerequisites
940 allOutputs |= thisTask.outputs
941 byTask[taskDef.label] = thisTask
942 if not prerequisites.isdisjoint(allInputs):
943 raise ValueError("{} marked as both prerequisites and regular inputs".format(
944 {dt.name for dt in allInputs & prerequisites}
945 ))
946 if not prerequisites.isdisjoint(allOutputs):
947 raise ValueError("{} marked as both prerequisites and outputs".format(
948 {dt.name for dt in allOutputs & prerequisites}
949 ))
950 # Make sure that components which are marked as inputs get treated as
951 # intermediates if there is an output which produces the composite
952 # containing the component
953 intermediateComponents = NamedValueSet()
954 intermediateComposites = NamedValueSet()
955 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
956 for dsType in allInputs:
957 # get the name of a possible component
958 name, component = dsType.nameAndComponent()
959 # if there is a component name, that means this is a component
960 # DatasetType, if there is an output which produces the parent of
961 # this component, treat this input as an intermediate
962 if component is not None:
963 if name in outputNameMapping:
964 if outputNameMapping[name].dimensions != dsType.dimensions:
965 raise ValueError(f"Component dataset type {dsType.name} has different "
966 f"dimensions ({dsType.dimensions}) than its parent "
967 f"({outputNameMapping[name].dimensions}).")
968 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
969 universe=registry.dimensions)
970 intermediateComponents.add(dsType)
971 intermediateComposites.add(composite)
973 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
974 common = a.names & b.names
975 for name in common:
976 if a[name] != b[name]:
977 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
979 checkConsistency(allInitInputs, allInitOutputs)
980 checkConsistency(allInputs, allOutputs)
981 checkConsistency(allInputs, intermediateComposites)
982 checkConsistency(allOutputs, intermediateComposites)
984 def frozen(s: NamedValueSet) -> NamedValueSet:
985 s.freeze()
986 return s
988 return cls(
989 initInputs=frozen(allInitInputs - allInitOutputs),
990 initIntermediates=frozen(allInitInputs & allInitOutputs),
991 initOutputs=frozen(allInitOutputs - allInitInputs),
992 inputs=frozen(allInputs - allOutputs - intermediateComponents),
993 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
994 outputs=frozen(allOutputs - allInputs - intermediateComposites),
995 prerequisites=frozen(prerequisites),
996 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
997 )
999 @classmethod
1000 def initOutputNames(cls, pipeline: Union[Pipeline, Iterable[TaskDef]], *,
1001 include_configs: bool = True, include_packages: bool = True) -> Iterator[str]:
1002 """Return the names of dataset types ot task initOutputs, Configs,
1003 and package versions for a pipeline.
1005 Parameters
1006 ----------
1007 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1008 A `Pipeline` instance or collection of `TaskDef` instances.
1009 include_configs : `bool`, optional
1010 If `True` (default) include config dataset types.
1011 include_packages : `bool`, optional
1012 If `True` (default) include the dataset type for package versions.
1014 Yields
1015 ------
1016 datasetTypeName : `str`
1017 Name of the dataset type.
1018 """
1019 if include_packages:
1020 # Package versions dataset type
1021 yield cls.packagesDatasetName
1023 if isinstance(pipeline, Pipeline):
1024 pipeline = pipeline.toExpandedPipeline()
1026 for taskDef in pipeline:
1028 # all task InitOutputs
1029 for name in taskDef.connections.initOutputs:
1030 attribute = getattr(taskDef.connections, name)
1031 yield attribute.name
1033 # config dataset name
1034 if include_configs:
1035 yield taskDef.configDatasetName