Coverage for python/lsst/pipe/base/pipeline.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32from types import MappingProxyType
33from typing import Dict, Iterable, Mapping, Set, Union, Generator, TYPE_CHECKING, Optional, Tuple
35import copy
36import re
37import os
38import urllib.parse
39import warnings
41# -----------------------------
42# Imports for other modules --
43from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension, ButlerURI
44from lsst.utils import doImport
45from .configOverrides import ConfigOverrides
46from .connections import iterConnections
47from .pipelineTask import PipelineTask
49from . import pipelineIR
50from . import pipeTools
52if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from lsst.obs.base import Instrument
55# ----------------------------------
56# Local non-exported definitions --
57# ----------------------------------
59# ------------------------
60# Exported definitions --
61# ------------------------
64@dataclass
65class LabelSpecifier:
66 """A structure to specify a subset of labels to load
68 This structure may contain a set of labels to be used in subsetting a
69 pipeline, or a beginning and end point. Beginning or end may be empty,
70 in which case the range will be a half open interval. Unlike python
71 iteration bounds, end bounds are *INCLUDED*. Note that range based
72 selection is not well defined for pipelines that are not linear in nature,
73 and correct behavior is not guaranteed, or may vary from run to run.
74 """
75 labels: Optional[Set[str]] = None
76 begin: Optional[str] = None
77 end: Optional[str] = None
79 def __post_init__(self):
80 if self.labels is not None and (self.begin or self.end):
81 raise ValueError("This struct can only be initialized with a labels set or "
82 "a begin (and/or) end specifier")
85class TaskDef:
86 """TaskDef is a collection of information about task needed by Pipeline.
88 The information includes task name, configuration object and optional
89 task class. This class is just a collection of attributes and it exposes
90 all of them so that attributes could potentially be modified in place
91 (e.g. if configuration needs extra overrides).
93 Attributes
94 ----------
95 taskName : `str`
96 `PipelineTask` class name, currently it is not specified whether this
97 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
98 Framework should be prepared to handle all cases.
99 config : `lsst.pex.config.Config`
100 Instance of the configuration class corresponding to this task class,
101 usually with all overrides applied. This config will be frozen.
102 taskClass : `type` or ``None``
103 `PipelineTask` class object, can be ``None``. If ``None`` then
104 framework will have to locate and load class.
105 label : `str`, optional
106 Task label, usually a short string unique in a pipeline.
107 """
108 def __init__(self, taskName, config, taskClass=None, label=""):
109 self.taskName = taskName
110 config.freeze()
111 self.config = config
112 self.taskClass = taskClass
113 self.label = label
114 self.connections = config.connections.ConnectionsClass(config=config)
116 @property
117 def configDatasetName(self) -> str:
118 """Name of a dataset type for configuration of this task (`str`)
119 """
120 return self.label + "_config"
122 @property
123 def metadataDatasetName(self) -> Optional[str]:
124 """Name of a dataset type for metadata of this task, `None` if
125 metadata is not to be saved (`str`)
126 """
127 if self.config.saveMetadata:
128 return self.label + "_metadata"
129 else:
130 return None
132 @property
133 def logOutputDatasetName(self) -> Optional[str]:
134 """Name of a dataset type for log output from this task, `None` if
135 logs are not to be saved (`str`)
136 """
137 if self.config.saveLogOutput:
138 return self.label + "_log"
139 else:
140 return None
142 def __str__(self):
143 rep = "TaskDef(" + self.taskName
144 if self.label:
145 rep += ", label=" + self.label
146 rep += ")"
147 return rep
149 def __eq__(self, other: object) -> bool:
150 if not isinstance(other, TaskDef):
151 return False
152 # This does not consider equality of configs when determining equality
153 # as config equality is a difficult thing to define. Should be updated
154 # after DM-27847
155 return self.taskClass == other.taskClass and self.label == other.label
157 def __hash__(self):
158 return hash((self.taskClass, self.label))
161class Pipeline:
162 """A `Pipeline` is a representation of a series of tasks to run, and the
163 configuration for those tasks.
165 Parameters
166 ----------
167 description : `str`
168 A description of that this pipeline does.
169 """
170 def __init__(self, description: str):
171 pipeline_dict = {"description": description, "tasks": {}}
172 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
174 @classmethod
175 def fromFile(cls, filename: str) -> Pipeline:
176 """Load a pipeline defined in a pipeline yaml file.
178 Parameters
179 ----------
180 filename: `str`
181 A path that points to a pipeline defined in yaml format. This
182 filename may also supply additional labels to be used in
183 subsetting the loaded Pipeline. These labels are separated from
184 the path by a \\#, and may be specified as a comma separated
185 list, or a range denoted as beginning..end. Beginning or end may
186 be empty, in which case the range will be a half open interval.
187 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
188 that range based selection is not well defined for pipelines that
189 are not linear in nature, and correct behavior is not guaranteed,
190 or may vary from run to run.
192 Returns
193 -------
194 pipeline: `Pipeline`
195 The pipeline loaded from specified location with appropriate (if
196 any) subsetting
198 Notes
199 -----
200 This method attempts to prune any contracts that contain labels which
201 are not in the declared subset of labels. This pruning is done using a
202 string based matching due to the nature of contracts and may prune more
203 than it should.
204 """
205 return cls.from_uri(filename)
207 @classmethod
208 def from_uri(cls, uri: Union[str, ButlerURI]) -> Pipeline:
209 """Load a pipeline defined in a pipeline yaml file at a location
210 specified by a URI.
212 Parameters
213 ----------
214 uri: `str` or `ButlerURI`
215 If a string is supplied this should be a URI path that points to a
216 pipeline defined in yaml format. This uri may also supply
217 additional labels to be used in subsetting the loaded Pipeline.
218 These labels are separated from the path by a \\#, and may be
219 specified as a comma separated list, or a range denoted as
220 beginning..end. Beginning or end may be empty, in which case the
221 range will be a half open interval. Unlike python iteration
222 bounds, end bounds are *INCLUDED*. Note that range based selection
223 is not well defined for pipelines that are not linear in nature,
224 and correct behavior is not guaranteed, or may vary from run to
225 run. The same specifiers can be used with a ButlerURI object, by
226 being the sole contents in the fragments attribute.
228 Returns
229 -------
230 pipeline: `Pipeline`
231 The pipeline loaded from specified location with appropriate (if
232 any) subsetting
234 Notes
235 -----
236 This method attempts to prune any contracts that contain labels which
237 are not in the declared subset of labels. This pruning is done using a
238 string based matching due to the nature of contracts and may prune more
239 than it should.
240 """
241 # Split up the uri and any labels that were supplied
242 uri, label_specifier = cls._parse_file_specifier(uri)
243 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
245 # If there are labels supplied, only keep those
246 if label_specifier is not None:
247 pipeline = pipeline.subsetFromLabels(label_specifier)
248 return pipeline
250 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
251 """Subset a pipeline to contain only labels specified in labelSpecifier
253 Parameters
254 ----------
255 labelSpecifier : `labelSpecifier`
256 Object containing labels that describes how to subset a pipeline.
258 Returns
259 -------
260 pipeline : `Pipeline`
261 A new pipeline object that is a subset of the old pipeline
263 Raises
264 ------
265 ValueError
266 Raised if there is an issue with specified labels
268 Notes
269 -----
270 This method attempts to prune any contracts that contain labels which
271 are not in the declared subset of labels. This pruning is done using a
272 string based matching due to the nature of contracts and may prune more
273 than it should.
274 """
275 # Labels supplied as a set
276 if labelSpecifier.labels:
277 labelSet = labelSpecifier.labels
278 # Labels supplied as a range, first create a list of all the labels
279 # in the pipeline sorted according to task dependency. Then only
280 # keep labels that lie between the supplied bounds
281 else:
282 # Create a copy of the pipeline to use when assessing the label
283 # ordering. Use a dict for fast searching while preserving order.
284 # Remove contracts so they do not fail in the expansion step. This
285 # is needed because a user may only configure the tasks they intend
286 # to run, which may cause some contracts to fail if they will later
287 # be dropped
288 pipeline = copy.deepcopy(self)
289 pipeline._pipelineIR.contracts = []
290 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
292 # Verify the bounds are in the labels
293 if labelSpecifier.begin is not None:
294 if labelSpecifier.begin not in labels:
295 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
296 "pipeline definition")
297 if labelSpecifier.end is not None:
298 if labelSpecifier.end not in labels:
299 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
300 "definition")
302 labelSet = set()
303 for label in labels:
304 if labelSpecifier.begin is not None:
305 if label != labelSpecifier.begin:
306 continue
307 else:
308 labelSpecifier.begin = None
309 labelSet.add(label)
310 if labelSpecifier.end is not None and label == labelSpecifier.end:
311 break
312 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
314 @staticmethod
315 def _parse_file_specifier(uri: Union[str, ButlerURI]
316 ) -> Tuple[ButlerURI, Optional[LabelSpecifier]]:
317 """Split appart a uri and any possible label subsets
318 """
319 if isinstance(uri, str):
320 # This is to support legacy pipelines during transition
321 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
322 if num_replace:
323 warnings.warn(f"The pipeline file {uri} seems to use the legacy : to separate "
324 "labels, this is deprecated and will be removed after June 2021, please use "
325 "# instead.",
326 category=FutureWarning)
327 if uri.count("#") > 1:
328 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
329 uri = ButlerURI(uri)
330 label_subset = uri.fragment or None
332 specifier: Optional[LabelSpecifier]
333 if label_subset is not None:
334 label_subset = urllib.parse.unquote(label_subset)
335 args: Dict[str, Union[Set[str], str, None]]
336 # labels supplied as a list
337 if ',' in label_subset:
338 if '..' in label_subset:
339 raise ValueError("Can only specify a list of labels or a range"
340 "when loading a Pipline not both")
341 args = {"labels": set(label_subset.split(","))}
342 # labels supplied as a range
343 elif '..' in label_subset:
344 # Try to de-structure the labelSubset, this will fail if more
345 # than one range is specified
346 begin, end, *rest = label_subset.split("..")
347 if rest:
348 raise ValueError("Only one range can be specified when loading a pipeline")
349 args = {"begin": begin if begin else None, "end": end if end else None}
350 # Assume anything else is a single label
351 else:
352 args = {"labels": {label_subset}}
354 specifier = LabelSpecifier(**args)
355 else:
356 specifier = None
358 return uri, specifier
360 @classmethod
361 def fromString(cls, pipeline_string: str) -> Pipeline:
362 """Create a pipeline from string formatted as a pipeline document.
364 Parameters
365 ----------
366 pipeline_string : `str`
367 A string that is formatted according like a pipeline document
369 Returns
370 -------
371 pipeline: `Pipeline`
372 """
373 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
374 return pipeline
376 @classmethod
377 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
378 """Create a pipeline from an already created `PipelineIR` object.
380 Parameters
381 ----------
382 deserialized_pipeline: `PipelineIR`
383 An already created pipeline intermediate representation object
385 Returns
386 -------
387 pipeline: `Pipeline`
388 """
389 pipeline = cls.__new__(cls)
390 pipeline._pipelineIR = deserialized_pipeline
391 return pipeline
393 @classmethod
394 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
395 """Create a new pipeline by copying an already existing `Pipeline`.
397 Parameters
398 ----------
399 pipeline: `Pipeline`
400 An already created pipeline intermediate representation object
402 Returns
403 -------
404 pipeline: `Pipeline`
405 """
406 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
408 def __str__(self) -> str:
409 return str(self._pipelineIR)
411 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
412 """Add an instrument to the pipeline, or replace an instrument that is
413 already defined.
415 Parameters
416 ----------
417 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
418 Either a derived class object of a `lsst.daf.butler.instrument` or
419 a string corresponding to a fully qualified
420 `lsst.daf.butler.instrument` name.
421 """
422 if isinstance(instrument, str):
423 pass
424 else:
425 # TODO: assume that this is a subclass of Instrument, no type
426 # checking
427 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
428 self._pipelineIR.instrument = instrument
430 def getInstrument(self) -> Instrument:
431 """Get the instrument from the pipeline.
433 Returns
434 -------
435 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
436 A derived class object of a `lsst.daf.butler.instrument`, a string
437 corresponding to a fully qualified `lsst.daf.butler.instrument`
438 name, or None if the pipeline does not have an instrument.
439 """
440 return self._pipelineIR.instrument
442 def addTask(self, task: Union[PipelineTask, str], label: str) -> None:
443 """Add a new task to the pipeline, or replace a task that is already
444 associated with the supplied label.
446 Parameters
447 ----------
448 task: `PipelineTask` or `str`
449 Either a derived class object of a `PipelineTask` or a string
450 corresponding to a fully qualified `PipelineTask` name.
451 label: `str`
452 A label that is used to identify the `PipelineTask` being added
453 """
454 if isinstance(task, str):
455 taskName = task
456 elif issubclass(task, PipelineTask):
457 taskName = f"{task.__module__}.{task.__qualname__}"
458 else:
459 raise ValueError("task must be either a child class of PipelineTask or a string containing"
460 " a fully qualified name to one")
461 if not label:
462 # in some cases (with command line-generated pipeline) tasks can
463 # be defined without label which is not acceptable, use task
464 # _DefaultName in that case
465 if isinstance(task, str):
466 task = doImport(task)
467 label = task._DefaultName
468 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
470 def removeTask(self, label: str) -> None:
471 """Remove a task from the pipeline.
473 Parameters
474 ----------
475 label : `str`
476 The label used to identify the task that is to be removed
478 Raises
479 ------
480 KeyError
481 If no task with that label exists in the pipeline
483 """
484 self._pipelineIR.tasks.pop(label)
486 def addConfigOverride(self, label: str, key: str, value: object) -> None:
487 """Apply single config override.
489 Parameters
490 ----------
491 label : `str`
492 Label of the task.
493 key: `str`
494 Fully-qualified field name.
495 value : object
496 Value to be given to a field.
497 """
498 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
500 def addConfigFile(self, label: str, filename: str) -> None:
501 """Add overrides from a specified file.
503 Parameters
504 ----------
505 label : `str`
506 The label used to identify the task associated with config to
507 modify
508 filename : `str`
509 Path to the override file.
510 """
511 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
513 def addConfigPython(self, label: str, pythonString: str) -> None:
514 """Add Overrides by running a snippet of python code against a config.
516 Parameters
517 ----------
518 label : `str`
519 The label used to identity the task associated with config to
520 modify.
521 pythonString: `str`
522 A string which is valid python code to be executed. This is done
523 with config as the only local accessible value.
524 """
525 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
527 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
528 if label == "parameters":
529 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
530 raise ValueError("Cannot override parameters that are not defined in pipeline")
531 self._pipelineIR.parameters.mapping.update(newConfig.rest)
532 if newConfig.file:
533 raise ValueError("Setting parameters section with config file is not supported")
534 if newConfig.python:
535 raise ValueError("Setting parameters section using python block in unsupported")
536 return
537 if label not in self._pipelineIR.tasks:
538 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
539 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
541 def toFile(self, filename: str) -> None:
542 self._pipelineIR.to_file(filename)
544 def write_to_uri(self, uri: Union[str, ButlerURI]) -> None:
545 self._pipelineIR.write_to_uri(uri)
547 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
548 """Returns a generator of TaskDefs which can be used to create quantum
549 graphs.
551 Returns
552 -------
553 generator : generator of `TaskDef`
554 The generator returned will be the sorted iterator of tasks which
555 are to be used in constructing a quantum graph.
557 Raises
558 ------
559 NotImplementedError
560 If a dataId is supplied in a config block. This is in place for
561 future use
562 """
563 taskDefs = []
564 for label, taskIR in self._pipelineIR.tasks.items():
565 taskClass = doImport(taskIR.klass)
566 taskName = taskClass.__qualname__
567 config = taskClass.ConfigClass()
568 overrides = ConfigOverrides()
569 if self._pipelineIR.instrument is not None:
570 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
571 if taskIR.config is not None:
572 for configIR in (configIr.formatted(self._pipelineIR.parameters)
573 for configIr in taskIR.config):
574 if configIR.dataId is not None:
575 raise NotImplementedError("Specializing a config on a partial data id is not yet "
576 "supported in Pipeline definition")
577 # only apply override if it applies to everything
578 if configIR.dataId is None:
579 if configIR.file:
580 for configFile in configIR.file:
581 overrides.addFileOverride(os.path.expandvars(configFile))
582 if configIR.python is not None:
583 overrides.addPythonOverride(configIR.python)
584 for key, value in configIR.rest.items():
585 overrides.addValueOverride(key, value)
586 overrides.applyTo(config)
587 # This may need to be revisited
588 config.validate()
589 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
591 # lets evaluate the contracts
592 if self._pipelineIR.contracts is not None:
593 label_to_config = {x.label: x.config for x in taskDefs}
594 for contract in self._pipelineIR.contracts:
595 # execute this in its own line so it can raise a good error
596 # message if there was problems with the eval
597 success = eval(contract.contract, None, label_to_config)
598 if not success:
599 extra_info = f": {contract.msg}" if contract.msg is not None else ""
600 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
601 f"satisfied{extra_info}")
603 yield from pipeTools.orderPipeline(taskDefs)
605 def __len__(self):
606 return len(self._pipelineIR.tasks)
608 def __eq__(self, other: object):
609 if not isinstance(other, Pipeline):
610 return False
611 return self._pipelineIR == other._pipelineIR
614@dataclass(frozen=True)
615class TaskDatasetTypes:
616 """An immutable struct that extracts and classifies the dataset types used
617 by a `PipelineTask`
618 """
620 initInputs: NamedValueSet[DatasetType]
621 """Dataset types that are needed as inputs in order to construct this Task.
623 Task-level `initInputs` may be classified as either
624 `~PipelineDatasetTypes.initInputs` or
625 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
626 """
628 initOutputs: NamedValueSet[DatasetType]
629 """Dataset types that may be written after constructing this Task.
631 Task-level `initOutputs` may be classified as either
632 `~PipelineDatasetTypes.initOutputs` or
633 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
634 """
636 inputs: NamedValueSet[DatasetType]
637 """Dataset types that are regular inputs to this Task.
639 If an input dataset needed for a Quantum cannot be found in the input
640 collection(s) or produced by another Task in the Pipeline, that Quantum
641 (and all dependent Quanta) will not be produced.
643 Task-level `inputs` may be classified as either
644 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
645 at the Pipeline level.
646 """
648 prerequisites: NamedValueSet[DatasetType]
649 """Dataset types that are prerequisite inputs to this Task.
651 Prerequisite inputs must exist in the input collection(s) before the
652 pipeline is run, but do not constrain the graph - if a prerequisite is
653 missing for a Quantum, `PrerequisiteMissingError` is raised.
655 Prerequisite inputs are not resolved until the second stage of
656 QuantumGraph generation.
657 """
659 outputs: NamedValueSet[DatasetType]
660 """Dataset types that are produced by this Task.
662 Task-level `outputs` may be classified as either
663 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
664 at the Pipeline level.
665 """
667 @classmethod
668 def fromTaskDef(
669 cls,
670 taskDef: TaskDef,
671 *,
672 registry: Registry,
673 include_configs: bool = True,
674 ) -> TaskDatasetTypes:
675 """Extract and classify the dataset types from a single `PipelineTask`.
677 Parameters
678 ----------
679 taskDef: `TaskDef`
680 An instance of a `TaskDef` class for a particular `PipelineTask`.
681 registry: `Registry`
682 Registry used to construct normalized `DatasetType` objects and
683 retrieve those that are incomplete.
684 include_configs : `bool`, optional
685 If `True` (default) include config dataset types as
686 ``initOutputs``.
688 Returns
689 -------
690 types: `TaskDatasetTypes`
691 The dataset types used by this task.
692 """
693 def makeDatasetTypesSet(connectionType: str, freeze: bool = True) -> NamedValueSet[DatasetType]:
694 """Constructs a set of true `DatasetType` objects
696 Parameters
697 ----------
698 connectionType : `str`
699 Name of the connection type to produce a set for, corresponds
700 to an attribute of type `list` on the connection class instance
701 freeze : `bool`, optional
702 If `True`, call `NamedValueSet.freeze` on the object returned.
704 Returns
705 -------
706 datasetTypes : `NamedValueSet`
707 A set of all datasetTypes which correspond to the input
708 connection type specified in the connection class of this
709 `PipelineTask`
711 Notes
712 -----
713 This function is a closure over the variables ``registry`` and
714 ``taskDef``.
715 """
716 datasetTypes = NamedValueSet()
717 for c in iterConnections(taskDef.connections, connectionType):
718 dimensions = set(getattr(c, 'dimensions', set()))
719 if "skypix" in dimensions:
720 try:
721 datasetType = registry.getDatasetType(c.name)
722 except LookupError as err:
723 raise LookupError(
724 f"DatasetType '{c.name}' referenced by "
725 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
726 f"placeholder, but does not already exist in the registry. "
727 f"Note that reference catalog names are now used as the dataset "
728 f"type name instead of 'ref_cat'."
729 ) from err
730 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
731 rest2 = set(dim.name for dim in datasetType.dimensions
732 if not isinstance(dim, SkyPixDimension))
733 if rest1 != rest2:
734 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
735 f"connections ({rest1}) are inconsistent with those in "
736 f"registry's version of this dataset ({rest2}).")
737 else:
738 # Component dataset types are not explicitly in the
739 # registry. This complicates consistency checks with
740 # registry and requires we work out the composite storage
741 # class.
742 registryDatasetType = None
743 try:
744 registryDatasetType = registry.getDatasetType(c.name)
745 except KeyError:
746 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
747 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
748 if componentName else None
749 datasetType = c.makeDatasetType(
750 registry.dimensions,
751 parentStorageClass=parentStorageClass
752 )
753 registryDatasetType = datasetType
754 else:
755 datasetType = c.makeDatasetType(
756 registry.dimensions,
757 parentStorageClass=registryDatasetType.parentStorageClass
758 )
760 if registryDatasetType and datasetType != registryDatasetType:
761 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
762 f"registry definition ({registryDatasetType}) "
763 f"for {taskDef.label}.")
764 datasetTypes.add(datasetType)
765 if freeze:
766 datasetTypes.freeze()
767 return datasetTypes
769 # optionally add initOutput dataset for config
770 initOutputs = makeDatasetTypesSet("initOutputs", freeze=False)
771 if include_configs:
772 initOutputs.add(
773 DatasetType(
774 taskDef.configDatasetName,
775 registry.dimensions.empty,
776 storageClass="Config",
777 )
778 )
779 initOutputs.freeze()
781 # optionally add output dataset for metadata
782 outputs = makeDatasetTypesSet("outputs", freeze=False)
783 if taskDef.metadataDatasetName is not None:
784 # Metadata is supposed to be of the PropertySet type, its
785 # dimensions correspond to a task quantum
786 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
787 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
788 if taskDef.logOutputDatasetName is not None:
789 # Log output dimensions correspond to a task quantum.
790 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
791 outputs |= {DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")}
793 outputs.freeze()
795 return cls(
796 initInputs=makeDatasetTypesSet("initInputs"),
797 initOutputs=initOutputs,
798 inputs=makeDatasetTypesSet("inputs"),
799 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
800 outputs=outputs,
801 )
804@dataclass(frozen=True)
805class PipelineDatasetTypes:
806 """An immutable struct that classifies the dataset types used in a
807 `Pipeline`.
808 """
810 initInputs: NamedValueSet[DatasetType]
811 """Dataset types that are needed as inputs in order to construct the Tasks
812 in this Pipeline.
814 This does not include dataset types that are produced when constructing
815 other Tasks in the Pipeline (these are classified as `initIntermediates`).
816 """
818 initOutputs: NamedValueSet[DatasetType]
819 """Dataset types that may be written after constructing the Tasks in this
820 Pipeline.
822 This does not include dataset types that are also used as inputs when
823 constructing other Tasks in the Pipeline (these are classified as
824 `initIntermediates`).
825 """
827 initIntermediates: NamedValueSet[DatasetType]
828 """Dataset types that are both used when constructing one or more Tasks
829 in the Pipeline and produced as a side-effect of constructing another
830 Task in the Pipeline.
831 """
833 inputs: NamedValueSet[DatasetType]
834 """Dataset types that are regular inputs for the full pipeline.
836 If an input dataset needed for a Quantum cannot be found in the input
837 collection(s), that Quantum (and all dependent Quanta) will not be
838 produced.
839 """
841 prerequisites: NamedValueSet[DatasetType]
842 """Dataset types that are prerequisite inputs for the full Pipeline.
844 Prerequisite inputs must exist in the input collection(s) before the
845 pipeline is run, but do not constrain the graph - if a prerequisite is
846 missing for a Quantum, `PrerequisiteMissingError` is raised.
848 Prerequisite inputs are not resolved until the second stage of
849 QuantumGraph generation.
850 """
852 intermediates: NamedValueSet[DatasetType]
853 """Dataset types that are output by one Task in the Pipeline and consumed
854 as inputs by one or more other Tasks in the Pipeline.
855 """
857 outputs: NamedValueSet[DatasetType]
858 """Dataset types that are output by a Task in the Pipeline and not consumed
859 by any other Task in the Pipeline.
860 """
862 byTask: Mapping[str, TaskDatasetTypes]
863 """Per-Task dataset types, keyed by label in the `Pipeline`.
865 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
866 neither has been modified since the dataset types were extracted, of
867 course).
868 """
870 @classmethod
871 def fromPipeline(
872 cls,
873 pipeline: Union[Pipeline, Iterable[TaskDef]],
874 *,
875 registry: Registry,
876 include_configs: bool = True,
877 include_packages: bool = True,
878 ) -> PipelineDatasetTypes:
879 """Extract and classify the dataset types from all tasks in a
880 `Pipeline`.
882 Parameters
883 ----------
884 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
885 A dependency-ordered collection of tasks that can be run
886 together.
887 registry: `Registry`
888 Registry used to construct normalized `DatasetType` objects and
889 retrieve those that are incomplete.
890 include_configs : `bool`, optional
891 If `True` (default) include config dataset types as
892 ``initOutputs``.
893 include_packages : `bool`, optional
894 If `True` (default) include the dataset type for software package
895 versions in ``initOutputs``.
897 Returns
898 -------
899 types: `PipelineDatasetTypes`
900 The dataset types used by this `Pipeline`.
902 Raises
903 ------
904 ValueError
905 Raised if Tasks are inconsistent about which datasets are marked
906 prerequisite. This indicates that the Tasks cannot be run as part
907 of the same `Pipeline`.
908 """
909 allInputs = NamedValueSet()
910 allOutputs = NamedValueSet()
911 allInitInputs = NamedValueSet()
912 allInitOutputs = NamedValueSet()
913 prerequisites = NamedValueSet()
914 byTask = dict()
915 if include_packages:
916 allInitOutputs.add(
917 DatasetType(
918 "packages",
919 registry.dimensions.empty,
920 storageClass="Packages",
921 )
922 )
923 if isinstance(pipeline, Pipeline):
924 pipeline = pipeline.toExpandedPipeline()
925 for taskDef in pipeline:
926 thisTask = TaskDatasetTypes.fromTaskDef(
927 taskDef,
928 registry=registry,
929 include_configs=include_configs,
930 )
931 allInitInputs |= thisTask.initInputs
932 allInitOutputs |= thisTask.initOutputs
933 allInputs |= thisTask.inputs
934 prerequisites |= thisTask.prerequisites
935 allOutputs |= thisTask.outputs
936 byTask[taskDef.label] = thisTask
937 if not prerequisites.isdisjoint(allInputs):
938 raise ValueError("{} marked as both prerequisites and regular inputs".format(
939 {dt.name for dt in allInputs & prerequisites}
940 ))
941 if not prerequisites.isdisjoint(allOutputs):
942 raise ValueError("{} marked as both prerequisites and outputs".format(
943 {dt.name for dt in allOutputs & prerequisites}
944 ))
945 # Make sure that components which are marked as inputs get treated as
946 # intermediates if there is an output which produces the composite
947 # containing the component
948 intermediateComponents = NamedValueSet()
949 intermediateComposites = NamedValueSet()
950 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
951 for dsType in allInputs:
952 # get the name of a possible component
953 name, component = dsType.nameAndComponent()
954 # if there is a component name, that means this is a component
955 # DatasetType, if there is an output which produces the parent of
956 # this component, treat this input as an intermediate
957 if component is not None:
958 if name in outputNameMapping:
959 if outputNameMapping[name].dimensions != dsType.dimensions:
960 raise ValueError(f"Component dataset type {dsType.name} has different "
961 f"dimensions ({dsType.dimensions}) than its parent "
962 f"({outputNameMapping[name].dimensions}).")
963 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
964 universe=registry.dimensions)
965 intermediateComponents.add(dsType)
966 intermediateComposites.add(composite)
968 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
969 common = a.names & b.names
970 for name in common:
971 if a[name] != b[name]:
972 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
974 checkConsistency(allInitInputs, allInitOutputs)
975 checkConsistency(allInputs, allOutputs)
976 checkConsistency(allInputs, intermediateComposites)
977 checkConsistency(allOutputs, intermediateComposites)
979 def frozen(s: NamedValueSet) -> NamedValueSet:
980 s.freeze()
981 return s
983 return cls(
984 initInputs=frozen(allInitInputs - allInitOutputs),
985 initIntermediates=frozen(allInitInputs & allInitOutputs),
986 initOutputs=frozen(allInitOutputs - allInitInputs),
987 inputs=frozen(allInputs - allOutputs - intermediateComponents),
988 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
989 outputs=frozen(allOutputs - allInputs - intermediateComposites),
990 prerequisites=frozen(prerequisites),
991 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
992 )