Coverage for python/lsst/pipe/base/pipeline.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28# -------------------------------
29# Imports of standard modules --
30# -------------------------------
31from dataclasses import dataclass
32import logging
33from types import MappingProxyType
34from typing import (ClassVar, Dict, Iterable, Iterator, Mapping, Set, Union,
35 Generator, TYPE_CHECKING, Optional, Tuple)
37import copy
38import re
39import os
40import urllib.parse
41import warnings
43# -----------------------------
44# Imports for other modules --
45from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension, ButlerURI
46from lsst.utils import doImport
47from .configOverrides import ConfigOverrides
48from .connections import iterConnections
49from .pipelineTask import PipelineTask
51from . import pipelineIR
52from . import pipeTools
54if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true
55 from lsst.obs.base import Instrument
57# ----------------------------------
58# Local non-exported definitions --
59# ----------------------------------
61_LOG = logging.getLogger(__name__)
63# ------------------------
64# Exported definitions --
65# ------------------------
68@dataclass
69class LabelSpecifier:
70 """A structure to specify a subset of labels to load
72 This structure may contain a set of labels to be used in subsetting a
73 pipeline, or a beginning and end point. Beginning or end may be empty,
74 in which case the range will be a half open interval. Unlike python
75 iteration bounds, end bounds are *INCLUDED*. Note that range based
76 selection is not well defined for pipelines that are not linear in nature,
77 and correct behavior is not guaranteed, or may vary from run to run.
78 """
79 labels: Optional[Set[str]] = None
80 begin: Optional[str] = None
81 end: Optional[str] = None
83 def __post_init__(self):
84 if self.labels is not None and (self.begin or self.end):
85 raise ValueError("This struct can only be initialized with a labels set or "
86 "a begin (and/or) end specifier")
89class TaskDef:
90 """TaskDef is a collection of information about task needed by Pipeline.
92 The information includes task name, configuration object and optional
93 task class. This class is just a collection of attributes and it exposes
94 all of them so that attributes could potentially be modified in place
95 (e.g. if configuration needs extra overrides).
97 Attributes
98 ----------
99 taskName : `str`, optional
100 `PipelineTask` class name, currently it is not specified whether this
101 is a fully-qualified name or partial name (e.g. ``module.TaskClass``).
102 Framework should be prepared to handle all cases. If not provided,
103 ``taskClass`` must be, and ``taskClass.__name__`` is used.
104 config : `lsst.pex.config.Config`, optional
105 Instance of the configuration class corresponding to this task class,
106 usually with all overrides applied. This config will be frozen. If
107 not provided, ``taskClass`` must be provided and
108 ``taskClass.ConfigClass()`` will be used.
109 taskClass : `type`, optional
110 `PipelineTask` class object, can be ``None``. If ``None`` then
111 framework will have to locate and load class.
112 label : `str`, optional
113 Task label, usually a short string unique in a pipeline. If not
114 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
115 be used.
116 """
117 def __init__(self, taskName=None, config=None, taskClass=None, label=None):
118 if taskName is None:
119 if taskClass is None:
120 raise ValueError("At least one of `taskName` and `taskClass` must be provided.")
121 taskName = taskClass.__name__
122 if config is None:
123 if taskClass is None:
124 raise ValueError("`taskClass` must be provided if `config` is not.")
125 config = taskClass.ConfigClass()
126 if label is None:
127 if taskClass is None:
128 raise ValueError("`taskClass` must be provided if `label` is not.")
129 label = taskClass._DefaultName
130 self.taskName = taskName
131 config.freeze()
132 self.config = config
133 self.taskClass = taskClass
134 self.label = label
135 self.connections = config.connections.ConnectionsClass(config=config)
137 @property
138 def configDatasetName(self) -> str:
139 """Name of a dataset type for configuration of this task (`str`)
140 """
141 return self.label + "_config"
143 @property
144 def metadataDatasetName(self) -> Optional[str]:
145 """Name of a dataset type for metadata of this task, `None` if
146 metadata is not to be saved (`str`)
147 """
148 if self.config.saveMetadata:
149 return self.label + "_metadata"
150 else:
151 return None
153 @property
154 def logOutputDatasetName(self) -> Optional[str]:
155 """Name of a dataset type for log output from this task, `None` if
156 logs are not to be saved (`str`)
157 """
158 if self.config.saveLogOutput:
159 return self.label + "_log"
160 else:
161 return None
163 def __str__(self):
164 rep = "TaskDef(" + self.taskName
165 if self.label:
166 rep += ", label=" + self.label
167 rep += ")"
168 return rep
170 def __eq__(self, other: object) -> bool:
171 if not isinstance(other, TaskDef):
172 return False
173 # This does not consider equality of configs when determining equality
174 # as config equality is a difficult thing to define. Should be updated
175 # after DM-27847
176 return self.taskClass == other.taskClass and self.label == other.label
178 def __hash__(self):
179 return hash((self.taskClass, self.label))
182class Pipeline:
183 """A `Pipeline` is a representation of a series of tasks to run, and the
184 configuration for those tasks.
186 Parameters
187 ----------
188 description : `str`
189 A description of that this pipeline does.
190 """
191 def __init__(self, description: str):
192 pipeline_dict = {"description": description, "tasks": {}}
193 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
195 @classmethod
196 def fromFile(cls, filename: str) -> Pipeline:
197 """Load a pipeline defined in a pipeline yaml file.
199 Parameters
200 ----------
201 filename: `str`
202 A path that points to a pipeline defined in yaml format. This
203 filename may also supply additional labels to be used in
204 subsetting the loaded Pipeline. These labels are separated from
205 the path by a \\#, and may be specified as a comma separated
206 list, or a range denoted as beginning..end. Beginning or end may
207 be empty, in which case the range will be a half open interval.
208 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
209 that range based selection is not well defined for pipelines that
210 are not linear in nature, and correct behavior is not guaranteed,
211 or may vary from run to run.
213 Returns
214 -------
215 pipeline: `Pipeline`
216 The pipeline loaded from specified location with appropriate (if
217 any) subsetting
219 Notes
220 -----
221 This method attempts to prune any contracts that contain labels which
222 are not in the declared subset of labels. This pruning is done using a
223 string based matching due to the nature of contracts and may prune more
224 than it should.
225 """
226 return cls.from_uri(filename)
228 @classmethod
229 def from_uri(cls, uri: Union[str, ButlerURI]) -> Pipeline:
230 """Load a pipeline defined in a pipeline yaml file at a location
231 specified by a URI.
233 Parameters
234 ----------
235 uri: `str` or `ButlerURI`
236 If a string is supplied this should be a URI path that points to a
237 pipeline defined in yaml format. This uri may also supply
238 additional labels to be used in subsetting the loaded Pipeline.
239 These labels are separated from the path by a \\#, and may be
240 specified as a comma separated list, or a range denoted as
241 beginning..end. Beginning or end may be empty, in which case the
242 range will be a half open interval. Unlike python iteration
243 bounds, end bounds are *INCLUDED*. Note that range based selection
244 is not well defined for pipelines that are not linear in nature,
245 and correct behavior is not guaranteed, or may vary from run to
246 run. The same specifiers can be used with a ButlerURI object, by
247 being the sole contents in the fragments attribute.
249 Returns
250 -------
251 pipeline: `Pipeline`
252 The pipeline loaded from specified location with appropriate (if
253 any) subsetting
255 Notes
256 -----
257 This method attempts to prune any contracts that contain labels which
258 are not in the declared subset of labels. This pruning is done using a
259 string based matching due to the nature of contracts and may prune more
260 than it should.
261 """
262 # Split up the uri and any labels that were supplied
263 uri, label_specifier = cls._parse_file_specifier(uri)
264 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
266 # If there are labels supplied, only keep those
267 if label_specifier is not None:
268 pipeline = pipeline.subsetFromLabels(label_specifier)
269 return pipeline
271 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
272 """Subset a pipeline to contain only labels specified in labelSpecifier
274 Parameters
275 ----------
276 labelSpecifier : `labelSpecifier`
277 Object containing labels that describes how to subset a pipeline.
279 Returns
280 -------
281 pipeline : `Pipeline`
282 A new pipeline object that is a subset of the old pipeline
284 Raises
285 ------
286 ValueError
287 Raised if there is an issue with specified labels
289 Notes
290 -----
291 This method attempts to prune any contracts that contain labels which
292 are not in the declared subset of labels. This pruning is done using a
293 string based matching due to the nature of contracts and may prune more
294 than it should.
295 """
296 # Labels supplied as a set
297 if labelSpecifier.labels:
298 labelSet = labelSpecifier.labels
299 # Labels supplied as a range, first create a list of all the labels
300 # in the pipeline sorted according to task dependency. Then only
301 # keep labels that lie between the supplied bounds
302 else:
303 # Create a copy of the pipeline to use when assessing the label
304 # ordering. Use a dict for fast searching while preserving order.
305 # Remove contracts so they do not fail in the expansion step. This
306 # is needed because a user may only configure the tasks they intend
307 # to run, which may cause some contracts to fail if they will later
308 # be dropped
309 pipeline = copy.deepcopy(self)
310 pipeline._pipelineIR.contracts = []
311 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
313 # Verify the bounds are in the labels
314 if labelSpecifier.begin is not None:
315 if labelSpecifier.begin not in labels:
316 raise ValueError(f"Beginning of range subset, {labelSpecifier.begin}, not found in "
317 "pipeline definition")
318 if labelSpecifier.end is not None:
319 if labelSpecifier.end not in labels:
320 raise ValueError(f"End of range subset, {labelSpecifier.end}, not found in pipeline "
321 "definition")
323 labelSet = set()
324 for label in labels:
325 if labelSpecifier.begin is not None:
326 if label != labelSpecifier.begin:
327 continue
328 else:
329 labelSpecifier.begin = None
330 labelSet.add(label)
331 if labelSpecifier.end is not None and label == labelSpecifier.end:
332 break
333 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
335 @staticmethod
336 def _parse_file_specifier(uri: Union[str, ButlerURI]
337 ) -> Tuple[ButlerURI, Optional[LabelSpecifier]]:
338 """Split appart a uri and any possible label subsets
339 """
340 if isinstance(uri, str):
341 # This is to support legacy pipelines during transition
342 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
343 if num_replace:
344 warnings.warn(f"The pipeline file {uri} seems to use the legacy : to separate "
345 "labels, this is deprecated and will be removed after June 2021, please use "
346 "# instead.",
347 category=FutureWarning)
348 if uri.count("#") > 1:
349 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
350 uri = ButlerURI(uri)
351 label_subset = uri.fragment or None
353 specifier: Optional[LabelSpecifier]
354 if label_subset is not None:
355 label_subset = urllib.parse.unquote(label_subset)
356 args: Dict[str, Union[Set[str], str, None]]
357 # labels supplied as a list
358 if ',' in label_subset:
359 if '..' in label_subset:
360 raise ValueError("Can only specify a list of labels or a range"
361 "when loading a Pipline not both")
362 args = {"labels": set(label_subset.split(","))}
363 # labels supplied as a range
364 elif '..' in label_subset:
365 # Try to de-structure the labelSubset, this will fail if more
366 # than one range is specified
367 begin, end, *rest = label_subset.split("..")
368 if rest:
369 raise ValueError("Only one range can be specified when loading a pipeline")
370 args = {"begin": begin if begin else None, "end": end if end else None}
371 # Assume anything else is a single label
372 else:
373 args = {"labels": {label_subset}}
375 specifier = LabelSpecifier(**args)
376 else:
377 specifier = None
379 return uri, specifier
381 @classmethod
382 def fromString(cls, pipeline_string: str) -> Pipeline:
383 """Create a pipeline from string formatted as a pipeline document.
385 Parameters
386 ----------
387 pipeline_string : `str`
388 A string that is formatted according like a pipeline document
390 Returns
391 -------
392 pipeline: `Pipeline`
393 """
394 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
395 return pipeline
397 @classmethod
398 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
399 """Create a pipeline from an already created `PipelineIR` object.
401 Parameters
402 ----------
403 deserialized_pipeline: `PipelineIR`
404 An already created pipeline intermediate representation object
406 Returns
407 -------
408 pipeline: `Pipeline`
409 """
410 pipeline = cls.__new__(cls)
411 pipeline._pipelineIR = deserialized_pipeline
412 return pipeline
414 @classmethod
415 def fromPipeline(cls, pipeline: pipelineIR.PipelineIR) -> Pipeline:
416 """Create a new pipeline by copying an already existing `Pipeline`.
418 Parameters
419 ----------
420 pipeline: `Pipeline`
421 An already created pipeline intermediate representation object
423 Returns
424 -------
425 pipeline: `Pipeline`
426 """
427 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
429 def __str__(self) -> str:
430 return str(self._pipelineIR)
432 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
433 """Add an instrument to the pipeline, or replace an instrument that is
434 already defined.
436 Parameters
437 ----------
438 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
439 Either a derived class object of a `lsst.daf.butler.instrument` or
440 a string corresponding to a fully qualified
441 `lsst.daf.butler.instrument` name.
442 """
443 if isinstance(instrument, str):
444 pass
445 else:
446 # TODO: assume that this is a subclass of Instrument, no type
447 # checking
448 instrument = f"{instrument.__module__}.{instrument.__qualname__}"
449 self._pipelineIR.instrument = instrument
451 def getInstrument(self) -> Instrument:
452 """Get the instrument from the pipeline.
454 Returns
455 -------
456 instrument : `~lsst.daf.butler.instrument.Instrument`, `str`, or None
457 A derived class object of a `lsst.daf.butler.instrument`, a string
458 corresponding to a fully qualified `lsst.daf.butler.instrument`
459 name, or None if the pipeline does not have an instrument.
460 """
461 return self._pipelineIR.instrument
463 def addTask(self, task: Union[PipelineTask, str], label: str) -> None:
464 """Add a new task to the pipeline, or replace a task that is already
465 associated with the supplied label.
467 Parameters
468 ----------
469 task: `PipelineTask` or `str`
470 Either a derived class object of a `PipelineTask` or a string
471 corresponding to a fully qualified `PipelineTask` name.
472 label: `str`
473 A label that is used to identify the `PipelineTask` being added
474 """
475 if isinstance(task, str):
476 taskName = task
477 elif issubclass(task, PipelineTask):
478 taskName = f"{task.__module__}.{task.__qualname__}"
479 else:
480 raise ValueError("task must be either a child class of PipelineTask or a string containing"
481 " a fully qualified name to one")
482 if not label:
483 # in some cases (with command line-generated pipeline) tasks can
484 # be defined without label which is not acceptable, use task
485 # _DefaultName in that case
486 if isinstance(task, str):
487 task = doImport(task)
488 label = task._DefaultName
489 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
491 def removeTask(self, label: str) -> None:
492 """Remove a task from the pipeline.
494 Parameters
495 ----------
496 label : `str`
497 The label used to identify the task that is to be removed
499 Raises
500 ------
501 KeyError
502 If no task with that label exists in the pipeline
504 """
505 self._pipelineIR.tasks.pop(label)
507 def addConfigOverride(self, label: str, key: str, value: object) -> None:
508 """Apply single config override.
510 Parameters
511 ----------
512 label : `str`
513 Label of the task.
514 key: `str`
515 Fully-qualified field name.
516 value : object
517 Value to be given to a field.
518 """
519 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
521 def addConfigFile(self, label: str, filename: str) -> None:
522 """Add overrides from a specified file.
524 Parameters
525 ----------
526 label : `str`
527 The label used to identify the task associated with config to
528 modify
529 filename : `str`
530 Path to the override file.
531 """
532 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
534 def addConfigPython(self, label: str, pythonString: str) -> None:
535 """Add Overrides by running a snippet of python code against a config.
537 Parameters
538 ----------
539 label : `str`
540 The label used to identity the task associated with config to
541 modify.
542 pythonString: `str`
543 A string which is valid python code to be executed. This is done
544 with config as the only local accessible value.
545 """
546 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
548 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
549 if label == "parameters":
550 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
551 raise ValueError("Cannot override parameters that are not defined in pipeline")
552 self._pipelineIR.parameters.mapping.update(newConfig.rest)
553 if newConfig.file:
554 raise ValueError("Setting parameters section with config file is not supported")
555 if newConfig.python:
556 raise ValueError("Setting parameters section using python block in unsupported")
557 return
558 if label not in self._pipelineIR.tasks:
559 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
560 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
562 def toFile(self, filename: str) -> None:
563 self._pipelineIR.to_file(filename)
565 def write_to_uri(self, uri: Union[str, ButlerURI]) -> None:
566 self._pipelineIR.write_to_uri(uri)
568 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
569 """Returns a generator of TaskDefs which can be used to create quantum
570 graphs.
572 Returns
573 -------
574 generator : generator of `TaskDef`
575 The generator returned will be the sorted iterator of tasks which
576 are to be used in constructing a quantum graph.
578 Raises
579 ------
580 NotImplementedError
581 If a dataId is supplied in a config block. This is in place for
582 future use
583 """
584 taskDefs = []
585 for label, taskIR in self._pipelineIR.tasks.items():
586 taskClass = doImport(taskIR.klass)
587 taskName = taskClass.__qualname__
588 config = taskClass.ConfigClass()
589 overrides = ConfigOverrides()
590 if self._pipelineIR.instrument is not None:
591 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
592 if taskIR.config is not None:
593 for configIR in (configIr.formatted(self._pipelineIR.parameters)
594 for configIr in taskIR.config):
595 if configIR.dataId is not None:
596 raise NotImplementedError("Specializing a config on a partial data id is not yet "
597 "supported in Pipeline definition")
598 # only apply override if it applies to everything
599 if configIR.dataId is None:
600 if configIR.file:
601 for configFile in configIR.file:
602 overrides.addFileOverride(os.path.expandvars(configFile))
603 if configIR.python is not None:
604 overrides.addPythonOverride(configIR.python)
605 for key, value in configIR.rest.items():
606 overrides.addValueOverride(key, value)
607 overrides.applyTo(config)
608 # This may need to be revisited
609 try:
610 config.validate()
611 except Exception:
612 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
613 raise
614 taskDefs.append(TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label))
616 # lets evaluate the contracts
617 if self._pipelineIR.contracts is not None:
618 label_to_config = {x.label: x.config for x in taskDefs}
619 for contract in self._pipelineIR.contracts:
620 # execute this in its own line so it can raise a good error
621 # message if there was problems with the eval
622 success = eval(contract.contract, None, label_to_config)
623 if not success:
624 extra_info = f": {contract.msg}" if contract.msg is not None else ""
625 raise pipelineIR.ContractError(f"Contract(s) '{contract.contract}' were not "
626 f"satisfied{extra_info}")
628 yield from pipeTools.orderPipeline(taskDefs)
630 def __len__(self):
631 return len(self._pipelineIR.tasks)
633 def __eq__(self, other: object):
634 if not isinstance(other, Pipeline):
635 return False
636 return self._pipelineIR == other._pipelineIR
639@dataclass(frozen=True)
640class TaskDatasetTypes:
641 """An immutable struct that extracts and classifies the dataset types used
642 by a `PipelineTask`
643 """
645 initInputs: NamedValueSet[DatasetType]
646 """Dataset types that are needed as inputs in order to construct this Task.
648 Task-level `initInputs` may be classified as either
649 `~PipelineDatasetTypes.initInputs` or
650 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
651 """
653 initOutputs: NamedValueSet[DatasetType]
654 """Dataset types that may be written after constructing this Task.
656 Task-level `initOutputs` may be classified as either
657 `~PipelineDatasetTypes.initOutputs` or
658 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
659 """
661 inputs: NamedValueSet[DatasetType]
662 """Dataset types that are regular inputs to this Task.
664 If an input dataset needed for a Quantum cannot be found in the input
665 collection(s) or produced by another Task in the Pipeline, that Quantum
666 (and all dependent Quanta) will not be produced.
668 Task-level `inputs` may be classified as either
669 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
670 at the Pipeline level.
671 """
673 prerequisites: NamedValueSet[DatasetType]
674 """Dataset types that are prerequisite inputs to this Task.
676 Prerequisite inputs must exist in the input collection(s) before the
677 pipeline is run, but do not constrain the graph - if a prerequisite is
678 missing for a Quantum, `PrerequisiteMissingError` is raised.
680 Prerequisite inputs are not resolved until the second stage of
681 QuantumGraph generation.
682 """
684 outputs: NamedValueSet[DatasetType]
685 """Dataset types that are produced by this Task.
687 Task-level `outputs` may be classified as either
688 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
689 at the Pipeline level.
690 """
692 @classmethod
693 def fromTaskDef(
694 cls,
695 taskDef: TaskDef,
696 *,
697 registry: Registry,
698 include_configs: bool = True,
699 ) -> TaskDatasetTypes:
700 """Extract and classify the dataset types from a single `PipelineTask`.
702 Parameters
703 ----------
704 taskDef: `TaskDef`
705 An instance of a `TaskDef` class for a particular `PipelineTask`.
706 registry: `Registry`
707 Registry used to construct normalized `DatasetType` objects and
708 retrieve those that are incomplete.
709 include_configs : `bool`, optional
710 If `True` (default) include config dataset types as
711 ``initOutputs``.
713 Returns
714 -------
715 types: `TaskDatasetTypes`
716 The dataset types used by this task.
717 """
718 def makeDatasetTypesSet(connectionType: str, freeze: bool = True) -> NamedValueSet[DatasetType]:
719 """Constructs a set of true `DatasetType` objects
721 Parameters
722 ----------
723 connectionType : `str`
724 Name of the connection type to produce a set for, corresponds
725 to an attribute of type `list` on the connection class instance
726 freeze : `bool`, optional
727 If `True`, call `NamedValueSet.freeze` on the object returned.
729 Returns
730 -------
731 datasetTypes : `NamedValueSet`
732 A set of all datasetTypes which correspond to the input
733 connection type specified in the connection class of this
734 `PipelineTask`
736 Notes
737 -----
738 This function is a closure over the variables ``registry`` and
739 ``taskDef``.
740 """
741 datasetTypes = NamedValueSet()
742 for c in iterConnections(taskDef.connections, connectionType):
743 dimensions = set(getattr(c, 'dimensions', set()))
744 if "skypix" in dimensions:
745 try:
746 datasetType = registry.getDatasetType(c.name)
747 except LookupError as err:
748 raise LookupError(
749 f"DatasetType '{c.name}' referenced by "
750 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
751 f"placeholder, but does not already exist in the registry. "
752 f"Note that reference catalog names are now used as the dataset "
753 f"type name instead of 'ref_cat'."
754 ) from err
755 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
756 rest2 = set(dim.name for dim in datasetType.dimensions
757 if not isinstance(dim, SkyPixDimension))
758 if rest1 != rest2:
759 raise ValueError(f"Non-skypix dimensions for dataset type {c.name} declared in "
760 f"connections ({rest1}) are inconsistent with those in "
761 f"registry's version of this dataset ({rest2}).")
762 else:
763 # Component dataset types are not explicitly in the
764 # registry. This complicates consistency checks with
765 # registry and requires we work out the composite storage
766 # class.
767 registryDatasetType = None
768 try:
769 registryDatasetType = registry.getDatasetType(c.name)
770 except KeyError:
771 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
772 parentStorageClass = DatasetType.PlaceholderParentStorageClass \
773 if componentName else None
774 datasetType = c.makeDatasetType(
775 registry.dimensions,
776 parentStorageClass=parentStorageClass
777 )
778 registryDatasetType = datasetType
779 else:
780 datasetType = c.makeDatasetType(
781 registry.dimensions,
782 parentStorageClass=registryDatasetType.parentStorageClass
783 )
785 if registryDatasetType and datasetType != registryDatasetType:
786 try:
787 # Explicitly check for storage class just to make
788 # more specific message.
789 _ = datasetType.storageClass
790 except KeyError:
791 raise ValueError("Storage class does not exist for supplied dataset type "
792 f"{datasetType} for {taskDef.label}.") from None
793 raise ValueError(f"Supplied dataset type ({datasetType}) inconsistent with "
794 f"registry definition ({registryDatasetType}) "
795 f"for {taskDef.label}.")
796 datasetTypes.add(datasetType)
797 if freeze:
798 datasetTypes.freeze()
799 return datasetTypes
801 # optionally add initOutput dataset for config
802 initOutputs = makeDatasetTypesSet("initOutputs", freeze=False)
803 if include_configs:
804 initOutputs.add(
805 DatasetType(
806 taskDef.configDatasetName,
807 registry.dimensions.empty,
808 storageClass="Config",
809 )
810 )
811 initOutputs.freeze()
813 # optionally add output dataset for metadata
814 outputs = makeDatasetTypesSet("outputs", freeze=False)
815 if taskDef.metadataDatasetName is not None:
816 # Metadata is supposed to be of the PropertySet type, its
817 # dimensions correspond to a task quantum
818 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
819 outputs |= {DatasetType(taskDef.metadataDatasetName, dimensions, "PropertySet")}
820 if taskDef.logOutputDatasetName is not None:
821 # Log output dimensions correspond to a task quantum.
822 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
823 outputs |= {DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")}
825 outputs.freeze()
827 return cls(
828 initInputs=makeDatasetTypesSet("initInputs"),
829 initOutputs=initOutputs,
830 inputs=makeDatasetTypesSet("inputs"),
831 prerequisites=makeDatasetTypesSet("prerequisiteInputs"),
832 outputs=outputs,
833 )
836@dataclass(frozen=True)
837class PipelineDatasetTypes:
838 """An immutable struct that classifies the dataset types used in a
839 `Pipeline`.
840 """
842 packagesDatasetName: ClassVar[str] = "packages"
843 """Name of a dataset type used to save package versions.
844 """
846 initInputs: NamedValueSet[DatasetType]
847 """Dataset types that are needed as inputs in order to construct the Tasks
848 in this Pipeline.
850 This does not include dataset types that are produced when constructing
851 other Tasks in the Pipeline (these are classified as `initIntermediates`).
852 """
854 initOutputs: NamedValueSet[DatasetType]
855 """Dataset types that may be written after constructing the Tasks in this
856 Pipeline.
858 This does not include dataset types that are also used as inputs when
859 constructing other Tasks in the Pipeline (these are classified as
860 `initIntermediates`).
861 """
863 initIntermediates: NamedValueSet[DatasetType]
864 """Dataset types that are both used when constructing one or more Tasks
865 in the Pipeline and produced as a side-effect of constructing another
866 Task in the Pipeline.
867 """
869 inputs: NamedValueSet[DatasetType]
870 """Dataset types that are regular inputs for the full pipeline.
872 If an input dataset needed for a Quantum cannot be found in the input
873 collection(s), that Quantum (and all dependent Quanta) will not be
874 produced.
875 """
877 prerequisites: NamedValueSet[DatasetType]
878 """Dataset types that are prerequisite inputs for the full Pipeline.
880 Prerequisite inputs must exist in the input collection(s) before the
881 pipeline is run, but do not constrain the graph - if a prerequisite is
882 missing for a Quantum, `PrerequisiteMissingError` is raised.
884 Prerequisite inputs are not resolved until the second stage of
885 QuantumGraph generation.
886 """
888 intermediates: NamedValueSet[DatasetType]
889 """Dataset types that are output by one Task in the Pipeline and consumed
890 as inputs by one or more other Tasks in the Pipeline.
891 """
893 outputs: NamedValueSet[DatasetType]
894 """Dataset types that are output by a Task in the Pipeline and not consumed
895 by any other Task in the Pipeline.
896 """
898 byTask: Mapping[str, TaskDatasetTypes]
899 """Per-Task dataset types, keyed by label in the `Pipeline`.
901 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
902 neither has been modified since the dataset types were extracted, of
903 course).
904 """
906 @classmethod
907 def fromPipeline(
908 cls,
909 pipeline: Union[Pipeline, Iterable[TaskDef]],
910 *,
911 registry: Registry,
912 include_configs: bool = True,
913 include_packages: bool = True,
914 ) -> PipelineDatasetTypes:
915 """Extract and classify the dataset types from all tasks in a
916 `Pipeline`.
918 Parameters
919 ----------
920 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
921 A collection of tasks that can be run together.
922 registry: `Registry`
923 Registry used to construct normalized `DatasetType` objects and
924 retrieve those that are incomplete.
925 include_configs : `bool`, optional
926 If `True` (default) include config dataset types as
927 ``initOutputs``.
928 include_packages : `bool`, optional
929 If `True` (default) include the dataset type for software package
930 versions in ``initOutputs``.
932 Returns
933 -------
934 types: `PipelineDatasetTypes`
935 The dataset types used by this `Pipeline`.
937 Raises
938 ------
939 ValueError
940 Raised if Tasks are inconsistent about which datasets are marked
941 prerequisite. This indicates that the Tasks cannot be run as part
942 of the same `Pipeline`.
943 """
944 allInputs = NamedValueSet()
945 allOutputs = NamedValueSet()
946 allInitInputs = NamedValueSet()
947 allInitOutputs = NamedValueSet()
948 prerequisites = NamedValueSet()
949 byTask = dict()
950 if include_packages:
951 allInitOutputs.add(
952 DatasetType(
953 cls.packagesDatasetName,
954 registry.dimensions.empty,
955 storageClass="Packages",
956 )
957 )
958 if isinstance(pipeline, Pipeline):
959 pipeline = pipeline.toExpandedPipeline()
960 for taskDef in pipeline:
961 thisTask = TaskDatasetTypes.fromTaskDef(
962 taskDef,
963 registry=registry,
964 include_configs=include_configs,
965 )
966 allInitInputs |= thisTask.initInputs
967 allInitOutputs |= thisTask.initOutputs
968 allInputs |= thisTask.inputs
969 prerequisites |= thisTask.prerequisites
970 allOutputs |= thisTask.outputs
971 byTask[taskDef.label] = thisTask
972 if not prerequisites.isdisjoint(allInputs):
973 raise ValueError("{} marked as both prerequisites and regular inputs".format(
974 {dt.name for dt in allInputs & prerequisites}
975 ))
976 if not prerequisites.isdisjoint(allOutputs):
977 raise ValueError("{} marked as both prerequisites and outputs".format(
978 {dt.name for dt in allOutputs & prerequisites}
979 ))
980 # Make sure that components which are marked as inputs get treated as
981 # intermediates if there is an output which produces the composite
982 # containing the component
983 intermediateComponents = NamedValueSet()
984 intermediateComposites = NamedValueSet()
985 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
986 for dsType in allInputs:
987 # get the name of a possible component
988 name, component = dsType.nameAndComponent()
989 # if there is a component name, that means this is a component
990 # DatasetType, if there is an output which produces the parent of
991 # this component, treat this input as an intermediate
992 if component is not None:
993 if name in outputNameMapping:
994 if outputNameMapping[name].dimensions != dsType.dimensions:
995 raise ValueError(f"Component dataset type {dsType.name} has different "
996 f"dimensions ({dsType.dimensions}) than its parent "
997 f"({outputNameMapping[name].dimensions}).")
998 composite = DatasetType(name, dsType.dimensions, outputNameMapping[name].storageClass,
999 universe=registry.dimensions)
1000 intermediateComponents.add(dsType)
1001 intermediateComposites.add(composite)
1003 def checkConsistency(a: NamedValueSet, b: NamedValueSet):
1004 common = a.names & b.names
1005 for name in common:
1006 if a[name] != b[name]:
1007 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
1009 checkConsistency(allInitInputs, allInitOutputs)
1010 checkConsistency(allInputs, allOutputs)
1011 checkConsistency(allInputs, intermediateComposites)
1012 checkConsistency(allOutputs, intermediateComposites)
1014 def frozen(s: NamedValueSet) -> NamedValueSet:
1015 s.freeze()
1016 return s
1018 return cls(
1019 initInputs=frozen(allInitInputs - allInitOutputs),
1020 initIntermediates=frozen(allInitInputs & allInitOutputs),
1021 initOutputs=frozen(allInitOutputs - allInitInputs),
1022 inputs=frozen(allInputs - allOutputs - intermediateComponents),
1023 intermediates=frozen(allInputs & allOutputs | intermediateComponents),
1024 outputs=frozen(allOutputs - allInputs - intermediateComposites),
1025 prerequisites=frozen(prerequisites),
1026 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
1027 )
1029 @classmethod
1030 def initOutputNames(cls, pipeline: Union[Pipeline, Iterable[TaskDef]], *,
1031 include_configs: bool = True, include_packages: bool = True) -> Iterator[str]:
1032 """Return the names of dataset types ot task initOutputs, Configs,
1033 and package versions for a pipeline.
1035 Parameters
1036 ----------
1037 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1038 A `Pipeline` instance or collection of `TaskDef` instances.
1039 include_configs : `bool`, optional
1040 If `True` (default) include config dataset types.
1041 include_packages : `bool`, optional
1042 If `True` (default) include the dataset type for package versions.
1044 Yields
1045 ------
1046 datasetTypeName : `str`
1047 Name of the dataset type.
1048 """
1049 if include_packages:
1050 # Package versions dataset type
1051 yield cls.packagesDatasetName
1053 if isinstance(pipeline, Pipeline):
1054 pipeline = pipeline.toExpandedPipeline()
1056 for taskDef in pipeline:
1058 # all task InitOutputs
1059 for name in taskDef.connections.initOutputs:
1060 attribute = getattr(taskDef.connections, name)
1061 yield attribute.name
1063 # config dataset name
1064 if include_configs:
1065 yield taskDef.configDatasetName