Coverage for python/lsst/pipe/base/pipeline.py: 20%
432 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:47 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:47 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Module defining Pipeline class and related methods.
24"""
26__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
28import copy
29import logging
30import os
31import re
32import urllib.parse
33import warnings
35# -------------------------------
36# Imports of standard modules --
37# -------------------------------
38from dataclasses import dataclass
39from types import MappingProxyType
40from typing import (
41 TYPE_CHECKING,
42 AbstractSet,
43 Callable,
44 ClassVar,
45 Dict,
46 Generator,
47 Iterable,
48 Iterator,
49 Mapping,
50 Optional,
51 Set,
52 Tuple,
53 Type,
54 Union,
55 cast,
56)
58# -----------------------------
59# Imports for other modules --
60from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
61from lsst.resources import ResourcePath, ResourcePathExpression
62from lsst.utils import doImportType
63from lsst.utils.introspection import get_full_type_name
65from . import pipelineIR, pipeTools
66from ._task_metadata import TaskMetadata
67from .config import PipelineTaskConfig
68from .configOverrides import ConfigOverrides
69from .connections import iterConnections
70from .connectionTypes import Input
71from .pipelineTask import PipelineTask
72from .task import _TASK_METADATA_TYPE
74if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. 74 ↛ 75line 74 didn't jump to line 75, because the condition on line 74 was never true
75 from lsst.obs.base import Instrument
76 from lsst.pex.config import Config
78# ----------------------------------
79# Local non-exported definitions --
80# ----------------------------------
82_LOG = logging.getLogger(__name__)
84# ------------------------
85# Exported definitions --
86# ------------------------
89@dataclass
90class LabelSpecifier:
91 """A structure to specify a subset of labels to load
93 This structure may contain a set of labels to be used in subsetting a
94 pipeline, or a beginning and end point. Beginning or end may be empty,
95 in which case the range will be a half open interval. Unlike python
96 iteration bounds, end bounds are *INCLUDED*. Note that range based
97 selection is not well defined for pipelines that are not linear in nature,
98 and correct behavior is not guaranteed, or may vary from run to run.
99 """
101 labels: Optional[Set[str]] = None
102 begin: Optional[str] = None
103 end: Optional[str] = None
105 def __post_init__(self) -> None:
106 if self.labels is not None and (self.begin or self.end):
107 raise ValueError(
108 "This struct can only be initialized with a labels set or a begin (and/or) end specifier"
109 )
112class TaskDef:
113 """TaskDef is a collection of information about task needed by Pipeline.
115 The information includes task name, configuration object and optional
116 task class. This class is just a collection of attributes and it exposes
117 all of them so that attributes could potentially be modified in place
118 (e.g. if configuration needs extra overrides).
120 Attributes
121 ----------
122 taskName : `str`, optional
123 The fully-qualified `PipelineTask` class name. If not provided,
124 ``taskClass`` must be.
125 config : `lsst.pipe.base.config.PipelineTaskConfig`, optional
126 Instance of the configuration class corresponding to this task class,
127 usually with all overrides applied. This config will be frozen. If
128 not provided, ``taskClass`` must be provided and
129 ``taskClass.ConfigClass()`` will be used.
130 taskClass : `type`, optional
131 `PipelineTask` class object; if provided and ``taskName`` is as well,
132 the caller guarantees that they are consistent. If not provided,
133 ``taskName`` is used to import the type.
134 label : `str`, optional
135 Task label, usually a short string unique in a pipeline. If not
136 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
137 be used.
138 """
140 def __init__(
141 self,
142 taskName: Optional[str] = None,
143 config: Optional[PipelineTaskConfig] = None,
144 taskClass: Optional[Type[PipelineTask]] = None,
145 label: Optional[str] = None,
146 ):
147 if taskName is None:
148 if taskClass is None:
149 raise ValueError("At least one of `taskName` and `taskClass` must be provided.")
150 taskName = get_full_type_name(taskClass)
151 elif taskClass is None:
152 taskClass = doImportType(taskName)
153 if config is None:
154 if taskClass is None:
155 raise ValueError("`taskClass` must be provided if `config` is not.")
156 config = taskClass.ConfigClass()
157 if label is None:
158 if taskClass is None:
159 raise ValueError("`taskClass` must be provided if `label` is not.")
160 label = taskClass._DefaultName
161 self.taskName = taskName
162 try:
163 config.validate()
164 except Exception:
165 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
166 raise
167 config.freeze()
168 self.config = config
169 self.taskClass = taskClass
170 self.label = label
171 self.connections = config.connections.ConnectionsClass(config=config)
173 @property
174 def configDatasetName(self) -> str:
175 """Name of a dataset type for configuration of this task (`str`)"""
176 return self.label + "_config"
178 @property
179 def metadataDatasetName(self) -> Optional[str]:
180 """Name of a dataset type for metadata of this task, `None` if
181 metadata is not to be saved (`str`)
182 """
183 if self.config.saveMetadata:
184 return self.makeMetadataDatasetName(self.label)
185 else:
186 return None
188 @classmethod
189 def makeMetadataDatasetName(cls, label: str) -> str:
190 """Construct the name of the dataset type for metadata for a task.
192 Parameters
193 ----------
194 label : `str`
195 Label for the task within its pipeline.
197 Returns
198 -------
199 name : `str`
200 Name of the task's metadata dataset type.
201 """
202 return f"{label}_metadata"
204 @property
205 def logOutputDatasetName(self) -> Optional[str]:
206 """Name of a dataset type for log output from this task, `None` if
207 logs are not to be saved (`str`)
208 """
209 if cast(PipelineTaskConfig, self.config).saveLogOutput:
210 return self.label + "_log"
211 else:
212 return None
214 def __str__(self) -> str:
215 rep = "TaskDef(" + self.taskName
216 if self.label:
217 rep += ", label=" + self.label
218 rep += ")"
219 return rep
221 def __eq__(self, other: object) -> bool:
222 if not isinstance(other, TaskDef):
223 return False
224 # This does not consider equality of configs when determining equality
225 # as config equality is a difficult thing to define. Should be updated
226 # after DM-27847
227 return self.taskClass == other.taskClass and self.label == other.label
229 def __hash__(self) -> int:
230 return hash((self.taskClass, self.label))
232 @classmethod
233 def _unreduce(cls, taskName: str, config: PipelineTaskConfig, label: str) -> TaskDef:
234 """Custom callable for unpickling.
236 All arguments are forwarded directly to the constructor; this
237 trampoline is only needed because ``__reduce__`` callables can't be
238 called with keyword arguments.
239 """
240 return cls(taskName=taskName, config=config, label=label)
242 def __reduce__(self) -> Tuple[Callable[[str, PipelineTaskConfig, str], TaskDef], Tuple[str, Config, str]]:
243 return (self._unreduce, (self.taskName, self.config, self.label))
246class Pipeline:
247 """A `Pipeline` is a representation of a series of tasks to run, and the
248 configuration for those tasks.
250 Parameters
251 ----------
252 description : `str`
253 A description of that this pipeline does.
254 """
256 def __init__(self, description: str):
257 pipeline_dict = {"description": description, "tasks": {}}
258 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
260 @classmethod
261 def fromFile(cls, filename: str) -> Pipeline:
262 """Load a pipeline defined in a pipeline yaml file.
264 Parameters
265 ----------
266 filename: `str`
267 A path that points to a pipeline defined in yaml format. This
268 filename may also supply additional labels to be used in
269 subsetting the loaded Pipeline. These labels are separated from
270 the path by a \\#, and may be specified as a comma separated
271 list, or a range denoted as beginning..end. Beginning or end may
272 be empty, in which case the range will be a half open interval.
273 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
274 that range based selection is not well defined for pipelines that
275 are not linear in nature, and correct behavior is not guaranteed,
276 or may vary from run to run.
278 Returns
279 -------
280 pipeline: `Pipeline`
281 The pipeline loaded from specified location with appropriate (if
282 any) subsetting
284 Notes
285 -----
286 This method attempts to prune any contracts that contain labels which
287 are not in the declared subset of labels. This pruning is done using a
288 string based matching due to the nature of contracts and may prune more
289 than it should.
290 """
291 return cls.from_uri(filename)
293 @classmethod
294 def from_uri(cls, uri: ResourcePathExpression) -> Pipeline:
295 """Load a pipeline defined in a pipeline yaml file at a location
296 specified by a URI.
298 Parameters
299 ----------
300 uri: convertible to `ResourcePath`
301 If a string is supplied this should be a URI path that points to a
302 pipeline defined in yaml format, either as a direct path to the
303 yaml file, or as a directory containing a "pipeline.yaml" file (the
304 form used by `write_to_uri` with ``expand=True``). This uri may
305 also supply additional labels to be used in subsetting the loaded
306 Pipeline. These labels are separated from the path by a \\#, and
307 may be specified as a comma separated list, or a range denoted as
308 beginning..end. Beginning or end may be empty, in which case the
309 range will be a half open interval. Unlike python iteration bounds,
310 end bounds are *INCLUDED*. Note that range based selection is not
311 well defined for pipelines that are not linear in nature, and
312 correct behavior is not guaranteed, or may vary from run to run.
313 The same specifiers can be used with a `ResourcePath` object, by
314 being the sole contents in the fragments attribute.
316 Returns
317 -------
318 pipeline: `Pipeline`
319 The pipeline loaded from specified location with appropriate (if
320 any) subsetting
322 Notes
323 -----
324 This method attempts to prune any contracts that contain labels which
325 are not in the declared subset of labels. This pruning is done using a
326 string based matching due to the nature of contracts and may prune more
327 than it should.
328 """
329 # Split up the uri and any labels that were supplied
330 uri, label_specifier = cls._parse_file_specifier(uri)
331 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
333 # If there are labels supplied, only keep those
334 if label_specifier is not None:
335 pipeline = pipeline.subsetFromLabels(label_specifier)
336 return pipeline
338 def subsetFromLabels(self, labelSpecifier: LabelSpecifier) -> Pipeline:
339 """Subset a pipeline to contain only labels specified in labelSpecifier
341 Parameters
342 ----------
343 labelSpecifier : `labelSpecifier`
344 Object containing labels that describes how to subset a pipeline.
346 Returns
347 -------
348 pipeline : `Pipeline`
349 A new pipeline object that is a subset of the old pipeline
351 Raises
352 ------
353 ValueError
354 Raised if there is an issue with specified labels
356 Notes
357 -----
358 This method attempts to prune any contracts that contain labels which
359 are not in the declared subset of labels. This pruning is done using a
360 string based matching due to the nature of contracts and may prune more
361 than it should.
362 """
363 # Labels supplied as a set
364 if labelSpecifier.labels:
365 labelSet = labelSpecifier.labels
366 # Labels supplied as a range, first create a list of all the labels
367 # in the pipeline sorted according to task dependency. Then only
368 # keep labels that lie between the supplied bounds
369 else:
370 # Create a copy of the pipeline to use when assessing the label
371 # ordering. Use a dict for fast searching while preserving order.
372 # Remove contracts so they do not fail in the expansion step. This
373 # is needed because a user may only configure the tasks they intend
374 # to run, which may cause some contracts to fail if they will later
375 # be dropped
376 pipeline = copy.deepcopy(self)
377 pipeline._pipelineIR.contracts = []
378 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
380 # Verify the bounds are in the labels
381 if labelSpecifier.begin is not None:
382 if labelSpecifier.begin not in labels:
383 raise ValueError(
384 f"Beginning of range subset, {labelSpecifier.begin}, not found in "
385 "pipeline definition"
386 )
387 if labelSpecifier.end is not None:
388 if labelSpecifier.end not in labels:
389 raise ValueError(
390 f"End of range subset, {labelSpecifier.end}, not found in pipeline definition"
391 )
393 labelSet = set()
394 for label in labels:
395 if labelSpecifier.begin is not None:
396 if label != labelSpecifier.begin:
397 continue
398 else:
399 labelSpecifier.begin = None
400 labelSet.add(label)
401 if labelSpecifier.end is not None and label == labelSpecifier.end:
402 break
403 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet))
405 @staticmethod
406 def _parse_file_specifier(uri: ResourcePathExpression) -> Tuple[ResourcePath, Optional[LabelSpecifier]]:
407 """Split appart a uri and any possible label subsets"""
408 if isinstance(uri, str):
409 # This is to support legacy pipelines during transition
410 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
411 if num_replace:
412 warnings.warn(
413 f"The pipeline file {uri} seems to use the legacy : to separate "
414 "labels, this is deprecated and will be removed after June 2021, please use "
415 "# instead.",
416 category=FutureWarning,
417 )
418 if uri.count("#") > 1:
419 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
420 # Everything else can be converted directly to ResourcePath.
421 uri = ResourcePath(uri)
422 label_subset = uri.fragment or None
424 specifier: Optional[LabelSpecifier]
425 if label_subset is not None:
426 label_subset = urllib.parse.unquote(label_subset)
427 args: Dict[str, Union[Set[str], str, None]]
428 # labels supplied as a list
429 if "," in label_subset:
430 if ".." in label_subset:
431 raise ValueError(
432 "Can only specify a list of labels or a rangewhen loading a Pipline not both"
433 )
434 args = {"labels": set(label_subset.split(","))}
435 # labels supplied as a range
436 elif ".." in label_subset:
437 # Try to de-structure the labelSubset, this will fail if more
438 # than one range is specified
439 begin, end, *rest = label_subset.split("..")
440 if rest:
441 raise ValueError("Only one range can be specified when loading a pipeline")
442 args = {"begin": begin if begin else None, "end": end if end else None}
443 # Assume anything else is a single label
444 else:
445 args = {"labels": {label_subset}}
447 # MyPy doesn't like how cavalier kwarg construction is with types.
448 specifier = LabelSpecifier(**args) # type: ignore
449 else:
450 specifier = None
452 return uri, specifier
454 @classmethod
455 def fromString(cls, pipeline_string: str) -> Pipeline:
456 """Create a pipeline from string formatted as a pipeline document.
458 Parameters
459 ----------
460 pipeline_string : `str`
461 A string that is formatted according like a pipeline document
463 Returns
464 -------
465 pipeline: `Pipeline`
466 """
467 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
468 return pipeline
470 @classmethod
471 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
472 """Create a pipeline from an already created `PipelineIR` object.
474 Parameters
475 ----------
476 deserialized_pipeline: `PipelineIR`
477 An already created pipeline intermediate representation object
479 Returns
480 -------
481 pipeline: `Pipeline`
482 """
483 pipeline = cls.__new__(cls)
484 pipeline._pipelineIR = deserialized_pipeline
485 return pipeline
487 @classmethod
488 def fromPipeline(cls, pipeline: Pipeline) -> Pipeline:
489 """Create a new pipeline by copying an already existing `Pipeline`.
491 Parameters
492 ----------
493 pipeline: `Pipeline`
494 An already created pipeline intermediate representation object
496 Returns
497 -------
498 pipeline: `Pipeline`
499 """
500 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
502 def __str__(self) -> str:
503 return str(self._pipelineIR)
505 def addInstrument(self, instrument: Union[Instrument, str]) -> None:
506 """Add an instrument to the pipeline, or replace an instrument that is
507 already defined.
509 Parameters
510 ----------
511 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
512 Either a derived class object of a `lsst.daf.butler.instrument` or
513 a string corresponding to a fully qualified
514 `lsst.daf.butler.instrument` name.
515 """
516 if isinstance(instrument, str):
517 pass
518 else:
519 # TODO: assume that this is a subclass of Instrument, no type
520 # checking
521 instrument = get_full_type_name(instrument)
522 self._pipelineIR.instrument = instrument
524 def getInstrument(self) -> Optional[str]:
525 """Get the instrument from the pipeline.
527 Returns
528 -------
529 instrument : `str`, or None
530 The fully qualified name of a `lsst.obs.base.Instrument` subclass,
531 name, or None if the pipeline does not have an instrument.
532 """
533 return self._pipelineIR.instrument
535 def addTask(self, task: Union[Type[PipelineTask], str], label: str) -> None:
536 """Add a new task to the pipeline, or replace a task that is already
537 associated with the supplied label.
539 Parameters
540 ----------
541 task: `PipelineTask` or `str`
542 Either a derived class object of a `PipelineTask` or a string
543 corresponding to a fully qualified `PipelineTask` name.
544 label: `str`
545 A label that is used to identify the `PipelineTask` being added
546 """
547 if isinstance(task, str):
548 taskName = task
549 elif issubclass(task, PipelineTask):
550 taskName = get_full_type_name(task)
551 else:
552 raise ValueError(
553 "task must be either a child class of PipelineTask or a string containing"
554 " a fully qualified name to one"
555 )
556 if not label:
557 # in some cases (with command line-generated pipeline) tasks can
558 # be defined without label which is not acceptable, use task
559 # _DefaultName in that case
560 if isinstance(task, str):
561 task_class = doImportType(task)
562 label = task_class._DefaultName
563 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
565 def removeTask(self, label: str) -> None:
566 """Remove a task from the pipeline.
568 Parameters
569 ----------
570 label : `str`
571 The label used to identify the task that is to be removed
573 Raises
574 ------
575 KeyError
576 If no task with that label exists in the pipeline
578 """
579 self._pipelineIR.tasks.pop(label)
581 def addConfigOverride(self, label: str, key: str, value: object) -> None:
582 """Apply single config override.
584 Parameters
585 ----------
586 label : `str`
587 Label of the task.
588 key: `str`
589 Fully-qualified field name.
590 value : object
591 Value to be given to a field.
592 """
593 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
595 def addConfigFile(self, label: str, filename: str) -> None:
596 """Add overrides from a specified file.
598 Parameters
599 ----------
600 label : `str`
601 The label used to identify the task associated with config to
602 modify
603 filename : `str`
604 Path to the override file.
605 """
606 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
608 def addConfigPython(self, label: str, pythonString: str) -> None:
609 """Add Overrides by running a snippet of python code against a config.
611 Parameters
612 ----------
613 label : `str`
614 The label used to identity the task associated with config to
615 modify.
616 pythonString: `str`
617 A string which is valid python code to be executed. This is done
618 with config as the only local accessible value.
619 """
620 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
622 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
623 if label == "parameters":
624 if newConfig.rest.keys() - self._pipelineIR.parameters.mapping.keys():
625 raise ValueError("Cannot override parameters that are not defined in pipeline")
626 self._pipelineIR.parameters.mapping.update(newConfig.rest)
627 if newConfig.file:
628 raise ValueError("Setting parameters section with config file is not supported")
629 if newConfig.python:
630 raise ValueError("Setting parameters section using python block in unsupported")
631 return
632 if label not in self._pipelineIR.tasks:
633 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
634 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
636 def toFile(self, filename: str) -> None:
637 self._pipelineIR.to_file(filename)
639 def write_to_uri(self, uri: ResourcePathExpression) -> None:
640 """Write the pipeline to a file or directory.
642 Parameters
643 ----------
644 uri : convertible to `ResourcePath`
645 URI to write to; may have any scheme with `ResourcePath` write
646 support or no scheme for a local file/directory. Should have a
647 ``.yaml``.
648 """
649 self._pipelineIR.write_to_uri(uri)
651 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
652 """Returns a generator of TaskDefs which can be used to create quantum
653 graphs.
655 Returns
656 -------
657 generator : generator of `TaskDef`
658 The generator returned will be the sorted iterator of tasks which
659 are to be used in constructing a quantum graph.
661 Raises
662 ------
663 NotImplementedError
664 If a dataId is supplied in a config block. This is in place for
665 future use
666 """
667 taskDefs = []
668 for label in self._pipelineIR.tasks:
669 taskDefs.append(self._buildTaskDef(label))
671 # lets evaluate the contracts
672 if self._pipelineIR.contracts is not None:
673 label_to_config = {x.label: x.config for x in taskDefs}
674 for contract in self._pipelineIR.contracts:
675 # execute this in its own line so it can raise a good error
676 # message if there was problems with the eval
677 success = eval(contract.contract, None, label_to_config)
678 if not success:
679 extra_info = f": {contract.msg}" if contract.msg is not None else ""
680 raise pipelineIR.ContractError(
681 f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
682 )
684 taskDefs = sorted(taskDefs, key=lambda x: x.label)
685 yield from pipeTools.orderPipeline(taskDefs)
687 def _buildTaskDef(self, label: str) -> TaskDef:
688 if (taskIR := self._pipelineIR.tasks.get(label)) is None:
689 raise NameError(f"Label {label} does not appear in this pipeline")
690 taskClass: Type[PipelineTask] = doImportType(taskIR.klass)
691 taskName = get_full_type_name(taskClass)
692 config = taskClass.ConfigClass()
693 overrides = ConfigOverrides()
694 if self._pipelineIR.instrument is not None:
695 overrides.addInstrumentOverride(self._pipelineIR.instrument, taskClass._DefaultName)
696 if taskIR.config is not None:
697 for configIR in (configIr.formatted(self._pipelineIR.parameters) for configIr in taskIR.config):
698 if configIR.dataId is not None:
699 raise NotImplementedError(
700 "Specializing a config on a partial data id is not yet "
701 "supported in Pipeline definition"
702 )
703 # only apply override if it applies to everything
704 if configIR.dataId is None:
705 if configIR.file:
706 for configFile in configIR.file:
707 overrides.addFileOverride(os.path.expandvars(configFile))
708 if configIR.python is not None:
709 overrides.addPythonOverride(configIR.python)
710 for key, value in configIR.rest.items():
711 overrides.addValueOverride(key, value)
712 overrides.applyTo(config)
713 return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)
715 def __iter__(self) -> Generator[TaskDef, None, None]:
716 return self.toExpandedPipeline()
718 def __getitem__(self, item: str) -> TaskDef:
719 return self._buildTaskDef(item)
721 def __len__(self) -> int:
722 return len(self._pipelineIR.tasks)
724 def __eq__(self, other: object) -> bool:
725 if not isinstance(other, Pipeline):
726 return False
727 elif self._pipelineIR == other._pipelineIR:
728 # Shortcut: if the IR is the same, the expanded pipeline must be
729 # the same as well. But the converse is not true.
730 return True
731 else:
732 self_expanded = {td.label: (td.taskClass,) for td in self}
733 other_expanded = {td.label: (td.taskClass,) for td in other}
734 if self_expanded != other_expanded:
735 return False
736 # After DM-27847, we should compare configuration here, or better,
737 # delegated to TaskDef.__eq__ after making that compare configurations.
738 raise NotImplementedError(
739 "Pipelines cannot be compared because config instances cannot be compared; see DM-27847."
740 )
743@dataclass(frozen=True)
744class TaskDatasetTypes:
745 """An immutable struct that extracts and classifies the dataset types used
746 by a `PipelineTask`
747 """
749 initInputs: NamedValueSet[DatasetType]
750 """Dataset types that are needed as inputs in order to construct this Task.
752 Task-level `initInputs` may be classified as either
753 `~PipelineDatasetTypes.initInputs` or
754 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
755 """
757 initOutputs: NamedValueSet[DatasetType]
758 """Dataset types that may be written after constructing this Task.
760 Task-level `initOutputs` may be classified as either
761 `~PipelineDatasetTypes.initOutputs` or
762 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
763 """
765 inputs: NamedValueSet[DatasetType]
766 """Dataset types that are regular inputs to this Task.
768 If an input dataset needed for a Quantum cannot be found in the input
769 collection(s) or produced by another Task in the Pipeline, that Quantum
770 (and all dependent Quanta) will not be produced.
772 Task-level `inputs` may be classified as either
773 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
774 at the Pipeline level.
775 """
777 queryConstraints: NamedValueSet[DatasetType]
778 """Regular inputs that should not be used as constraints on the initial
779 QuantumGraph generation data ID query, according to their tasks
780 (`NamedValueSet`).
781 """
783 prerequisites: NamedValueSet[DatasetType]
784 """Dataset types that are prerequisite inputs to this Task.
786 Prerequisite inputs must exist in the input collection(s) before the
787 pipeline is run, but do not constrain the graph - if a prerequisite is
788 missing for a Quantum, `PrerequisiteMissingError` is raised.
790 Prerequisite inputs are not resolved until the second stage of
791 QuantumGraph generation.
792 """
794 outputs: NamedValueSet[DatasetType]
795 """Dataset types that are produced by this Task.
797 Task-level `outputs` may be classified as either
798 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
799 at the Pipeline level.
800 """
802 @classmethod
803 def fromTaskDef(
804 cls,
805 taskDef: TaskDef,
806 *,
807 registry: Registry,
808 include_configs: bool = True,
809 storage_class_mapping: Optional[Mapping[str, str]] = None,
810 ) -> TaskDatasetTypes:
811 """Extract and classify the dataset types from a single `PipelineTask`.
813 Parameters
814 ----------
815 taskDef: `TaskDef`
816 An instance of a `TaskDef` class for a particular `PipelineTask`.
817 registry: `Registry`
818 Registry used to construct normalized `DatasetType` objects and
819 retrieve those that are incomplete.
820 include_configs : `bool`, optional
821 If `True` (default) include config dataset types as
822 ``initOutputs``.
823 storage_class_mapping : `Mapping` of `str` to `StorageClass`, optional
824 If a taskdef contains a component dataset type that is unknown
825 to the registry, its parent StorageClass will be looked up in this
826 mapping if it is supplied. If the mapping does not contain the
827 composite dataset type, or the mapping is not supplied an exception
828 will be raised.
830 Returns
831 -------
832 types: `TaskDatasetTypes`
833 The dataset types used by this task.
835 Raises
836 ------
837 ValueError
838 Raised if dataset type connection definition differs from
839 registry definition.
840 LookupError
841 Raised if component parent StorageClass could not be determined
842 and storage_class_mapping does not contain the composite type, or
843 is set to None.
844 """
846 def makeDatasetTypesSet(
847 connectionType: str,
848 is_input: bool,
849 freeze: bool = True,
850 ) -> NamedValueSet[DatasetType]:
851 """Constructs a set of true `DatasetType` objects
853 Parameters
854 ----------
855 connectionType : `str`
856 Name of the connection type to produce a set for, corresponds
857 to an attribute of type `list` on the connection class instance
858 is_input : `bool`
859 These are input dataset types, else they are output dataset
860 types.
861 freeze : `bool`, optional
862 If `True`, call `NamedValueSet.freeze` on the object returned.
864 Returns
865 -------
866 datasetTypes : `NamedValueSet`
867 A set of all datasetTypes which correspond to the input
868 connection type specified in the connection class of this
869 `PipelineTask`
871 Raises
872 ------
873 ValueError
874 Raised if dataset type connection definition differs from
875 registry definition.
876 LookupError
877 Raised if component parent StorageClass could not be determined
878 and storage_class_mapping does not contain the composite type,
879 or is set to None.
881 Notes
882 -----
883 This function is a closure over the variables ``registry`` and
884 ``taskDef``, and ``storage_class_mapping``.
885 """
886 datasetTypes = NamedValueSet[DatasetType]()
887 for c in iterConnections(taskDef.connections, connectionType):
888 dimensions = set(getattr(c, "dimensions", set()))
889 if "skypix" in dimensions:
890 try:
891 datasetType = registry.getDatasetType(c.name)
892 except LookupError as err:
893 raise LookupError(
894 f"DatasetType '{c.name}' referenced by "
895 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
896 f"placeholder, but does not already exist in the registry. "
897 f"Note that reference catalog names are now used as the dataset "
898 f"type name instead of 'ref_cat'."
899 ) from err
900 rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
901 rest2 = set(
902 dim.name for dim in datasetType.dimensions if not isinstance(dim, SkyPixDimension)
903 )
904 if rest1 != rest2:
905 raise ValueError(
906 f"Non-skypix dimensions for dataset type {c.name} declared in "
907 f"connections ({rest1}) are inconsistent with those in "
908 f"registry's version of this dataset ({rest2})."
909 )
910 else:
911 # Component dataset types are not explicitly in the
912 # registry. This complicates consistency checks with
913 # registry and requires we work out the composite storage
914 # class.
915 registryDatasetType = None
916 try:
917 registryDatasetType = registry.getDatasetType(c.name)
918 except KeyError:
919 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
920 if componentName:
921 if storage_class_mapping is None or compositeName not in storage_class_mapping:
922 raise LookupError(
923 "Component parent class cannot be determined, and "
924 "composite name was not in storage class mapping, or no "
925 "storage_class_mapping was supplied"
926 )
927 else:
928 parentStorageClass = storage_class_mapping[compositeName]
929 else:
930 parentStorageClass = None
931 datasetType = c.makeDatasetType(
932 registry.dimensions, parentStorageClass=parentStorageClass
933 )
934 registryDatasetType = datasetType
935 else:
936 datasetType = c.makeDatasetType(
937 registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass
938 )
940 if registryDatasetType and datasetType != registryDatasetType:
941 # The dataset types differ but first check to see if
942 # they are compatible before raising.
943 if is_input:
944 # This DatasetType must be compatible on get.
945 is_compatible = datasetType.is_compatible_with(registryDatasetType)
946 else:
947 # Has to be able to be converted to expect type
948 # on put.
949 is_compatible = registryDatasetType.is_compatible_with(datasetType)
950 if is_compatible:
951 # For inputs we want the pipeline to use the
952 # pipeline definition, for outputs it should use
953 # the registry definition.
954 if not is_input:
955 datasetType = registryDatasetType
956 _LOG.debug(
957 "Dataset types differ (task %s != registry %s) but are compatible"
958 " for %s in %s.",
959 datasetType,
960 registryDatasetType,
961 "input" if is_input else "output",
962 taskDef.label,
963 )
964 else:
965 try:
966 # Explicitly check for storage class just to
967 # make more specific message.
968 _ = datasetType.storageClass
969 except KeyError:
970 raise ValueError(
971 "Storage class does not exist for supplied dataset type "
972 f"{datasetType} for {taskDef.label}."
973 ) from None
974 raise ValueError(
975 f"Supplied dataset type ({datasetType}) inconsistent with "
976 f"registry definition ({registryDatasetType}) "
977 f"for {taskDef.label}."
978 )
979 datasetTypes.add(datasetType)
980 if freeze:
981 datasetTypes.freeze()
982 return datasetTypes
984 # optionally add initOutput dataset for config
985 initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False)
986 if include_configs:
987 initOutputs.add(
988 DatasetType(
989 taskDef.configDatasetName,
990 registry.dimensions.empty,
991 storageClass="Config",
992 )
993 )
994 initOutputs.freeze()
996 # optionally add output dataset for metadata
997 outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False)
998 if taskDef.metadataDatasetName is not None:
999 # Metadata is supposed to be of the TaskMetadata type, its
1000 # dimensions correspond to a task quantum.
1001 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
1003 # Allow the storage class definition to be read from the existing
1004 # dataset type definition if present.
1005 try:
1006 current = registry.getDatasetType(taskDef.metadataDatasetName)
1007 except KeyError:
1008 # No previous definition so use the default.
1009 storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet"
1010 else:
1011 storageClass = current.storageClass.name
1013 outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
1014 if taskDef.logOutputDatasetName is not None:
1015 # Log output dimensions correspond to a task quantum.
1016 dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
1017 outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")})
1019 outputs.freeze()
1021 inputs = makeDatasetTypesSet("inputs", is_input=True)
1022 queryConstraints = NamedValueSet(
1023 inputs[c.name]
1024 for c in cast(Iterable[Input], iterConnections(taskDef.connections, "inputs"))
1025 if not c.deferGraphConstraint
1026 )
1028 return cls(
1029 initInputs=makeDatasetTypesSet("initInputs", is_input=True),
1030 initOutputs=initOutputs,
1031 inputs=inputs,
1032 queryConstraints=queryConstraints,
1033 prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True),
1034 outputs=outputs,
1035 )
1038@dataclass(frozen=True)
1039class PipelineDatasetTypes:
1040 """An immutable struct that classifies the dataset types used in a
1041 `Pipeline`.
1042 """
1044 packagesDatasetName: ClassVar[str] = "packages"
1045 """Name of a dataset type used to save package versions.
1046 """
1048 initInputs: NamedValueSet[DatasetType]
1049 """Dataset types that are needed as inputs in order to construct the Tasks
1050 in this Pipeline.
1052 This does not include dataset types that are produced when constructing
1053 other Tasks in the Pipeline (these are classified as `initIntermediates`).
1054 """
1056 initOutputs: NamedValueSet[DatasetType]
1057 """Dataset types that may be written after constructing the Tasks in this
1058 Pipeline.
1060 This does not include dataset types that are also used as inputs when
1061 constructing other Tasks in the Pipeline (these are classified as
1062 `initIntermediates`).
1063 """
1065 initIntermediates: NamedValueSet[DatasetType]
1066 """Dataset types that are both used when constructing one or more Tasks
1067 in the Pipeline and produced as a side-effect of constructing another
1068 Task in the Pipeline.
1069 """
1071 inputs: NamedValueSet[DatasetType]
1072 """Dataset types that are regular inputs for the full pipeline.
1074 If an input dataset needed for a Quantum cannot be found in the input
1075 collection(s), that Quantum (and all dependent Quanta) will not be
1076 produced.
1077 """
1079 queryConstraints: NamedValueSet[DatasetType]
1080 """Regular inputs that should be used as constraints on the initial
1081 QuantumGraph generation data ID query, according to their tasks
1082 (`NamedValueSet`).
1083 """
1085 prerequisites: NamedValueSet[DatasetType]
1086 """Dataset types that are prerequisite inputs for the full Pipeline.
1088 Prerequisite inputs must exist in the input collection(s) before the
1089 pipeline is run, but do not constrain the graph - if a prerequisite is
1090 missing for a Quantum, `PrerequisiteMissingError` is raised.
1092 Prerequisite inputs are not resolved until the second stage of
1093 QuantumGraph generation.
1094 """
1096 intermediates: NamedValueSet[DatasetType]
1097 """Dataset types that are output by one Task in the Pipeline and consumed
1098 as inputs by one or more other Tasks in the Pipeline.
1099 """
1101 outputs: NamedValueSet[DatasetType]
1102 """Dataset types that are output by a Task in the Pipeline and not consumed
1103 by any other Task in the Pipeline.
1104 """
1106 byTask: Mapping[str, TaskDatasetTypes]
1107 """Per-Task dataset types, keyed by label in the `Pipeline`.
1109 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
1110 neither has been modified since the dataset types were extracted, of
1111 course).
1112 """
1114 @classmethod
1115 def fromPipeline(
1116 cls,
1117 pipeline: Union[Pipeline, Iterable[TaskDef]],
1118 *,
1119 registry: Registry,
1120 include_configs: bool = True,
1121 include_packages: bool = True,
1122 ) -> PipelineDatasetTypes:
1123 """Extract and classify the dataset types from all tasks in a
1124 `Pipeline`.
1126 Parameters
1127 ----------
1128 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1129 A collection of tasks that can be run together.
1130 registry: `Registry`
1131 Registry used to construct normalized `DatasetType` objects and
1132 retrieve those that are incomplete.
1133 include_configs : `bool`, optional
1134 If `True` (default) include config dataset types as
1135 ``initOutputs``.
1136 include_packages : `bool`, optional
1137 If `True` (default) include the dataset type for software package
1138 versions in ``initOutputs``.
1140 Returns
1141 -------
1142 types: `PipelineDatasetTypes`
1143 The dataset types used by this `Pipeline`.
1145 Raises
1146 ------
1147 ValueError
1148 Raised if Tasks are inconsistent about which datasets are marked
1149 prerequisite. This indicates that the Tasks cannot be run as part
1150 of the same `Pipeline`.
1151 """
1152 allInputs = NamedValueSet[DatasetType]()
1153 allOutputs = NamedValueSet[DatasetType]()
1154 allInitInputs = NamedValueSet[DatasetType]()
1155 allInitOutputs = NamedValueSet[DatasetType]()
1156 prerequisites = NamedValueSet[DatasetType]()
1157 queryConstraints = NamedValueSet[DatasetType]()
1158 byTask = dict()
1159 if include_packages:
1160 allInitOutputs.add(
1161 DatasetType(
1162 cls.packagesDatasetName,
1163 registry.dimensions.empty,
1164 storageClass="Packages",
1165 )
1166 )
1167 # create a list of TaskDefs in case the input is a generator
1168 pipeline = list(pipeline)
1170 # collect all the output dataset types
1171 typeStorageclassMap: Dict[str, str] = {}
1172 for taskDef in pipeline:
1173 for outConnection in iterConnections(taskDef.connections, "outputs"):
1174 typeStorageclassMap[outConnection.name] = outConnection.storageClass
1176 for taskDef in pipeline:
1177 thisTask = TaskDatasetTypes.fromTaskDef(
1178 taskDef,
1179 registry=registry,
1180 include_configs=include_configs,
1181 storage_class_mapping=typeStorageclassMap,
1182 )
1183 allInitInputs.update(thisTask.initInputs)
1184 allInitOutputs.update(thisTask.initOutputs)
1185 allInputs.update(thisTask.inputs)
1186 # Inputs are query constraints if any task considers them a query
1187 # constraint.
1188 queryConstraints.update(thisTask.queryConstraints)
1189 prerequisites.update(thisTask.prerequisites)
1190 allOutputs.update(thisTask.outputs)
1191 byTask[taskDef.label] = thisTask
1192 if not prerequisites.isdisjoint(allInputs):
1193 raise ValueError(
1194 "{} marked as both prerequisites and regular inputs".format(
1195 {dt.name for dt in allInputs & prerequisites}
1196 )
1197 )
1198 if not prerequisites.isdisjoint(allOutputs):
1199 raise ValueError(
1200 "{} marked as both prerequisites and outputs".format(
1201 {dt.name for dt in allOutputs & prerequisites}
1202 )
1203 )
1204 # Make sure that components which are marked as inputs get treated as
1205 # intermediates if there is an output which produces the composite
1206 # containing the component
1207 intermediateComponents = NamedValueSet[DatasetType]()
1208 intermediateComposites = NamedValueSet[DatasetType]()
1209 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
1210 for dsType in allInputs:
1211 # get the name of a possible component
1212 name, component = dsType.nameAndComponent()
1213 # if there is a component name, that means this is a component
1214 # DatasetType, if there is an output which produces the parent of
1215 # this component, treat this input as an intermediate
1216 if component is not None:
1217 # This needs to be in this if block, because someone might have
1218 # a composite that is a pure input from existing data
1219 if name in outputNameMapping:
1220 intermediateComponents.add(dsType)
1221 intermediateComposites.add(outputNameMapping[name])
1223 def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None:
1224 common = a.names & b.names
1225 for name in common:
1226 # Any compatibility is allowed. This function does not know
1227 # if a dataset type is to be used for input or output.
1228 if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])):
1229 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
1231 checkConsistency(allInitInputs, allInitOutputs)
1232 checkConsistency(allInputs, allOutputs)
1233 checkConsistency(allInputs, intermediateComposites)
1234 checkConsistency(allOutputs, intermediateComposites)
1236 def frozen(s: AbstractSet[DatasetType]) -> NamedValueSet[DatasetType]:
1237 assert isinstance(s, NamedValueSet)
1238 s.freeze()
1239 return s
1241 inputs = frozen(allInputs - allOutputs - intermediateComponents)
1243 return cls(
1244 initInputs=frozen(allInitInputs - allInitOutputs),
1245 initIntermediates=frozen(allInitInputs & allInitOutputs),
1246 initOutputs=frozen(allInitOutputs - allInitInputs),
1247 inputs=inputs,
1248 queryConstraints=frozen(queryConstraints & inputs),
1249 # If there are storage class differences in inputs and outputs
1250 # the intermediates have to choose priority. Here choose that
1251 # inputs to tasks much match the requested storage class by
1252 # applying the inputs over the top of the outputs.
1253 intermediates=frozen(allOutputs & allInputs | intermediateComponents),
1254 outputs=frozen(allOutputs - allInputs - intermediateComposites),
1255 prerequisites=frozen(prerequisites),
1256 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
1257 )
1259 @classmethod
1260 def initOutputNames(
1261 cls,
1262 pipeline: Union[Pipeline, Iterable[TaskDef]],
1263 *,
1264 include_configs: bool = True,
1265 include_packages: bool = True,
1266 ) -> Iterator[str]:
1267 """Return the names of dataset types ot task initOutputs, Configs,
1268 and package versions for a pipeline.
1270 Parameters
1271 ----------
1272 pipeline: `Pipeline` or `Iterable` [ `TaskDef` ]
1273 A `Pipeline` instance or collection of `TaskDef` instances.
1274 include_configs : `bool`, optional
1275 If `True` (default) include config dataset types.
1276 include_packages : `bool`, optional
1277 If `True` (default) include the dataset type for package versions.
1279 Yields
1280 ------
1281 datasetTypeName : `str`
1282 Name of the dataset type.
1283 """
1284 if include_packages:
1285 # Package versions dataset type
1286 yield cls.packagesDatasetName
1288 if isinstance(pipeline, Pipeline):
1289 pipeline = pipeline.toExpandedPipeline()
1291 for taskDef in pipeline:
1292 # all task InitOutputs
1293 for name in taskDef.connections.initOutputs:
1294 attribute = getattr(taskDef.connections, name)
1295 yield attribute.name
1297 # config dataset name
1298 if include_configs:
1299 yield taskDef.configDatasetName