Coverage for python/lsst/pipe/base/pipeline.py: 20%
421 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-28 02:09 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-28 02:09 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28import copy
29import logging
30import os
31import re
32import urllib.parse
33import warnings
35# -------------------------------
36# Imports of standard modules --
37# -------------------------------
38from dataclasses import dataclass
39from types import MappingProxyType
40from typing import (
41 TYPE_CHECKING,
42 AbstractSet,
43 Callable,
44 ClassVar,
45 Dict,
46 Generator,
47 Iterable,
48 Iterator,
49 Mapping,
50 Optional,
51 Set,
52 Tuple,
53 Type,
54 Union,
55)
57# -----------------------------
58# Imports for other modules --
59from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
60from lsst.resources import ResourcePath, ResourcePathExpression
61from lsst.utils import doImportType
62from lsst.utils.introspection import get_full_type_name
64from . import pipelineIR, pipeTools
65from ._task_metadata import TaskMetadata
66from .configOverrides import ConfigOverrides
67from .connections import iterConnections
68from .pipelineTask import PipelineTask
69from .task import _TASK_METADATA_TYPE
71if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 71 ↛ 72line 71 didn't jump to line 72, because the condition on line 71 was never true
72 from lsst.obs.base import Instrument
73 from lsst.pex.config import Config
75# ----------------------------------
76# Local non-exported definitions --
77# ----------------------------------
79_LOG = logging.getLogger(__name__)
81# ------------------------
82# Exported definitions --
83# ------------------------
86@dataclass
87class LabelSpecifier:
88 """A structure to specify a subset of labels to load
90 This structure may contain a set of labels to be used in subsetting a
91 pipeline, or a beginning and end point. Beginning or end may be empty,
92 in which case the range will be a half open interval. Unlike python
93 iteration bounds, end bounds are *INCLUDED*. Note that range based
94 selection is not well defined for pipelines that are not linear in nature,
95 and correct behavior is not guaranteed, or may vary from run to run.
96 """
98 labels: Optional[Set[str]] = None
99 begin: Optional[str] = None
100 end: Optional[str] = None
102 def __post_init__(self) -> None:
103 if self.labels is not None and (self.begin or self.end):
104 raise ValueError(
105 "This struct can only be initialized with a labels set or a begin (and/or) end specifier"
106 )
109class TaskDef:
110 """TaskDef is a collection of information about task needed by Pipeline.
112 The information includes task name, configuration object and optional
113 task class. This class is just a collection of attributes and it exposes
114 all of them so that attributes could potentially be modified in place
115 (e.g. if configuration needs extra overrides).
117 Attributes
118 ----------
119 taskName : `str`, optional
120 The fully-qualified `PipelineTask` class name. If not provided,
121 ``taskClass`` must be.
122 config : `lsst.pex.config.Config`, optional
123 Instance of the configuration class corresponding to this task class,
124 usually with all overrides applied. This config will be frozen. If
125 not provided, ``taskClass`` must be provided and
126 ``taskClass.ConfigClass()`` will be used.
127 taskClass : `type`, optional
128 `PipelineTask` class object; if provided and ``taskName`` is as well,
129 the caller guarantees that they are consistent. If not provided,
130 ``taskName`` is used to import the type.
131 label : `str`, optional
132 Task label, usually a short string unique in a pipeline. If not
133 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
134 be used.
135 """
137 def __init__(
138 self,
139 taskName: Optional[str] = None,
140 config: Optional[Config] = None,
141 taskClass: Optional[Type[PipelineTask]] = None,
142 label: Optional[str] = None,
143 ):
144 if taskName is None:
145 if taskClass is None:
146 raise ValueError("At least one of `taskName` and `taskClass` must be provided.")
147 taskName = get_full_type_name(taskClass)
148 elif taskClass is None:
149 taskClass = doImportType(taskName)
150 if config is None:
151 if taskClass is None:
152 raise ValueError("`taskClass` must be provided if `config` is not.")
153 config = taskClass.ConfigClass()
154 if label is None:
155 if taskClass is None:
156 raise ValueError("`taskClass` must be provided if `label` is not.")
157 label = taskClass._DefaultName
158 self.taskName = taskName
159 try:
160 config.validate()
161 except Exception:
162 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
163 raise
164 config.freeze()
165 self.config = config
166 self.taskClass = taskClass
167 self.label = label
168 self.connections = config.connections.ConnectionsClass(config=config)
170 @property
171 def configDatasetName(self) -> str:
172 """Name of a dataset type for configuration of this task (`str`)"""
173 return self.label + "_config"
175 @property
176 def metadataDatasetName(self) -> Optional[str]:
177 """Name of a dataset type for metadata of this task, `None` if
178 metadata is not to be saved (`str`)
179 """
180 if self.config.saveMetadata:
181 return self.makeMetadataDatasetName(self.label)
182 else:
183 return None
185 @classmethod
186 def makeMetadataDatasetName(cls, label: str) -> str:
187 """Construct the name of the dataset type for metadata for a task.
189 Parameters
190 ----------
191 label : `str`
192 Label for the task within its pipeline.
194 Returns
195 -------
196 name : `str`
197 Name of the task's metadata dataset type.
198 """
199 return f"{label}_metadata"
201 @property
202 def logOutputDatasetName(self) -> Optional[str]:
203 """Name of a dataset type for log output from this task, `None` if
204 logs are not to be saved (`str`)
205 """
206 if self.config.saveLogOutput:
207 return self.label + "_log"
208 else:
209 return None
211 def __str__(self) -> str:
212 rep = "TaskDef(" + self.taskName
213 if self.label:
214 rep += ", label=" + self.label
215 rep += ")"
216 return rep
218 def __eq__(self, other: object) -> bool:
219 if not isinstance(other, TaskDef):
220 return False
221 # This does not consider equality of configs when determining equality
222 # as config equality is a difficult thing to define. Should be updated
223 # after DM-27847
224 return self.taskClass == other.taskClass and self.label == other.label
226 def __hash__(self) -> int:
227 return hash((self.taskClass, self.label))
229 @classmethod
230 def _unreduce(cls, taskName: str, config: Config, label: str) -> TaskDef:
231 """Custom callable for unpickling.
233 All arguments are forwarded directly to the constructor; this
234 trampoline is only needed because ``__reduce__`` callables can't be
235 called with keyword arguments.
236 """
237 return cls(taskName=taskName, config=config, label=label)
239 def __reduce__(self) -> Tuple[Callable[[str, Config, str], TaskDef], Tuple[str, Config, str]]:
240 return (self._unreduce, (self.taskName, self.config, self.label))
243class Pipeline:
244 """A `Pipeline` is a representation of a series of tasks to run, and the
245 configuration for those tasks.
247 Parameters
248 ----------
249 description : `str`
250 A description of that this pipeline does.
251 """
253 def __init__(self, description: str):
254 pipeline_dict = {"description": description, "tasks": {}}
255 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
257 @classmethod
258 def fromFile(cls, filename: str) -> Pipeline:
259 """Load a pipeline defined in a pipeline yaml file.
261 Parameters
262 ----------
263 filename: `str`
264 A path that points to a pipeline defined in yaml format. This
265 filename may also supply additional labels to be used in
266 subsetting the loaded Pipeline. These labels are separated from
267 the path by a \\#, and may be specified as a comma separated
268 list, or a range denoted as beginning..end. Beginning or end may
269 be empty, in which case the range will be a half open interval.
270 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
271 that range based selection is not well defined for pipelines that
272 are not linear in nature, and correct behavior is not guaranteed,
273 or may vary from run to run.
275 Returns
276 -------
277 pipeline: `Pipeline`
278 The pipeline loaded from specified location with appropriate (if
279 any) subsetting
281 Notes
282 -----
283 This method attempts to prune any contracts that contain labels which
284 are not in the declared subset of labels. This pruning is done using a
285 string based matching due to the nature of contracts and may prune more
286 than it should.
287 """
288 return cls.from_uri(filename)
290 @classmethod
291 def from_uri(cls, uri: ResourcePathExpression) -> Pipeline:
292 """Load a pipeline defined in a pipeline yaml file at a location
293 specified by a URI.
295 Parameters
296 ----------
297 uri: convertible to `ResourcePath`
298 If a string is supplied this should be a URI path that points to a
299 pipeline defined in yaml format, either as a direct path to the
300 yaml file, or as a directory containing a "pipeline.yaml" file (the
301 form used by `write_to_uri` with ``expand=True``). This uri may
302 also supply additional labels to be used in subsetting the loaded
303 Pipeline. These labels are separated from the path by a \\#, and
304 may be specified as a comma separated list, or a range denoted as
305 beginning..end. Beginning or end may be empty, in which case the
306 range will be a half open interval. Unlike python iteration bounds,
307 end bounds are *INCLUDED*. Note that range based selection is not
308 well defined for pipelines that are not linear in nature, and
309 correct behavior is not guaranteed, or may vary from run to run.
310 The same specifiers can be used with a `ResourcePath` object, by
311 being the sole contents in the fragments attribute.
313 Returns
314 -------
315 pipeline: `Pipeline`
316 The pipeline loaded from specified location with appropriate (if
317 any) subsetting
319 Notes
320 -----
321 This method attempts to prune any contracts that contain labels which
322 are not in the declared subset of labels. This pruning is done using a
323 string based matching due to the nature of contracts and may prune more
324 than it should.
325 """
326 # Split up the uri and any labels that were supplied
327 uri, label_specifier = cls._parse_file_specifier(uri)
328 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
330 # If there are labels supplied, only keep those
331 if label_specifier is not None:
332 pipeline = pipeline.subsetFromLabels(label_specifier)
333 return pipeline
335 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
336 """Subset a pipeline to contain only labels specified in labelSpecifier
338 Parameters
339 ----------
340 labelSpecifier : `labelSpecifier`
341 Object containing labels that describes how to subset a pipeline.
343 Returns
344 -------
345 pipeline : `Pipeline`
346 A new pipeline object that is a subset of the old pipeline
348 Raises
349 ------
350 ValueError
351 Raised if there is an issue with specified labels
353 Notes
354 -----
355 This method attempts to prune any contracts that contain labels which
356 are not in the declared subset of labels. This pruning is done using a
357 string based matching due to the nature of contracts and may prune more
358 than it should.
359 """
360 # Labels supplied as a set
361 if labelSpecifier.labels:
362 labelSet = labelSpecifier.labels
363 # Labels supplied as a range, first create a list of all the labels
364 # in the pipeline sorted according to task dependency. Then only
365 # keep labels that lie between the supplied bounds
366 else:
367 # Create a copy of the pipeline to use when assessing the label
368 # ordering. Use a dict for fast searching while preserving order.
369 # Remove contracts so they do not fail in the expansion step. This
370 # is needed because a user may only configure the tasks they intend
371 # to run, which may cause some contracts to fail if they will later
372 # be dropped
373 pipeline = copy.deepcopy(self)
374 pipeline._pipelineIR.contracts = []
375 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
377 # Verify the bounds are in the labels
378 if labelSpecifier.begin is not None:
379 if labelSpecifier.begin not in labels:
380 raise ValueError(
381 f"Beginning of range subset, {labelSpecifier.begin}, not found in "
382 "pipeline definition"
383 )
384 if labelSpecifier.end is not None:
385 if labelSpecifier.end not in labels:
386 raise ValueError(
387 f"End of range subset, {labelSpecifier.end}, not found in pipeline definition"
388 )
390 labelSet = set()
391 for label in labels:
392 if labelSpecifier.begin is not None:
393 if label != labelSpecifier.begin:
394 continue
395 else:
396 labelSpecifier.begin = None
397 labelSet.add(label)
398 if labelSpecifier.end is not None and label == labelSpecifier.end:
399 break
400 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
402 @staticmethod
403 def _parse_file_specifier(uri: ResourcePathExpression) -> Tuple[ResourcePath, Optional[LabelSpecifier]]:
404 """Split appart a uri and any possible label subsets"""
405 if isinstance(uri, str):
406 # This is to support legacy pipelines during transition
407 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
408 if num_replace:
409 warnings.warn(
410 f"The pipeline file {uri} seems to use the legacy : to separate "
411 "labels, this is deprecated and will be removed after June 2021, please use "
412 "# instead.",
413 category=FutureWarning,
414 )
415 if uri.count("#") > 1:
416 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
417 # Everything else can be converted directly to ResourcePath.
418 uri = ResourcePath(uri)
419 label_subset = uri.fragment or None
421 specifier: Optional[LabelSpecifier]
422 if label_subset is not None:
423 label_subset = urllib.parse.unquote(label_subset)
424 args: Dict[str, Union[Set[str], str, None]]
425 # labels supplied as a list
426 if "," in label_subset:
427 if ".." in label_subset:
428 raise ValueError(
429 "Can only specify a list of labels or a rangewhen loading a Pipline not both"
430 )
431 args = {"labels": set(label_subset.split(","))}
432 # labels supplied as a range
433 elif ".." in label_subset:
434 # Try to de-structure the labelSubset, this will fail if more
435 # than one range is specified
436 begin, end, *rest = label_subset.split("..")
437 if rest:
438 raise ValueError("Only one range can be specified when loading a pipeline")
439 args = {"begin": begin if begin else None, "end": end if end else None}
440 # Assume anything else is a single label
441 else:
442 args = {"labels": {label_subset}}
444 # MyPy doesn't like how cavalier kwarg construction is with types.
445 specifier = LabelSpecifier(**args) # type: ignore
446 else:
447 specifier = None
449 return uri, specifier
451 @classmethod
452 def fromString(cls, pipeline_string: str) -> Pipeline:
453 """Create a pipeline from string formatted as a pipeline document.
455 Parameters
456 ----------
457 pipeline_string : `str`
458 A string that is formatted according like a pipeline document
460 Returns
461 -------
462 pipeline: `Pipeline`
463 """
464 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
465 return pipeline
467 @classmethod
468 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
469 """Create a pipeline from an already created `PipelineIR` object.
471 Parameters
472 ----------
473 deserialized_pipeline: `PipelineIR`
474 An already created pipeline intermediate representation object
476 Returns
477 -------
478 pipeline: `Pipeline`
479 """
480 pipeline = cls.__new__(cls)
481 pipeline._pipelineIR = deserialized_pipeline
482 return pipeline
484 @classmethod
485 def fromPipeline(cls, pipeline: Pipeline) -> Pipeline:
486 """Create a new pipeline by copying an already existing `Pipeline`.
488 Parameters
489 ----------
490 pipeline: `Pipeline`
491 An already created pipeline intermediate representation object
493 Returns
494 -------
495 pipeline: `Pipeline`
496 """
497 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
499 def __str__(self) -> str:
500 return str(self._pipelineIR)
502 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
503 """Add an instrument to the pipeline, or replace an instrument that is
504 already defined.
506 Parameters
507 ----------
508 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
509 Either a derived class object of a `lsst.daf.butler.instrument` or
510 a string corresponding to a fully qualified
511 `lsst.daf.butler.instrument` name.
512 """
513 if isinstance(instrument, str):
514 pass
515 else:
516 # TODO: assume that this is a subclass of Instrument, no type
517 # checking
518 instrument = get_full_type_name(instrument)
519 self._pipelineIR.instrument = instrument
521 def getInstrument(self) -> Optional[str]:
522 """Get the instrument from the pipeline.
524 Returns
525 -------
526 instrument : `str`, or None
527 The fully qualified name of a `lsst.obs.base.Instrument` subclass,
528 name, or None if the pipeline does not have an instrument.
529 """
530 return self._pipelineIR.instrument
532 def addTask(self, task: Union[Type[PipelineTask], str], label: str) -> None:
533 """Add a new task to the pipeline, or replace a task that is already
534 associated with the supplied label.
536 Parameters
537 ----------
538 task: `PipelineTask` or `str`
539 Either a derived class object of a `PipelineTask` or a string
540 corresponding to a fully qualified `PipelineTask` name.
541 label: `str`
542 A label that is used to identify the `PipelineTask` being added
543 """
544 if isinstance(task, str):
545 taskName = task
546 elif issubclass(task, PipelineTask):
547 taskName = get_full_type_name(task)
548 else:
549 raise ValueError(
550 "task must be either a child class of PipelineTask or a string containing"
551 " a fully qualified name to one"
552 )
553 if not label:
554 # in some cases (with command line-generated pipeline) tasks can
555 # be defined without label which is not acceptable, use task
556 # _DefaultName in that case
557 if isinstance(task, str):
558 task_class = doImportType(task)
559 label = task_class._DefaultName
560 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
562 def removeTask(self, label: str) -> None:
563 """Remove a task from the pipeline.
565 Parameters
566 ----------
567 label : `str`
568 The label used to identify the task that is to be removed
570 Raises
571 ------
572 KeyError
573 If no task with that label exists in the pipeline
575 """
576 self._pipelineIR.tasks.pop(label)
578 def addConfigOverride(self, label: str, key: str, value: object) -> None:
579 """Apply single config override.
581 Parameters
582 ----------
583 label : `str`
584 Label of the task.
585 key: `str`
586 Fully-qualified field name.
587 value : object
588 Value to be given to a field.
589 """
590 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
592 def addConfigFile(self, label: str, filename: str) -> None:
593 """Add overrides from a specified file.
595 Parameters
596 ----------
597 label : `str`
598 The label used to identify the task associated with config to
599 modify
600 filename : `str`
601 Path to the override file.
602 """
603 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
605 def addConfigPython(self, label: str, pythonString: str) -> None:
606 """Add Overrides by running a snippet of python code against a config.
608 Parameters
609 ----------
610 label : `str`
611 The label used to identity the task associated with config to
612 modify.
613 pythonString: `str`
614 A string which is valid python code to be executed. This is done
615 with config as the only local accessible value.
616 """
617 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
619 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
620 if label == "parameters":
621 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
622 raise ValueError("Cannot override parameters that are not defined in pipeline")
623 self._pipelineIR.parameters.mapping.update(newConfig.rest)
624 if newConfig.file:
625 raise ValueError("Setting parameters section with config file is not supported")
626 if newConfig.python:
627 raise ValueError("Setting parameters section using python block in unsupported")
628 return
629 if label not in self._pipelineIR.tasks:
630 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
631 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
633 def toFile(self, filename: str) -> None:
634 self._pipelineIR.to_file(filename)
636 def write_to_uri(self, uri: ResourcePathExpression) -> None:
637 """Write the pipeline to a file or directory.
639 Parameters
640 ----------
641 uri : convertible to `ResourcePath`
642 URI to write to; may have any scheme with `ResourcePath` write
643 support or no scheme for a local file/directory. Should have a
644 ``.yaml``.
645 """
646 self._pipelineIR.write_to_uri(uri)
648 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
649 """Returns a generator of TaskDefs which can be used to create quantum
650 graphs.
652 Returns
653 -------
654 generator : generator of `TaskDef`
655 The generator returned will be the sorted iterator of tasks which
656 are to be used in constructing a quantum graph.
658 Raises
659 ------
660 NotImplementedError
661 If a dataId is supplied in a config block. This is in place for
662 future use
663 """
664 taskDefs = []
665 for label in self._pipelineIR.tasks:
666 taskDefs.append(self._buildTaskDef(label))
668 # lets evaluate the contracts
669 if self._pipelineIR.contracts is not None:
670 label_to_config = {x.label: x.config for x in taskDefs}
671 for contract in self._pipelineIR.contracts:
672 # execute this in its own line so it can raise a good error
673 # message if there was problems with the eval
674 success = eval(contract.contract, None, label_to_config)
675 if not success:
676 extra_info = f": {contract.msg}" if contract.msg is not None else ""
677 raise pipelineIR.ContractError(
678 f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
679 )
681 taskDefs = sorted(taskDefs, key=lambda x: x.label)
682 yield from pipeTools.orderPipeline(taskDefs)
684 def _buildTaskDef(self, label: str) -> TaskDef:
685 if (taskIR := self._pipelineIR.tasks.get(label)) is None:
686 raise NameError(f"Label {label} does not appear in this pipeline")
687 taskClass: Type[PipelineTask] = doImportType(taskIR.klass)
688 taskName = get_full_type_name(taskClass)
689 config = taskClass.ConfigClass()
690 overrides = ConfigOverrides()
691 if self._pipelineIR.instrument is not None:
692 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
693 if taskIR.config is not None:
694 for configIR in (configIr.formatted(self._pipelineIR.parameters) for configIr in taskIR.config):
695 if configIR.dataId is not None:
696 raise NotImplementedError(
697 "Specializing a config on a partial data id is not yet "
698 "supported in Pipeline definition"
699 )
700 # only apply override if it applies to everything
701 if configIR.dataId is None:
702 if configIR.file:
703 for configFile in configIR.file:
704 overrides.addFileOverride(os.path.expandvars(configFile))
705 if configIR.python is not None:
706 overrides.addPythonOverride(configIR.python)
707 for key, value in configIR.rest.items():
708 overrides.addValueOverride(key, value)
709 overrides.applyTo(config)
710 return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)
712 def __iter__(self) -> Generator[TaskDef, None, None]:
713 return self.toExpandedPipeline()
715 def __getitem__(self, item: str) -> TaskDef:
716 return self._buildTaskDef(item)
718 def __len__(self) -> int:
719 return len(self._pipelineIR.tasks)
721 def __eq__(self, other: object) -> bool:
722 if not isinstance(other, Pipeline):
723 return False
724 elif self._pipelineIR == other._pipelineIR:
725 # Shortcut: if the IR is the same, the expanded pipeline must be
726 # the same as well. But the converse is not true.
727 return True
728 else:
729 self_expanded = {td.label: (td.taskClass,) for td in self}
730 other_expanded = {td.label: (td.taskClass,) for td in other}
731 if self_expanded != other_expanded:
732 return False
733 # After DM-27847, we should compare configuration here, or better,
734 # delegated to TaskDef.__eq__ after making that compare configurations.
735 raise NotImplementedError(
736 "Pipelines cannot be compared because config instances cannot be compared; see DM-27847."
737 )
740@dataclass(frozen=True)
741class TaskDatasetTypes:
742 """An immutable struct that extracts and classifies the dataset types used
743 by a `PipelineTask`
744 """
746 initInputs: NamedValueSet[DatasetType]
747 """Dataset types that are needed as inputs in order to construct this Task.
749 Task-level `initInputs` may be classified as either
750 `~PipelineDatasetTypes.initInputs` or
751 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
752 """
754 initOutputs: NamedValueSet[DatasetType]
755 """Dataset types that may be written after constructing this Task.
757 Task-level `initOutputs` may be classified as either
758 `~PipelineDatasetTypes.initOutputs` or
759 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
760 """
762 inputs: NamedValueSet[DatasetType]
763 """Dataset types that are regular inputs to this Task.
765 If an input dataset needed for a Quantum cannot be found in the input
766 collection(s) or produced by another Task in the Pipeline, that Quantum
767 (and all dependent Quanta) will not be produced.
769 Task-level `inputs` may be classified as either
770 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
771 at the Pipeline level.
772 """
774 prerequisites: NamedValueSet[DatasetType]
775 """Dataset types that are prerequisite inputs to this Task.
777 Prerequisite inputs must exist in the input collection(s) before the
778 pipeline is run, but do not constrain the graph - if a prerequisite is
779 missing for a Quantum, `PrerequisiteMissingError` is raised.
781 Prerequisite inputs are not resolved until the second stage of
782 QuantumGraph generation.
783 """
785 outputs: NamedValueSet[DatasetType]
786 """Dataset types that are produced by this Task.
788 Task-level `outputs` may be classified as either
789 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
790 at the Pipeline level.
791 """
793 @classmethod
794 def fromTaskDef(
795 cls,
796 taskDef: TaskDef,
797 *,
798 registry: Registry,
799 include_configs: bool = True,
800 storage_class_mapping: Optional[Mapping[str, str]] = None,
801 ) -> TaskDatasetTypes:
802 """Extract and classify the dataset types from a single `PipelineTask`.
804 Parameters
805 ----------
806 taskDef: `TaskDef`
807 An instance of a `TaskDef` class for a particular `PipelineTask`.
808 registry: `Registry`
809 Registry used to construct normalized `DatasetType` objects and
810 retrieve those that are incomplete.
811 include_configs : `bool`, optional
812 If `True` (default) include config dataset types as
813 ``initOutputs``.
814 storage_class_mapping : `Mapping` of `str` to `StorageClass`, optional
815 If a taskdef contains a component dataset type that is unknown
816 to the registry, its parent StorageClass will be looked up in this
817 mapping if it is supplied. If the mapping does not contain the
818 composite dataset type, or the mapping is not supplied an exception
819 will be raised.
821 Returns
822 -------
823 types: `TaskDatasetTypes`
824 The dataset types used by this task.
826 Raises
827 ------
828 ValueError
829 Raised if dataset type connection definition differs from
830 registry definition.
831 LookupError
832 Raised if component parent StorageClass could not be determined
833 and storage_class_mapping does not contain the composite type, or
834 is set to None.
835 """
837 def makeDatasetTypesSet(
838 connectionType: str,
839 is_input: bool,
840 freeze: bool = True,
841 ) -> NamedValueSet[DatasetType]:
842 """Constructs a set of true `DatasetType` objects
844 Parameters
845 ----------
846 connectionType : `str`
847 Name of the connection type to produce a set for, corresponds
848 to an attribute of type `list` on the connection class instance
849 is_input : `bool`
850 These are input dataset types, else they are output dataset
851 types.
852 freeze : `bool`, optional
853 If `True`, call `NamedValueSet.freeze` on the object returned.
855 Returns
856 -------
857 datasetTypes : `NamedValueSet`
858 A set of all datasetTypes which correspond to the input
859 connection type specified in the connection class of this
860 `PipelineTask`
862 Raises
863 ------
864 ValueError
865 Raised if dataset type connection definition differs from
866 registry definition.
867 LookupError
868 Raised if component parent StorageClass could not be determined
869 and storage_class_mapping does not contain the composite type,
870 or is set to None.
872 Notes
873 -----
874 This function is a closure over the variables ``registry`` and
875 ``taskDef``, and ``storage_class_mapping``.
876 """
877 datasetTypes = NamedValueSet[DatasetType]()
878 for c in iterConnections(taskDef.connections, connectionType):
879 dimensions = set(getattr(c, "dimensions", set()))
880 if "skypix" in dimensions:
881 try:
882 datasetType = registry.getDatasetType(c.name)
883 except LookupError as err:
884 raise LookupError(
885 f"DatasetType '{c.name}' referenced by "
886 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
887 f"placeholder, but does not already exist in the registry. "
888 f"Note that reference catalog names are now used as the dataset "
889 f"type name instead of 'ref_cat'."
890 ) from err
891 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
892 rest2 = set(
893 dim.name for dim in datasetType.dimensions if not isinstance(dim, SkyPixDimension)
894 )
895 if rest1 != rest2:
896 raise ValueError(
897 f"Non-skypix dimensions for dataset type {c.name} declared in "
898 f"connections ({rest1}) are inconsistent with those in "
899 f"registry's version of this dataset ({rest2})."
900 )
901 else:
902 # Component dataset types are not explicitly in the
903 # registry. This complicates consistency checks with
904 # registry and requires we work out the composite storage
905 # class.
906 registryDatasetType = None
907 try:
908 registryDatasetType = registry.getDatasetType(c.name)
909 except KeyError:
910 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
911 if componentName:
912 if storage_class_mapping is None or compositeName not in storage_class_mapping:
913 raise LookupError(
914 "Component parent class cannot be determined, and "
915 "composite name was not in storage class mapping, or no "
916 "storage_class_mapping was supplied"
917 )
918 else:
919 parentStorageClass = storage_class_mapping[compositeName]
920 else:
921 parentStorageClass = None
922 datasetType = c.makeDatasetType(
923 registry.dimensions, parentStorageClass=parentStorageClass
924 )
925 registryDatasetType = datasetType
926 else:
927 datasetType = c.makeDatasetType(
928 registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass
929 )
931 if registryDatasetType and datasetType != registryDatasetType:
932 # The dataset types differ but first check to see if
933 # they are compatible before raising.
934 if is_input:
935 # This DatasetType must be compatible on get.
936 is_compatible = datasetType.is_compatible_with(registryDatasetType)
937 else:
938 # Has to be able to be converted to expect type
939 # on put.
940 is_compatible = registryDatasetType.is_compatible_with(datasetType)
941 if is_compatible:
942 # For inputs we want the pipeline to use the
943 # pipeline definition, for outputs it should use
944 # the registry definition.
945 if not is_input:
946 datasetType = registryDatasetType
947 _LOG.debug(
948 "Dataset types differ (task %s != registry %s) but are compatible"
949 " for %s in %s.",
950 datasetType,
951 registryDatasetType,
952 "input" if is_input else "output",
953 taskDef.label,
954 )
955 else:
956 try:
957 # Explicitly check for storage class just to
958 # make more specific message.
959 _ = datasetType.storageClass
960 except KeyError:
961 raise ValueError(
962 "Storage class does not exist for supplied dataset type "
963 f"{datasetType} for {taskDef.label}."
964 ) from None
965 raise ValueError(
966 f"Supplied dataset type ({datasetType}) inconsistent with "
967 f"registry definition ({registryDatasetType}) "
968 f"for {taskDef.label}."
969 )
970 datasetTypes.add(datasetType)
971 if freeze:
972 datasetTypes.freeze()
973 return datasetTypes
975 # optionally add initOutput dataset for config
976 initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False)
977 if include_configs:
978 initOutputs.add(
979 DatasetType(
980 taskDef.configDatasetName,
981 registry.dimensions.empty,
982 storageClass="Config",
983 )
984 )
985 initOutputs.freeze()
987 # optionally add output dataset for metadata
988 outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False)
989 if taskDef.metadataDatasetName is not None:
990 # Metadata is supposed to be of the TaskMetadata type, its
991 # dimensions correspond to a task quantum.
992 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
994 # Allow the storage class definition to be read from the existing
995 # dataset type definition if present.
996 try:
997 current = registry.getDatasetType(taskDef.metadataDatasetName)
998 except KeyError:
999 # No previous definition so use the default.
1000 storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet"
1001 else:
1002 storageClass = current.storageClass.name
1004 outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
1005 if taskDef.logOutputDatasetName is not None:
1006 # Log output dimensions correspond to a task quantum.
1007 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
1008 outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")})
1010 outputs.freeze()
1012 return cls(
1013 initInputs=makeDatasetTypesSet("initInputs", is_input=True),
1014 initOutputs=initOutputs,
1015 inputs=makeDatasetTypesSet("inputs", is_input=True),
1016 prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True),
1017 outputs=outputs,
1018 )
1021@dataclass(frozen=True)
1022class PipelineDatasetTypes:
1023 """An immutable struct that classifies the dataset types used in a
1024 `Pipeline`.
1025 """
1027 packagesDatasetName: ClassVar[str] = "packages"
1028 """Name of a dataset type used to save package versions.
1029 """
1031 initInputs: NamedValueSet[DatasetType]
1032 """Dataset types that are needed as inputs in order to construct the Tasks
1033 in this Pipeline.
1035 This does not include dataset types that are produced when constructing
1036 other Tasks in the Pipeline (these are classified as `initIntermediates`).
1037 """
1039 initOutputs: NamedValueSet[DatasetType]
1040 """Dataset types that may be written after constructing the Tasks in this
1041 Pipeline.
1043 This does not include dataset types that are also used as inputs when
1044 constructing other Tasks in the Pipeline (these are classified as
1045 `initIntermediates`).
1046 """
1048 initIntermediates: NamedValueSet[DatasetType]
1049 """Dataset types that are both used when constructing one or more Tasks
1050 in the Pipeline and produced as a side-effect of constructing another
1051 Task in the Pipeline.
1052 """
1054 inputs: NamedValueSet[DatasetType]
1055 """Dataset types that are regular inputs for the full pipeline.
1057 If an input dataset needed for a Quantum cannot be found in the input
1058 collection(s), that Quantum (and all dependent Quanta) will not be
1059 produced.
1060 """
1062 prerequisites: NamedValueSet[DatasetType]
1063 """Dataset types that are prerequisite inputs for the full Pipeline.
1065 Prerequisite inputs must exist in the input collection(s) before the
1066 pipeline is run, but do not constrain the graph - if a prerequisite is
1067 missing for a Quantum, `PrerequisiteMissingError` is raised.
1069 Prerequisite inputs are not resolved until the second stage of
1070 QuantumGraph generation.
1071 """
1073 intermediates: NamedValueSet[DatasetType]
1074 """Dataset types that are output by one Task in the Pipeline and consumed
1075 as inputs by one or more other Tasks in the Pipeline.
1076 """
1078 outputs: NamedValueSet[DatasetType]
1079 """Dataset types that are output by a Task in the Pipeline and not consumed
1080 by any other Task in the Pipeline.
1081 """
1083 byTask: Mapping[str, TaskDatasetTypes]
1084 """Per-Task dataset types, keyed by label in the `Pipeline`.
1086 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
1087 neither has been modified since the dataset types were extracted, of
1088 course).
1089 """
1091 @classmethod
1092 def fromPipeline(
1093 cls,
1094 pipeline: Union[Pipeline, Iterable[TaskDef]],
1095 *,
1096 registry: Registry,
1097 include_configs: bool = True,
1098 include_packages: bool = True,
1099 ) -> PipelineDatasetTypes:
1100 """Extract and classify the dataset types from all tasks in a
1101 `Pipeline`.
1103 Parameters
1104 ----------
1105 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1106 A collection of tasks that can be run together.
1107 registry: `Registry`
1108 Registry used to construct normalized `DatasetType` objects and
1109 retrieve those that are incomplete.
1110 include_configs : `bool`, optional
1111 If `True` (default) include config dataset types as
1112 ``initOutputs``.
1113 include_packages : `bool`, optional
1114 If `True` (default) include the dataset type for software package
1115 versions in ``initOutputs``.
1117 Returns
1118 -------
1119 types: `PipelineDatasetTypes`
1120 The dataset types used by this `Pipeline`.
1122 Raises
1123 ------
1124 ValueError
1125 Raised if Tasks are inconsistent about which datasets are marked
1126 prerequisite. This indicates that the Tasks cannot be run as part
1127 of the same `Pipeline`.
1128 """
1129 allInputs = NamedValueSet[DatasetType]()
1130 allOutputs = NamedValueSet[DatasetType]()
1131 allInitInputs = NamedValueSet[DatasetType]()
1132 allInitOutputs = NamedValueSet[DatasetType]()
1133 prerequisites = NamedValueSet[DatasetType]()
1134 byTask = dict()
1135 if include_packages:
1136 allInitOutputs.add(
1137 DatasetType(
1138 cls.packagesDatasetName,
1139 registry.dimensions.empty,
1140 storageClass="Packages",
1141 )
1142 )
1143 # create a list of TaskDefs in case the input is a generator
1144 pipeline = list(pipeline)
1146 # collect all the output dataset types
1147 typeStorageclassMap: Dict[str, str] = {}
1148 for taskDef in pipeline:
1149 for outConnection in iterConnections(taskDef.connections, "outputs"):
1150 typeStorageclassMap[outConnection.name] = outConnection.storageClass
1152 for taskDef in pipeline:
1153 thisTask = TaskDatasetTypes.fromTaskDef(
1154 taskDef,
1155 registry=registry,
1156 include_configs=include_configs,
1157 storage_class_mapping=typeStorageclassMap,
1158 )
1159 allInitInputs.update(thisTask.initInputs)
1160 allInitOutputs.update(thisTask.initOutputs)
1161 allInputs.update(thisTask.inputs)
1162 prerequisites.update(thisTask.prerequisites)
1163 allOutputs.update(thisTask.outputs)
1164 byTask[taskDef.label] = thisTask
1165 if not prerequisites.isdisjoint(allInputs):
1166 raise ValueError(
1167 "{} marked as both prerequisites and regular inputs".format(
1168 {dt.name for dt in allInputs & prerequisites}
1169 )
1170 )
1171 if not prerequisites.isdisjoint(allOutputs):
1172 raise ValueError(
1173 "{} marked as both prerequisites and outputs".format(
1174 {dt.name for dt in allOutputs & prerequisites}
1175 )
1176 )
1177 # Make sure that components which are marked as inputs get treated as
1178 # intermediates if there is an output which produces the composite
1179 # containing the component
1180 intermediateComponents = NamedValueSet[DatasetType]()
1181 intermediateComposites = NamedValueSet[DatasetType]()
1182 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
1183 for dsType in allInputs:
1184 # get the name of a possible component
1185 name, component = dsType.nameAndComponent()
1186 # if there is a component name, that means this is a component
1187 # DatasetType, if there is an output which produces the parent of
1188 # this component, treat this input as an intermediate
1189 if component is not None:
1190 # This needs to be in this if block, because someone might have
1191 # a composite that is a pure input from existing data
1192 if name in outputNameMapping:
1193 intermediateComponents.add(dsType)
1194 intermediateComposites.add(outputNameMapping[name])
1196 def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None:
1197 common = a.names & b.names
1198 for name in common:
1199 # Any compatibility is allowed. This function does not know
1200 # if a dataset type is to be used for input or output.
1201 if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])):
1202 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
1204 checkConsistency(allInitInputs, allInitOutputs)
1205 checkConsistency(allInputs, allOutputs)
1206 checkConsistency(allInputs, intermediateComposites)
1207 checkConsistency(allOutputs, intermediateComposites)
1209 def frozen(s: AbstractSet[DatasetType]) -> NamedValueSet[DatasetType]:
1210 assert isinstance(s, NamedValueSet)
1211 s.freeze()
1212 return s
1214 return cls(
1215 initInputs=frozen(allInitInputs - allInitOutputs),
1216 initIntermediates=frozen(allInitInputs & allInitOutputs),
1217 initOutputs=frozen(allInitOutputs - allInitInputs),
1218 inputs=frozen(allInputs - allOutputs - intermediateComponents),
1219 # If there are storage class differences in inputs and outputs
1220 # the intermediates have to choose priority. Here choose that
1221 # inputs to tasks much match the requested storage class by
1222 # applying the inputs over the top of the outputs.
1223 intermediates=frozen(allOutputs & allInputs | intermediateComponents),
1224 outputs=frozen(allOutputs - allInputs - intermediateComposites),
1225 prerequisites=frozen(prerequisites),
1226 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
1227 )
1229 @classmethod
1230 def initOutputNames(
1231 cls,
1232 pipeline: Union[Pipeline, Iterable[TaskDef]],
1233 *,
1234 include_configs: bool = True,
1235 include_packages: bool = True,
1236 ) -> Iterator[str]:
1237 """Return the names of dataset types ot task initOutputs, Configs,
1238 and package versions for a pipeline.
1240 Parameters
1241 ----------
1242 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1243 A `Pipeline` instance or collection of `TaskDef` instances.
1244 include_configs : `bool`, optional
1245 If `True` (default) include config dataset types.
1246 include_packages : `bool`, optional
1247 If `True` (default) include the dataset type for package versions.
1249 Yields
1250 ------
1251 datasetTypeName : `str`
1252 Name of the dataset type.
1253 """
1254 if include_packages:
1255 # Package versions dataset type
1256 yield cls.packagesDatasetName
1258 if isinstance(pipeline, Pipeline):
1259 pipeline = pipeline.toExpandedPipeline()
1261 for taskDef in pipeline:
1263 # all task InitOutputs
1264 for name in taskDef.connections.initOutputs:
1265 attribute = getattr(taskDef.connections, name)
1266 yield attribute.name
1268 # config dataset name
1269 if include_configs:
1270 yield taskDef.configDatasetName