Coverage for python/lsst/pipe/base/pipeline.py: 20%
422 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 14:36 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 14:36 -0800
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28import copy
29import logging
30import os
31import re
32import urllib.parse
33import warnings
35# -------------------------------
36# Imports of standard modules --
37# -------------------------------
38from dataclasses import dataclass
39from types import MappingProxyType
40from typing import (
41 TYPE_CHECKING,
42 AbstractSet,
43 Callable,
44 ClassVar,
45 Dict,
46 Generator,
47 Iterable,
48 Iterator,
49 Mapping,
50 Optional,
51 Set,
52 Tuple,
53 Type,
54 Union,
55 cast,
56)
58# -----------------------------
59# Imports for other modules --
60from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
61from lsst.resources import ResourcePath, ResourcePathExpression
62from lsst.utils import doImportType
63from lsst.utils.introspection import get_full_type_name
65from . import pipelineIR, pipeTools
66from ._task_metadata import TaskMetadata
67from .config import PipelineTaskConfig
68from .configOverrides import ConfigOverrides
69from .connections import iterConnections
70from .pipelineTask import PipelineTask
71from .task import _TASK_METADATA_TYPE
73if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true
74 from lsst.obs.base import Instrument
75 from lsst.pex.config import Config
77# ----------------------------------
78# Local non-exported definitions --
79# ----------------------------------
81_LOG = logging.getLogger(__name__)
83# ------------------------
84# Exported definitions --
85# ------------------------
88@dataclass
89class LabelSpecifier:
90 """A structure to specify a subset of labels to load
92 This structure may contain a set of labels to be used in subsetting a
93 pipeline, or a beginning and end point. Beginning or end may be empty,
94 in which case the range will be a half open interval. Unlike python
95 iteration bounds, end bounds are *INCLUDED*. Note that range based
96 selection is not well defined for pipelines that are not linear in nature,
97 and correct behavior is not guaranteed, or may vary from run to run.
98 """
100 labels: Optional[Set[str]] = None
101 begin: Optional[str] = None
102 end: Optional[str] = None
104 def __post_init__(self) -> None:
105 if self.labels is not None and (self.begin or self.end):
106 raise ValueError(
107 "This struct can only be initialized with a labels set or a begin (and/or) end specifier"
108 )
111class TaskDef:
112 """TaskDef is a collection of information about task needed by Pipeline.
114 The information includes task name, configuration object and optional
115 task class. This class is just a collection of attributes and it exposes
116 all of them so that attributes could potentially be modified in place
117 (e.g. if configuration needs extra overrides).
119 Attributes
120 ----------
121 taskName : `str`, optional
122 The fully-qualified `PipelineTask` class name. If not provided,
123 ``taskClass`` must be.
124 config : `lsst.pipe.base.config.PipelineTaskConfig`, optional
125 Instance of the configuration class corresponding to this task class,
126 usually with all overrides applied. This config will be frozen. If
127 not provided, ``taskClass`` must be provided and
128 ``taskClass.ConfigClass()`` will be used.
129 taskClass : `type`, optional
130 `PipelineTask` class object; if provided and ``taskName`` is as well,
131 the caller guarantees that they are consistent. If not provided,
132 ``taskName`` is used to import the type.
133 label : `str`, optional
134 Task label, usually a short string unique in a pipeline. If not
135 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
136 be used.
137 """
139 def __init__(
140 self,
141 taskName: Optional[str] = None,
142 config: Optional[PipelineTaskConfig] = None,
143 taskClass: Optional[Type[PipelineTask]] = None,
144 label: Optional[str] = None,
145 ):
146 if taskName is None:
147 if taskClass is None:
148 raise ValueError("At least one of `taskName` and `taskClass` must be provided.")
149 taskName = get_full_type_name(taskClass)
150 elif taskClass is None:
151 taskClass = doImportType(taskName)
152 if config is None:
153 if taskClass is None:
154 raise ValueError("`taskClass` must be provided if `config` is not.")
155 config = taskClass.ConfigClass()
156 if label is None:
157 if taskClass is None:
158 raise ValueError("`taskClass` must be provided if `label` is not.")
159 label = taskClass._DefaultName
160 self.taskName = taskName
161 try:
162 config.validate()
163 except Exception:
164 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
165 raise
166 config.freeze()
167 self.config = config
168 self.taskClass = taskClass
169 self.label = label
170 self.connections = config.connections.ConnectionsClass(config=config)
172 @property
173 def configDatasetName(self) -> str:
174 """Name of a dataset type for configuration of this task (`str`)"""
175 return self.label + "_config"
177 @property
178 def metadataDatasetName(self) -> Optional[str]:
179 """Name of a dataset type for metadata of this task, `None` if
180 metadata is not to be saved (`str`)
181 """
182 if self.config.saveMetadata:
183 return self.makeMetadataDatasetName(self.label)
184 else:
185 return None
187 @classmethod
188 def makeMetadataDatasetName(cls, label: str) -> str:
189 """Construct the name of the dataset type for metadata for a task.
191 Parameters
192 ----------
193 label : `str`
194 Label for the task within its pipeline.
196 Returns
197 -------
198 name : `str`
199 Name of the task's metadata dataset type.
200 """
201 return f"{label}_metadata"
203 @property
204 def logOutputDatasetName(self) -> Optional[str]:
205 """Name of a dataset type for log output from this task, `None` if
206 logs are not to be saved (`str`)
207 """
208 if cast(PipelineTaskConfig, self.config).saveLogOutput:
209 return self.label + "_log"
210 else:
211 return None
213 def __str__(self) -> str:
214 rep = "TaskDef(" + self.taskName
215 if self.label:
216 rep += ", label=" + self.label
217 rep += ")"
218 return rep
220 def __eq__(self, other: object) -> bool:
221 if not isinstance(other, TaskDef):
222 return False
223 # This does not consider equality of configs when determining equality
224 # as config equality is a difficult thing to define. Should be updated
225 # after DM-27847
226 return self.taskClass == other.taskClass and self.label == other.label
228 def __hash__(self) -> int:
229 return hash((self.taskClass, self.label))
231 @classmethod
232 def _unreduce(cls, taskName: str, config: PipelineTaskConfig, label: str) -> TaskDef:
233 """Custom callable for unpickling.
235 All arguments are forwarded directly to the constructor; this
236 trampoline is only needed because ``__reduce__`` callables can't be
237 called with keyword arguments.
238 """
239 return cls(taskName=taskName, config=config, label=label)
241 def __reduce__(self) -> Tuple[Callable[[str, PipelineTaskConfig, str], TaskDef], Tuple[str, Config, str]]:
242 return (self._unreduce, (self.taskName, self.config, self.label))
245class Pipeline:
246 """A `Pipeline` is a representation of a series of tasks to run, and the
247 configuration for those tasks.
249 Parameters
250 ----------
251 description : `str`
252 A description of that this pipeline does.
253 """
255 def __init__(self, description: str):
256 pipeline_dict = {"description": description, "tasks": {}}
257 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
259 @classmethod
260 def fromFile(cls, filename: str) -> Pipeline:
261 """Load a pipeline defined in a pipeline yaml file.
263 Parameters
264 ----------
265 filename: `str`
266 A path that points to a pipeline defined in yaml format. This
267 filename may also supply additional labels to be used in
268 subsetting the loaded Pipeline. These labels are separated from
269 the path by a \\#, and may be specified as a comma separated
270 list, or a range denoted as beginning..end. Beginning or end may
271 be empty, in which case the range will be a half open interval.
272 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
273 that range based selection is not well defined for pipelines that
274 are not linear in nature, and correct behavior is not guaranteed,
275 or may vary from run to run.
277 Returns
278 -------
279 pipeline: `Pipeline`
280 The pipeline loaded from specified location with appropriate (if
281 any) subsetting
283 Notes
284 -----
285 This method attempts to prune any contracts that contain labels which
286 are not in the declared subset of labels. This pruning is done using a
287 string based matching due to the nature of contracts and may prune more
288 than it should.
289 """
290 return cls.from_uri(filename)
292 @classmethod
293 def from_uri(cls, uri: ResourcePathExpression) -> Pipeline:
294 """Load a pipeline defined in a pipeline yaml file at a location
295 specified by a URI.
297 Parameters
298 ----------
299 uri: convertible to `ResourcePath`
300 If a string is supplied this should be a URI path that points to a
301 pipeline defined in yaml format, either as a direct path to the
302 yaml file, or as a directory containing a "pipeline.yaml" file (the
303 form used by `write_to_uri` with ``expand=True``). This uri may
304 also supply additional labels to be used in subsetting the loaded
305 Pipeline. These labels are separated from the path by a \\#, and
306 may be specified as a comma separated list, or a range denoted as
307 beginning..end. Beginning or end may be empty, in which case the
308 range will be a half open interval. Unlike python iteration bounds,
309 end bounds are *INCLUDED*. Note that range based selection is not
310 well defined for pipelines that are not linear in nature, and
311 correct behavior is not guaranteed, or may vary from run to run.
312 The same specifiers can be used with a `ResourcePath` object, by
313 being the sole contents in the fragments attribute.
315 Returns
316 -------
317 pipeline: `Pipeline`
318 The pipeline loaded from specified location with appropriate (if
319 any) subsetting
321 Notes
322 -----
323 This method attempts to prune any contracts that contain labels which
324 are not in the declared subset of labels. This pruning is done using a
325 string based matching due to the nature of contracts and may prune more
326 than it should.
327 """
328 # Split up the uri and any labels that were supplied
329 uri, label_specifier = cls._parse_file_specifier(uri)
330 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
332 # If there are labels supplied, only keep those
333 if label_specifier is not None:
334 pipeline = pipeline.subsetFromLabels(label_specifier)
335 return pipeline
337 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
338 """Subset a pipeline to contain only labels specified in labelSpecifier
340 Parameters
341 ----------
342 labelSpecifier : `labelSpecifier`
343 Object containing labels that describes how to subset a pipeline.
345 Returns
346 -------
347 pipeline : `Pipeline`
348 A new pipeline object that is a subset of the old pipeline
350 Raises
351 ------
352 ValueError
353 Raised if there is an issue with specified labels
355 Notes
356 -----
357 This method attempts to prune any contracts that contain labels which
358 are not in the declared subset of labels. This pruning is done using a
359 string based matching due to the nature of contracts and may prune more
360 than it should.
361 """
362 # Labels supplied as a set
363 if labelSpecifier.labels:
364 labelSet = labelSpecifier.labels
365 # Labels supplied as a range, first create a list of all the labels
366 # in the pipeline sorted according to task dependency. Then only
367 # keep labels that lie between the supplied bounds
368 else:
369 # Create a copy of the pipeline to use when assessing the label
370 # ordering. Use a dict for fast searching while preserving order.
371 # Remove contracts so they do not fail in the expansion step. This
372 # is needed because a user may only configure the tasks they intend
373 # to run, which may cause some contracts to fail if they will later
374 # be dropped
375 pipeline = copy.deepcopy(self)
376 pipeline._pipelineIR.contracts = []
377 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
379 # Verify the bounds are in the labels
380 if labelSpecifier.begin is not None:
381 if labelSpecifier.begin not in labels:
382 raise ValueError(
383 f"Beginning of range subset, {labelSpecifier.begin}, not found in "
384 "pipeline definition"
385 )
386 if labelSpecifier.end is not None:
387 if labelSpecifier.end not in labels:
388 raise ValueError(
389 f"End of range subset, {labelSpecifier.end}, not found in pipeline definition"
390 )
392 labelSet = set()
393 for label in labels:
394 if labelSpecifier.begin is not None:
395 if label != labelSpecifier.begin:
396 continue
397 else:
398 labelSpecifier.begin = None
399 labelSet.add(label)
400 if labelSpecifier.end is not None and label == labelSpecifier.end:
401 break
402 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
404 @staticmethod
405 def _parse_file_specifier(uri: ResourcePathExpression) -> Tuple[ResourcePath, Optional[LabelSpecifier]]:
406 """Split appart a uri and any possible label subsets"""
407 if isinstance(uri, str):
408 # This is to support legacy pipelines during transition
409 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
410 if num_replace:
411 warnings.warn(
412 f"The pipeline file {uri} seems to use the legacy : to separate "
413 "labels, this is deprecated and will be removed after June 2021, please use "
414 "# instead.",
415 category=FutureWarning,
416 )
417 if uri.count("#") > 1:
418 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
419 # Everything else can be converted directly to ResourcePath.
420 uri = ResourcePath(uri)
421 label_subset = uri.fragment or None
423 specifier: Optional[LabelSpecifier]
424 if label_subset is not None:
425 label_subset = urllib.parse.unquote(label_subset)
426 args: Dict[str, Union[Set[str], str, None]]
427 # labels supplied as a list
428 if "," in label_subset:
429 if ".." in label_subset:
430 raise ValueError(
431 "Can only specify a list of labels or a rangewhen loading a Pipline not both"
432 )
433 args = {"labels": set(label_subset.split(","))}
434 # labels supplied as a range
435 elif ".." in label_subset:
436 # Try to de-structure the labelSubset, this will fail if more
437 # than one range is specified
438 begin, end, *rest = label_subset.split("..")
439 if rest:
440 raise ValueError("Only one range can be specified when loading a pipeline")
441 args = {"begin": begin if begin else None, "end": end if end else None}
442 # Assume anything else is a single label
443 else:
444 args = {"labels": {label_subset}}
446 # MyPy doesn't like how cavalier kwarg construction is with types.
447 specifier = LabelSpecifier(**args) # type: ignore
448 else:
449 specifier = None
451 return uri, specifier
453 @classmethod
454 def fromString(cls, pipeline_string: str) -> Pipeline:
455 """Create a pipeline from string formatted as a pipeline document.
457 Parameters
458 ----------
459 pipeline_string : `str`
460 A string that is formatted according like a pipeline document
462 Returns
463 -------
464 pipeline: `Pipeline`
465 """
466 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
467 return pipeline
469 @classmethod
470 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
471 """Create a pipeline from an already created `PipelineIR` object.
473 Parameters
474 ----------
475 deserialized_pipeline: `PipelineIR`
476 An already created pipeline intermediate representation object
478 Returns
479 -------
480 pipeline: `Pipeline`
481 """
482 pipeline = cls.__new__(cls)
483 pipeline._pipelineIR = deserialized_pipeline
484 return pipeline
486 @classmethod
487 def fromPipeline(cls, pipeline: Pipeline) -> Pipeline:
488 """Create a new pipeline by copying an already existing `Pipeline`.
490 Parameters
491 ----------
492 pipeline: `Pipeline`
493 An already created pipeline intermediate representation object
495 Returns
496 -------
497 pipeline: `Pipeline`
498 """
499 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
501 def __str__(self) -> str:
502 return str(self._pipelineIR)
504 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
505 """Add an instrument to the pipeline, or replace an instrument that is
506 already defined.
508 Parameters
509 ----------
510 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
511 Either a derived class object of a `lsst.daf.butler.instrument` or
512 a string corresponding to a fully qualified
513 `lsst.daf.butler.instrument` name.
514 """
515 if isinstance(instrument, str):
516 pass
517 else:
518 # TODO: assume that this is a subclass of Instrument, no type
519 # checking
520 instrument = get_full_type_name(instrument)
521 self._pipelineIR.instrument = instrument
523 def getInstrument(self) -> Optional[str]:
524 """Get the instrument from the pipeline.
526 Returns
527 -------
528 instrument : `str`, or None
529 The fully qualified name of a `lsst.obs.base.Instrument` subclass,
530 name, or None if the pipeline does not have an instrument.
531 """
532 return self._pipelineIR.instrument
534 def addTask(self, task: Union[Type[PipelineTask], str], label: str) -> None:
535 """Add a new task to the pipeline, or replace a task that is already
536 associated with the supplied label.
538 Parameters
539 ----------
540 task: `PipelineTask` or `str`
541 Either a derived class object of a `PipelineTask` or a string
542 corresponding to a fully qualified `PipelineTask` name.
543 label: `str`
544 A label that is used to identify the `PipelineTask` being added
545 """
546 if isinstance(task, str):
547 taskName = task
548 elif issubclass(task, PipelineTask):
549 taskName = get_full_type_name(task)
550 else:
551 raise ValueError(
552 "task must be either a child class of PipelineTask or a string containing"
553 " a fully qualified name to one"
554 )
555 if not label:
556 # in some cases (with command line-generated pipeline) tasks can
557 # be defined without label which is not acceptable, use task
558 # _DefaultName in that case
559 if isinstance(task, str):
560 task_class = doImportType(task)
561 label = task_class._DefaultName
562 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
564 def removeTask(self, label: str) -> None:
565 """Remove a task from the pipeline.
567 Parameters
568 ----------
569 label : `str`
570 The label used to identify the task that is to be removed
572 Raises
573 ------
574 KeyError
575 If no task with that label exists in the pipeline
577 """
578 self._pipelineIR.tasks.pop(label)
580 def addConfigOverride(self, label: str, key: str, value: object) -> None:
581 """Apply single config override.
583 Parameters
584 ----------
585 label : `str`
586 Label of the task.
587 key: `str`
588 Fully-qualified field name.
589 value : object
590 Value to be given to a field.
591 """
592 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
594 def addConfigFile(self, label: str, filename: str) -> None:
595 """Add overrides from a specified file.
597 Parameters
598 ----------
599 label : `str`
600 The label used to identify the task associated with config to
601 modify
602 filename : `str`
603 Path to the override file.
604 """
605 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
607 def addConfigPython(self, label: str, pythonString: str) -> None:
608 """Add Overrides by running a snippet of python code against a config.
610 Parameters
611 ----------
612 label : `str`
613 The label used to identity the task associated with config to
614 modify.
615 pythonString: `str`
616 A string which is valid python code to be executed. This is done
617 with config as the only local accessible value.
618 """
619 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
621 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
622 if label == "parameters":
623 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
624 raise ValueError("Cannot override parameters that are not defined in pipeline")
625 self._pipelineIR.parameters.mapping.update(newConfig.rest)
626 if newConfig.file:
627 raise ValueError("Setting parameters section with config file is not supported")
628 if newConfig.python:
629 raise ValueError("Setting parameters section using python block in unsupported")
630 return
631 if label not in self._pipelineIR.tasks:
632 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
633 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
635 def toFile(self, filename: str) -> None:
636 self._pipelineIR.to_file(filename)
638 def write_to_uri(self, uri: ResourcePathExpression) -> None:
639 """Write the pipeline to a file or directory.
641 Parameters
642 ----------
643 uri : convertible to `ResourcePath`
644 URI to write to; may have any scheme with `ResourcePath` write
645 support or no scheme for a local file/directory. Should have a
646 ``.yaml``.
647 """
648 self._pipelineIR.write_to_uri(uri)
650 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
651 """Returns a generator of TaskDefs which can be used to create quantum
652 graphs.
654 Returns
655 -------
656 generator : generator of `TaskDef`
657 The generator returned will be the sorted iterator of tasks which
658 are to be used in constructing a quantum graph.
660 Raises
661 ------
662 NotImplementedError
663 If a dataId is supplied in a config block. This is in place for
664 future use
665 """
666 taskDefs = []
667 for label in self._pipelineIR.tasks:
668 taskDefs.append(self._buildTaskDef(label))
670 # lets evaluate the contracts
671 if self._pipelineIR.contracts is not None:
672 label_to_config = {x.label: x.config for x in taskDefs}
673 for contract in self._pipelineIR.contracts:
674 # execute this in its own line so it can raise a good error
675 # message if there was problems with the eval
676 success = eval(contract.contract, None, label_to_config)
677 if not success:
678 extra_info = f": {contract.msg}" if contract.msg is not None else ""
679 raise pipelineIR.ContractError(
680 f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
681 )
683 taskDefs = sorted(taskDefs, key=lambda x: x.label)
684 yield from pipeTools.orderPipeline(taskDefs)
686 def _buildTaskDef(self, label: str) -> TaskDef:
687 if (taskIR := self._pipelineIR.tasks.get(label)) is None:
688 raise NameError(f"Label {label} does not appear in this pipeline")
689 taskClass: Type[PipelineTask] = doImportType(taskIR.klass)
690 taskName = get_full_type_name(taskClass)
691 config = taskClass.ConfigClass()
692 overrides = ConfigOverrides()
693 if self._pipelineIR.instrument is not None:
694 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
695 if taskIR.config is not None:
696 for configIR in (configIr.formatted(self._pipelineIR.parameters) for configIr in taskIR.config):
697 if configIR.dataId is not None:
698 raise NotImplementedError(
699 "Specializing a config on a partial data id is not yet "
700 "supported in Pipeline definition"
701 )
702 # only apply override if it applies to everything
703 if configIR.dataId is None:
704 if configIR.file:
705 for configFile in configIR.file:
706 overrides.addFileOverride(os.path.expandvars(configFile))
707 if configIR.python is not None:
708 overrides.addPythonOverride(configIR.python)
709 for key, value in configIR.rest.items():
710 overrides.addValueOverride(key, value)
711 overrides.applyTo(config)
712 return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)
714 def __iter__(self) -> Generator[TaskDef, None, None]:
715 return self.toExpandedPipeline()
717 def __getitem__(self, item: str) -> TaskDef:
718 return self._buildTaskDef(item)
720 def __len__(self) -> int:
721 return len(self._pipelineIR.tasks)
723 def __eq__(self, other: object) -> bool:
724 if not isinstance(other, Pipeline):
725 return False
726 elif self._pipelineIR == other._pipelineIR:
727 # Shortcut: if the IR is the same, the expanded pipeline must be
728 # the same as well. But the converse is not true.
729 return True
730 else:
731 self_expanded = {td.label: (td.taskClass,) for td in self}
732 other_expanded = {td.label: (td.taskClass,) for td in other}
733 if self_expanded != other_expanded:
734 return False
735 # After DM-27847, we should compare configuration here, or better,
736 # delegated to TaskDef.__eq__ after making that compare configurations.
737 raise NotImplementedError(
738 "Pipelines cannot be compared because config instances cannot be compared; see DM-27847."
739 )
742@dataclass(frozen=True)
743class TaskDatasetTypes:
744 """An immutable struct that extracts and classifies the dataset types used
745 by a `PipelineTask`
746 """
748 initInputs: NamedValueSet[DatasetType]
749 """Dataset types that are needed as inputs in order to construct this Task.
751 Task-level `initInputs` may be classified as either
752 `~PipelineDatasetTypes.initInputs` or
753 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
754 """
756 initOutputs: NamedValueSet[DatasetType]
757 """Dataset types that may be written after constructing this Task.
759 Task-level `initOutputs` may be classified as either
760 `~PipelineDatasetTypes.initOutputs` or
761 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
762 """
764 inputs: NamedValueSet[DatasetType]
765 """Dataset types that are regular inputs to this Task.
767 If an input dataset needed for a Quantum cannot be found in the input
768 collection(s) or produced by another Task in the Pipeline, that Quantum
769 (and all dependent Quanta) will not be produced.
771 Task-level `inputs` may be classified as either
772 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
773 at the Pipeline level.
774 """
776 prerequisites: NamedValueSet[DatasetType]
777 """Dataset types that are prerequisite inputs to this Task.
779 Prerequisite inputs must exist in the input collection(s) before the
780 pipeline is run, but do not constrain the graph - if a prerequisite is
781 missing for a Quantum, `PrerequisiteMissingError` is raised.
783 Prerequisite inputs are not resolved until the second stage of
784 QuantumGraph generation.
785 """
787 outputs: NamedValueSet[DatasetType]
788 """Dataset types that are produced by this Task.
790 Task-level `outputs` may be classified as either
791 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
792 at the Pipeline level.
793 """
795 @classmethod
796 def fromTaskDef(
797 cls,
798 taskDef: TaskDef,
799 *,
800 registry: Registry,
801 include_configs: bool = True,
802 storage_class_mapping: Optional[Mapping[str, str]] = None,
803 ) -> TaskDatasetTypes:
804 """Extract and classify the dataset types from a single `PipelineTask`.
806 Parameters
807 ----------
808 taskDef: `TaskDef`
809 An instance of a `TaskDef` class for a particular `PipelineTask`.
810 registry: `Registry`
811 Registry used to construct normalized `DatasetType` objects and
812 retrieve those that are incomplete.
813 include_configs : `bool`, optional
814 If `True` (default) include config dataset types as
815 ``initOutputs``.
816 storage_class_mapping : `Mapping` of `str` to `StorageClass`, optional
817 If a taskdef contains a component dataset type that is unknown
818 to the registry, its parent StorageClass will be looked up in this
819 mapping if it is supplied. If the mapping does not contain the
820 composite dataset type, or the mapping is not supplied an exception
821 will be raised.
823 Returns
824 -------
825 types: `TaskDatasetTypes`
826 The dataset types used by this task.
828 Raises
829 ------
830 ValueError
831 Raised if dataset type connection definition differs from
832 registry definition.
833 LookupError
834 Raised if component parent StorageClass could not be determined
835 and storage_class_mapping does not contain the composite type, or
836 is set to None.
837 """
839 def makeDatasetTypesSet(
840 connectionType: str,
841 is_input: bool,
842 freeze: bool = True,
843 ) -> NamedValueSet[DatasetType]:
844 """Constructs a set of true `DatasetType` objects
846 Parameters
847 ----------
848 connectionType : `str`
849 Name of the connection type to produce a set for, corresponds
850 to an attribute of type `list` on the connection class instance
851 is_input : `bool`
852 These are input dataset types, else they are output dataset
853 types.
854 freeze : `bool`, optional
855 If `True`, call `NamedValueSet.freeze` on the object returned.
857 Returns
858 -------
859 datasetTypes : `NamedValueSet`
860 A set of all datasetTypes which correspond to the input
861 connection type specified in the connection class of this
862 `PipelineTask`
864 Raises
865 ------
866 ValueError
867 Raised if dataset type connection definition differs from
868 registry definition.
869 LookupError
870 Raised if component parent StorageClass could not be determined
871 and storage_class_mapping does not contain the composite type,
872 or is set to None.
874 Notes
875 -----
876 This function is a closure over the variables ``registry`` and
877 ``taskDef``, and ``storage_class_mapping``.
878 """
879 datasetTypes = NamedValueSet[DatasetType]()
880 for c in iterConnections(taskDef.connections, connectionType):
881 dimensions = set(getattr(c, "dimensions", set()))
882 if "skypix" in dimensions:
883 try:
884 datasetType = registry.getDatasetType(c.name)
885 except LookupError as err:
886 raise LookupError(
887 f"DatasetType '{c.name}' referenced by "
888 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
889 f"placeholder, but does not already exist in the registry. "
890 f"Note that reference catalog names are now used as the dataset "
891 f"type name instead of 'ref_cat'."
892 ) from err
893 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
894 rest2 = set(
895 dim.name for dim in datasetType.dimensions if not isinstance(dim, SkyPixDimension)
896 )
897 if rest1 != rest2:
898 raise ValueError(
899 f"Non-skypix dimensions for dataset type {c.name} declared in "
900 f"connections ({rest1}) are inconsistent with those in "
901 f"registry's version of this dataset ({rest2})."
902 )
903 else:
904 # Component dataset types are not explicitly in the
905 # registry. This complicates consistency checks with
906 # registry and requires we work out the composite storage
907 # class.
908 registryDatasetType = None
909 try:
910 registryDatasetType = registry.getDatasetType(c.name)
911 except KeyError:
912 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
913 if componentName:
914 if storage_class_mapping is None or compositeName not in storage_class_mapping:
915 raise LookupError(
916 "Component parent class cannot be determined, and "
917 "composite name was not in storage class mapping, or no "
918 "storage_class_mapping was supplied"
919 )
920 else:
921 parentStorageClass = storage_class_mapping[compositeName]
922 else:
923 parentStorageClass = None
924 datasetType = c.makeDatasetType(
925 registry.dimensions, parentStorageClass=parentStorageClass
926 )
927 registryDatasetType = datasetType
928 else:
929 datasetType = c.makeDatasetType(
930 registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass
931 )
933 if registryDatasetType and datasetType != registryDatasetType:
934 # The dataset types differ but first check to see if
935 # they are compatible before raising.
936 if is_input:
937 # This DatasetType must be compatible on get.
938 is_compatible = datasetType.is_compatible_with(registryDatasetType)
939 else:
940 # Has to be able to be converted to expect type
941 # on put.
942 is_compatible = registryDatasetType.is_compatible_with(datasetType)
943 if is_compatible:
944 # For inputs we want the pipeline to use the
945 # pipeline definition, for outputs it should use
946 # the registry definition.
947 if not is_input:
948 datasetType = registryDatasetType
949 _LOG.debug(
950 "Dataset types differ (task %s != registry %s) but are compatible"
951 " for %s in %s.",
952 datasetType,
953 registryDatasetType,
954 "input" if is_input else "output",
955 taskDef.label,
956 )
957 else:
958 try:
959 # Explicitly check for storage class just to
960 # make more specific message.
961 _ = datasetType.storageClass
962 except KeyError:
963 raise ValueError(
964 "Storage class does not exist for supplied dataset type "
965 f"{datasetType} for {taskDef.label}."
966 ) from None
967 raise ValueError(
968 f"Supplied dataset type ({datasetType}) inconsistent with "
969 f"registry definition ({registryDatasetType}) "
970 f"for {taskDef.label}."
971 )
972 datasetTypes.add(datasetType)
973 if freeze:
974 datasetTypes.freeze()
975 return datasetTypes
977 # optionally add initOutput dataset for config
978 initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False)
979 if include_configs:
980 initOutputs.add(
981 DatasetType(
982 taskDef.configDatasetName,
983 registry.dimensions.empty,
984 storageClass="Config",
985 )
986 )
987 initOutputs.freeze()
989 # optionally add output dataset for metadata
990 outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False)
991 if taskDef.metadataDatasetName is not None:
992 # Metadata is supposed to be of the TaskMetadata type, its
993 # dimensions correspond to a task quantum.
994 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
996 # Allow the storage class definition to be read from the existing
997 # dataset type definition if present.
998 try:
999 current = registry.getDatasetType(taskDef.metadataDatasetName)
1000 except KeyError:
1001 # No previous definition so use the default.
1002 storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet"
1003 else:
1004 storageClass = current.storageClass.name
1006 outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
1007 if taskDef.logOutputDatasetName is not None:
1008 # Log output dimensions correspond to a task quantum.
1009 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
1010 outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")})
1012 outputs.freeze()
1014 return cls(
1015 initInputs=makeDatasetTypesSet("initInputs", is_input=True),
1016 initOutputs=initOutputs,
1017 inputs=makeDatasetTypesSet("inputs", is_input=True),
1018 prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True),
1019 outputs=outputs,
1020 )
1023@dataclass(frozen=True)
1024class PipelineDatasetTypes:
1025 """An immutable struct that classifies the dataset types used in a
1026 `Pipeline`.
1027 """
1029 packagesDatasetName: ClassVar[str] = "packages"
1030 """Name of a dataset type used to save package versions.
1031 """
1033 initInputs: NamedValueSet[DatasetType]
1034 """Dataset types that are needed as inputs in order to construct the Tasks
1035 in this Pipeline.
1037 This does not include dataset types that are produced when constructing
1038 other Tasks in the Pipeline (these are classified as `initIntermediates`).
1039 """
1041 initOutputs: NamedValueSet[DatasetType]
1042 """Dataset types that may be written after constructing the Tasks in this
1043 Pipeline.
1045 This does not include dataset types that are also used as inputs when
1046 constructing other Tasks in the Pipeline (these are classified as
1047 `initIntermediates`).
1048 """
1050 initIntermediates: NamedValueSet[DatasetType]
1051 """Dataset types that are both used when constructing one or more Tasks
1052 in the Pipeline and produced as a side-effect of constructing another
1053 Task in the Pipeline.
1054 """
1056 inputs: NamedValueSet[DatasetType]
1057 """Dataset types that are regular inputs for the full pipeline.
1059 If an input dataset needed for a Quantum cannot be found in the input
1060 collection(s), that Quantum (and all dependent Quanta) will not be
1061 produced.
1062 """
1064 prerequisites: NamedValueSet[DatasetType]
1065 """Dataset types that are prerequisite inputs for the full Pipeline.
1067 Prerequisite inputs must exist in the input collection(s) before the
1068 pipeline is run, but do not constrain the graph - if a prerequisite is
1069 missing for a Quantum, `PrerequisiteMissingError` is raised.
1071 Prerequisite inputs are not resolved until the second stage of
1072 QuantumGraph generation.
1073 """
1075 intermediates: NamedValueSet[DatasetType]
1076 """Dataset types that are output by one Task in the Pipeline and consumed
1077 as inputs by one or more other Tasks in the Pipeline.
1078 """
1080 outputs: NamedValueSet[DatasetType]
1081 """Dataset types that are output by a Task in the Pipeline and not consumed
1082 by any other Task in the Pipeline.
1083 """
1085 byTask: Mapping[str, TaskDatasetTypes]
1086 """Per-Task dataset types, keyed by label in the `Pipeline`.
1088 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
1089 neither has been modified since the dataset types were extracted, of
1090 course).
1091 """
1093 @classmethod
1094 def fromPipeline(
1095 cls,
1096 pipeline: Union[Pipeline, Iterable[TaskDef]],
1097 *,
1098 registry: Registry,
1099 include_configs: bool = True,
1100 include_packages: bool = True,
1101 ) -> PipelineDatasetTypes:
1102 """Extract and classify the dataset types from all tasks in a
1103 `Pipeline`.
1105 Parameters
1106 ----------
1107 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1108 A collection of tasks that can be run together.
1109 registry: `Registry`
1110 Registry used to construct normalized `DatasetType` objects and
1111 retrieve those that are incomplete.
1112 include_configs : `bool`, optional
1113 If `True` (default) include config dataset types as
1114 ``initOutputs``.
1115 include_packages : `bool`, optional
1116 If `True` (default) include the dataset type for software package
1117 versions in ``initOutputs``.
1119 Returns
1120 -------
1121 types: `PipelineDatasetTypes`
1122 The dataset types used by this `Pipeline`.
1124 Raises
1125 ------
1126 ValueError
1127 Raised if Tasks are inconsistent about which datasets are marked
1128 prerequisite. This indicates that the Tasks cannot be run as part
1129 of the same `Pipeline`.
1130 """
1131 allInputs = NamedValueSet[DatasetType]()
1132 allOutputs = NamedValueSet[DatasetType]()
1133 allInitInputs = NamedValueSet[DatasetType]()
1134 allInitOutputs = NamedValueSet[DatasetType]()
1135 prerequisites = NamedValueSet[DatasetType]()
1136 byTask = dict()
1137 if include_packages:
1138 allInitOutputs.add(
1139 DatasetType(
1140 cls.packagesDatasetName,
1141 registry.dimensions.empty,
1142 storageClass="Packages",
1143 )
1144 )
1145 # create a list of TaskDefs in case the input is a generator
1146 pipeline = list(pipeline)
1148 # collect all the output dataset types
1149 typeStorageclassMap: Dict[str, str] = {}
1150 for taskDef in pipeline:
1151 for outConnection in iterConnections(taskDef.connections, "outputs"):
1152 typeStorageclassMap[outConnection.name] = outConnection.storageClass
1154 for taskDef in pipeline:
1155 thisTask = TaskDatasetTypes.fromTaskDef(
1156 taskDef,
1157 registry=registry,
1158 include_configs=include_configs,
1159 storage_class_mapping=typeStorageclassMap,
1160 )
1161 allInitInputs.update(thisTask.initInputs)
1162 allInitOutputs.update(thisTask.initOutputs)
1163 allInputs.update(thisTask.inputs)
1164 prerequisites.update(thisTask.prerequisites)
1165 allOutputs.update(thisTask.outputs)
1166 byTask[taskDef.label] = thisTask
1167 if not prerequisites.isdisjoint(allInputs):
1168 raise ValueError(
1169 "{} marked as both prerequisites and regular inputs".format(
1170 {dt.name for dt in allInputs & prerequisites}
1171 )
1172 )
1173 if not prerequisites.isdisjoint(allOutputs):
1174 raise ValueError(
1175 "{} marked as both prerequisites and outputs".format(
1176 {dt.name for dt in allOutputs & prerequisites}
1177 )
1178 )
1179 # Make sure that components which are marked as inputs get treated as
1180 # intermediates if there is an output which produces the composite
1181 # containing the component
1182 intermediateComponents = NamedValueSet[DatasetType]()
1183 intermediateComposites = NamedValueSet[DatasetType]()
1184 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
1185 for dsType in allInputs:
1186 # get the name of a possible component
1187 name, component = dsType.nameAndComponent()
1188 # if there is a component name, that means this is a component
1189 # DatasetType, if there is an output which produces the parent of
1190 # this component, treat this input as an intermediate
1191 if component is not None:
1192 # This needs to be in this if block, because someone might have
1193 # a composite that is a pure input from existing data
1194 if name in outputNameMapping:
1195 intermediateComponents.add(dsType)
1196 intermediateComposites.add(outputNameMapping[name])
1198 def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None:
1199 common = a.names & b.names
1200 for name in common:
1201 # Any compatibility is allowed. This function does not know
1202 # if a dataset type is to be used for input or output.
1203 if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])):
1204 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
1206 checkConsistency(allInitInputs, allInitOutputs)
1207 checkConsistency(allInputs, allOutputs)
1208 checkConsistency(allInputs, intermediateComposites)
1209 checkConsistency(allOutputs, intermediateComposites)
1211 def frozen(s: AbstractSet[DatasetType]) -> NamedValueSet[DatasetType]:
1212 assert isinstance(s, NamedValueSet)
1213 s.freeze()
1214 return s
1216 return cls(
1217 initInputs=frozen(allInitInputs - allInitOutputs),
1218 initIntermediates=frozen(allInitInputs & allInitOutputs),
1219 initOutputs=frozen(allInitOutputs - allInitInputs),
1220 inputs=frozen(allInputs - allOutputs - intermediateComponents),
1221 # If there are storage class differences in inputs and outputs
1222 # the intermediates have to choose priority. Here choose that
1223 # inputs to tasks much match the requested storage class by
1224 # applying the inputs over the top of the outputs.
1225 intermediates=frozen(allOutputs & allInputs | intermediateComponents),
1226 outputs=frozen(allOutputs - allInputs - intermediateComposites),
1227 prerequisites=frozen(prerequisites),
1228 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
1229 )
1231 @classmethod
1232 def initOutputNames(
1233 cls,
1234 pipeline: Union[Pipeline, Iterable[TaskDef]],
1235 *,
1236 include_configs: bool = True,
1237 include_packages: bool = True,
1238 ) -> Iterator[str]:
1239 """Return the names of dataset types ot task initOutputs, Configs,
1240 and package versions for a pipeline.
1242 Parameters
1243 ----------
1244 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1245 A `Pipeline` instance or collection of `TaskDef` instances.
1246 include_configs : `bool`, optional
1247 If `True` (default) include config dataset types.
1248 include_packages : `bool`, optional
1249 If `True` (default) include the dataset type for package versions.
1251 Yields
1252 ------
1253 datasetTypeName : `str`
1254 Name of the dataset type.
1255 """
1256 if include_packages:
1257 # Package versions dataset type
1258 yield cls.packagesDatasetName
1260 if isinstance(pipeline, Pipeline):
1261 pipeline = pipeline.toExpandedPipeline()
1263 for taskDef in pipeline:
1265 # all task InitOutputs
1266 for name in taskDef.connections.initOutputs:
1267 attribute = getattr(taskDef.connections, name)
1268 yield attribute.name
1270 # config dataset name
1271 if include_configs:
1272 yield taskDef.configDatasetName