Coverage for python/lsst/pipe/base/pipeline.py: 23%
468 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 10:56 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-06 10:56 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Module defining Pipeline class and related methods.
29"""
31from __future__ import annotations
33__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
35import copy
36import logging
37import re
38import urllib.parse
40# -------------------------------
41# Imports of standard modules --
42# -------------------------------
43from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Set
44from dataclasses import dataclass
45from types import MappingProxyType
46from typing import TYPE_CHECKING, ClassVar, cast
48# -----------------------------
49# Imports for other modules --
50from lsst.daf.butler import DataCoordinate, DatasetType, DimensionUniverse, NamedValueSet, Registry
51from lsst.resources import ResourcePath, ResourcePathExpression
52from lsst.utils import doImportType
53from lsst.utils.introspection import get_full_type_name
55from . import automatic_connection_constants as acc
56from . import pipeline_graph, pipelineIR
57from ._instrument import Instrument as PipeBaseInstrument
58from .config import PipelineTaskConfig
59from .connections import PipelineTaskConnections, iterConnections
60from .connectionTypes import Input
61from .pipelineTask import PipelineTask
63if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
64 from lsst.obs.base import Instrument
65 from lsst.pex.config import Config
67# ----------------------------------
68# Local non-exported definitions --
69# ----------------------------------
71_LOG = logging.getLogger(__name__)
73# ------------------------
74# Exported definitions --
75# ------------------------
78@dataclass
79class LabelSpecifier:
80 """A structure to specify a subset of labels to load
82 This structure may contain a set of labels to be used in subsetting a
83 pipeline, or a beginning and end point. Beginning or end may be empty,
84 in which case the range will be a half open interval. Unlike python
85 iteration bounds, end bounds are *INCLUDED*. Note that range based
86 selection is not well defined for pipelines that are not linear in nature,
87 and correct behavior is not guaranteed, or may vary from run to run.
88 """
90 labels: set[str] | None = None
91 begin: str | None = None
92 end: str | None = None
94 def __post_init__(self) -> None:
95 if self.labels is not None and (self.begin or self.end):
96 raise ValueError(
97 "This struct can only be initialized with a labels set or a begin (and/or) end specifier"
98 )
101class TaskDef:
102 """TaskDef is a collection of information about task needed by Pipeline.
104 The information includes task name, configuration object and optional
105 task class. This class is just a collection of attributes and it exposes
106 all of them so that attributes could potentially be modified in place
107 (e.g. if configuration needs extra overrides).
109 Attributes
110 ----------
111 taskName : `str`, optional
112 The fully-qualified `PipelineTask` class name. If not provided,
113 ``taskClass`` must be.
114 config : `lsst.pipe.base.config.PipelineTaskConfig`, optional
115 Instance of the configuration class corresponding to this task class,
116 usually with all overrides applied. This config will be frozen. If
117 not provided, ``taskClass`` must be provided and
118 ``taskClass.ConfigClass()`` will be used.
119 taskClass : `type`, optional
120 `PipelineTask` class object; if provided and ``taskName`` is as well,
121 the caller guarantees that they are consistent. If not provided,
122 ``taskName`` is used to import the type.
123 label : `str`, optional
124 Task label, usually a short string unique in a pipeline. If not
125 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
126 be used.
127 connections : `PipelineTaskConnections`, optional
128 Object that describes the dataset types used by the task. If not
129 provided, one will be constructed from the given configuration. If
130 provided, it is assumed that ``config`` has already been validated
131 and frozen.
132 """
134 def __init__(
135 self,
136 taskName: str | None = None,
137 config: PipelineTaskConfig | None = None,
138 taskClass: type[PipelineTask] | None = None,
139 label: str | None = None,
140 connections: PipelineTaskConnections | None = None,
141 ):
142 if taskName is None:
143 if taskClass is None:
144 raise ValueError("At least one of `taskName` and `taskClass` must be provided.")
145 taskName = get_full_type_name(taskClass)
146 elif taskClass is None:
147 taskClass = doImportType(taskName)
148 if config is None:
149 if taskClass is None:
150 raise ValueError("`taskClass` must be provided if `config` is not.")
151 config = taskClass.ConfigClass()
152 if label is None:
153 if taskClass is None:
154 raise ValueError("`taskClass` must be provided if `label` is not.")
155 label = taskClass._DefaultName
156 self.taskName = taskName
157 if connections is None:
158 # If we don't have connections yet, assume the config hasn't been
159 # validated yet.
160 try:
161 config.validate()
162 except Exception:
163 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
164 raise
165 config.freeze()
166 connections = config.connections.ConnectionsClass(config=config)
167 self.config = config
168 self.taskClass = taskClass
169 self.label = label
170 self.connections = connections
172 @property
173 def configDatasetName(self) -> str:
174 """Name of a dataset type for configuration of this task (`str`)"""
175 return acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=self.label)
177 @property
178 def metadataDatasetName(self) -> str:
179 """Name of a dataset type for metadata of this task (`str`)"""
180 return self.makeMetadataDatasetName(self.label)
182 @classmethod
183 def makeMetadataDatasetName(cls, label: str) -> str:
184 """Construct the name of the dataset type for metadata for a task.
186 Parameters
187 ----------
188 label : `str`
189 Label for the task within its pipeline.
191 Returns
192 -------
193 name : `str`
194 Name of the task's metadata dataset type.
195 """
196 return acc.METADATA_OUTPUT_TEMPLATE.format(label=label)
198 @property
199 def logOutputDatasetName(self) -> str | None:
200 """Name of a dataset type for log output from this task, `None` if
201 logs are not to be saved (`str`)
202 """
203 if self.config.saveLogOutput:
204 return acc.LOG_OUTPUT_TEMPLATE.format(label=self.label)
205 else:
206 return None
208 def __str__(self) -> str:
209 rep = "TaskDef(" + self.taskName
210 if self.label:
211 rep += ", label=" + self.label
212 rep += ")"
213 return rep
215 def __eq__(self, other: object) -> bool:
216 if not isinstance(other, TaskDef):
217 return False
218 # This does not consider equality of configs when determining equality
219 # as config equality is a difficult thing to define. Should be updated
220 # after DM-27847
221 return self.taskClass == other.taskClass and self.label == other.label
223 def __hash__(self) -> int:
224 return hash((self.taskClass, self.label))
226 @classmethod
227 def _unreduce(cls, taskName: str, config: PipelineTaskConfig, label: str) -> TaskDef:
228 """Unpickle pickle. Custom callable for unpickling.
230 All arguments are forwarded directly to the constructor; this
231 trampoline is only needed because ``__reduce__`` callables can't be
232 called with keyword arguments.
233 """
234 return cls(taskName=taskName, config=config, label=label)
236 def __reduce__(self) -> tuple[Callable[[str, PipelineTaskConfig, str], TaskDef], tuple[str, Config, str]]:
237 return (self._unreduce, (self.taskName, self.config, self.label))
240class Pipeline:
241 """A `Pipeline` is a representation of a series of tasks to run, and the
242 configuration for those tasks.
244 Parameters
245 ----------
246 description : `str`
247 A description of that this pipeline does.
248 """
250 PipelineSubsetCtrl = pipelineIR.PipelineSubsetCtrl
252 def __init__(self, description: str):
253 pipeline_dict = {"description": description, "tasks": {}}
254 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
256 @classmethod
257 def fromFile(cls, filename: str) -> Pipeline:
258 """Load a pipeline defined in a pipeline yaml file.
260 Parameters
261 ----------
262 filename: `str`
263 A path that points to a pipeline defined in yaml format. This
264 filename may also supply additional labels to be used in
265 subsetting the loaded Pipeline. These labels are separated from
266 the path by a ``#``, and may be specified as a comma separated
267 list, or a range denoted as beginning..end. Beginning or end may
268 be empty, in which case the range will be a half open interval.
269 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
270 that range based selection is not well defined for pipelines that
271 are not linear in nature, and correct behavior is not guaranteed,
272 or may vary from run to run.
274 Returns
275 -------
276 pipeline: `Pipeline`
277 The pipeline loaded from specified location with appropriate (if
278 any) subsetting.
280 Notes
281 -----
282 This method attempts to prune any contracts that contain labels which
283 are not in the declared subset of labels. This pruning is done using a
284 string based matching due to the nature of contracts and may prune more
285 than it should.
286 """
287 return cls.from_uri(filename)
289 @classmethod
290 def from_uri(cls, uri: ResourcePathExpression) -> Pipeline:
291 """Load a pipeline defined in a pipeline yaml file at a location
292 specified by a URI.
294 Parameters
295 ----------
296 uri : convertible to `~lsst.resources.ResourcePath`
297 If a string is supplied this should be a URI path that points to a
298 pipeline defined in yaml format, either as a direct path to the
299 yaml file, or as a directory containing a ``pipeline.yaml`` file
300 the form used by `write_to_uri` with ``expand=True``). This uri may
301 also supply additional labels to be used in subsetting the loaded
302 `Pipeline`. These labels are separated from the path by a ``#``,
303 and may be specified as a comma separated list, or a range denoted
304 as beginning..end. Beginning or end may be empty, in which case the
305 range will be a half open interval. Unlike python iteration bounds,
306 end bounds are *INCLUDED*. Note that range based selection is not
307 well defined for pipelines that are not linear in nature, and
308 correct behavior is not guaranteed, or may vary from run to run.
309 The same specifiers can be used with a
310 `~lsst.resources.ResourcePath` object, by being the sole contents
311 in the fragments attribute.
313 Returns
314 -------
315 pipeline : `Pipeline`
316 The pipeline loaded from specified location with appropriate (if
317 any) subsetting.
319 Notes
320 -----
321 This method attempts to prune any contracts that contain labels which
322 are not in the declared subset of labels. This pruning is done using a
323 string based matching due to the nature of contracts and may prune more
324 than it should.
325 """
326 # Split up the uri and any labels that were supplied
327 uri, label_specifier = cls._parse_file_specifier(uri)
328 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
330 # If there are labels supplied, only keep those
331 if label_specifier is not None:
332 pipeline = pipeline.subsetFromLabels(label_specifier)
333 return pipeline
335 def subsetFromLabels(
336 self,
337 labelSpecifier: LabelSpecifier,
338 subsetCtrl: pipelineIR.PipelineSubsetCtrl = PipelineSubsetCtrl.DROP,
339 ) -> Pipeline:
340 """Subset a pipeline to contain only labels specified in labelSpecifier
342 Parameters
343 ----------
344 labelSpecifier : `labelSpecifier`
345 Object containing labels that describes how to subset a pipeline.
346 subsetCtrl : `PipelineSubsetCtrl`
347 Control object which decides how subsets with missing labels are
348 handled. Setting to `PipelineSubsetCtrl.DROP` (the default) will
349 cause any subsets that have labels which are not in the set of all
350 task labels to be dropped. Setting to `PipelineSubsetCtrl.EDIT`
351 will cause the subset to instead be edited to remove the
352 nonexistent label.
354 Returns
355 -------
356 pipeline : `Pipeline`
357 A new pipeline object that is a subset of the old pipeline
359 Raises
360 ------
361 ValueError
362 Raised if there is an issue with specified labels
364 Notes
365 -----
366 This method attempts to prune any contracts that contain labels which
367 are not in the declared subset of labels. This pruning is done using a
368 string based matching due to the nature of contracts and may prune more
369 than it should.
370 """
371 # Labels supplied as a set
372 if labelSpecifier.labels:
373 labelSet = labelSpecifier.labels
374 # Labels supplied as a range, first create a list of all the labels
375 # in the pipeline sorted according to task dependency. Then only
376 # keep labels that lie between the supplied bounds
377 else:
378 # Create a copy of the pipeline to use when assessing the label
379 # ordering. Use a dict for fast searching while preserving order.
380 # Remove contracts so they do not fail in the expansion step. This
381 # is needed because a user may only configure the tasks they intend
382 # to run, which may cause some contracts to fail if they will later
383 # be dropped
384 pipeline = copy.deepcopy(self)
385 pipeline._pipelineIR.contracts = []
386 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
388 # Verify the bounds are in the labels
389 if labelSpecifier.begin is not None:
390 if labelSpecifier.begin not in labels:
391 raise ValueError(
392 f"Beginning of range subset, {labelSpecifier.begin}, not found in pipeline definition"
393 )
394 if labelSpecifier.end is not None:
395 if labelSpecifier.end not in labels:
396 raise ValueError(
397 f"End of range subset, {labelSpecifier.end}, not found in pipeline definition"
398 )
400 labelSet = set()
401 for label in labels:
402 if labelSpecifier.begin is not None:
403 if label != labelSpecifier.begin:
404 continue
405 else:
406 labelSpecifier.begin = None
407 labelSet.add(label)
408 if labelSpecifier.end is not None and label == labelSpecifier.end:
409 break
410 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet, subsetCtrl))
412 @staticmethod
413 def _parse_file_specifier(uri: ResourcePathExpression) -> tuple[ResourcePath, LabelSpecifier | None]:
414 """Split appart a uri and any possible label subsets"""
415 if isinstance(uri, str):
416 # This is to support legacy pipelines during transition
417 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
418 if num_replace:
419 raise ValueError(
420 f"The pipeline file {uri} seems to use the legacy :"
421 " to separate labels, please use # instead."
422 )
423 if uri.count("#") > 1:
424 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
425 # Everything else can be converted directly to ResourcePath.
426 uri = ResourcePath(uri)
427 label_subset = uri.fragment or None
429 specifier: LabelSpecifier | None
430 if label_subset is not None:
431 label_subset = urllib.parse.unquote(label_subset)
432 args: dict[str, set[str] | str | None]
433 # labels supplied as a list
434 if "," in label_subset:
435 if ".." in label_subset:
436 raise ValueError(
437 "Can only specify a list of labels or a rangewhen loading a Pipline not both"
438 )
439 args = {"labels": set(label_subset.split(","))}
440 # labels supplied as a range
441 elif ".." in label_subset:
442 # Try to de-structure the labelSubset, this will fail if more
443 # than one range is specified
444 begin, end, *rest = label_subset.split("..")
445 if rest:
446 raise ValueError("Only one range can be specified when loading a pipeline")
447 args = {"begin": begin if begin else None, "end": end if end else None}
448 # Assume anything else is a single label
449 else:
450 args = {"labels": {label_subset}}
452 # MyPy doesn't like how cavalier kwarg construction is with types.
453 specifier = LabelSpecifier(**args) # type: ignore
454 else:
455 specifier = None
457 return uri, specifier
459 @classmethod
460 def fromString(cls, pipeline_string: str) -> Pipeline:
461 """Create a pipeline from string formatted as a pipeline document.
463 Parameters
464 ----------
465 pipeline_string : `str`
466 A string that is formatted according like a pipeline document
468 Returns
469 -------
470 pipeline: `Pipeline`
471 """
472 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
473 return pipeline
475 @classmethod
476 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
477 """Create a pipeline from an already created `PipelineIR` object.
479 Parameters
480 ----------
481 deserialized_pipeline: `PipelineIR`
482 An already created pipeline intermediate representation object
484 Returns
485 -------
486 pipeline: `Pipeline`
487 """
488 pipeline = cls.__new__(cls)
489 pipeline._pipelineIR = deserialized_pipeline
490 return pipeline
492 @classmethod
493 def fromPipeline(cls, pipeline: Pipeline) -> Pipeline:
494 """Create a new pipeline by copying an already existing `Pipeline`.
496 Parameters
497 ----------
498 pipeline: `Pipeline`
499 An already created pipeline intermediate representation object
501 Returns
502 -------
503 pipeline: `Pipeline`
504 """
505 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
507 def __str__(self) -> str:
508 return str(self._pipelineIR)
510 def mergePipeline(self, pipeline: Pipeline) -> None:
511 """Merge another in-memory `Pipeline` object into this one.
513 This merges another pipeline into this object, as if it were declared
514 in the import block of the yaml definition of this pipeline. This
515 modifies this pipeline in place.
517 Parameters
518 ----------
519 pipeline : `Pipeline`
520 The `Pipeline` object that is to be merged into this object.
521 """
522 self._pipelineIR.merge_pipelines((pipeline._pipelineIR,))
524 def addLabelToSubset(self, subset: str, label: str) -> None:
525 """Add a task label from the specified subset.
527 Parameters
528 ----------
529 subset : `str`
530 The labeled subset to modify
531 label : `str`
532 The task label to add to the specified subset.
534 Raises
535 ------
536 ValueError
537 Raised if the specified subset does not exist within the pipeline.
538 Raised if the specified label does not exist within the pipeline.
539 """
540 if label not in self._pipelineIR.tasks:
541 raise ValueError(f"Label {label} does not appear within the pipeline")
542 if subset not in self._pipelineIR.labeled_subsets:
543 raise ValueError(f"Subset {subset} does not appear within the pipeline")
544 self._pipelineIR.labeled_subsets[subset].subset.add(label)
546 def removeLabelFromSubset(self, subset: str, label: str) -> None:
547 """Remove a task label from the specified subset.
549 Parameters
550 ----------
551 subset : `str`
552 The labeled subset to modify
553 label : `str`
554 The task label to remove from the specified subset.
556 Raises
557 ------
558 ValueError
559 Raised if the specified subset does not exist in the pipeline.
560 Raised if the specified label does not exist within the specified
561 subset.
562 """
563 if subset not in self._pipelineIR.labeled_subsets:
564 raise ValueError(f"Subset {subset} does not appear within the pipeline")
565 if label not in self._pipelineIR.labeled_subsets[subset].subset:
566 raise ValueError(f"Label {label} does not appear within the pipeline")
567 self._pipelineIR.labeled_subsets[subset].subset.remove(label)
569 def findSubsetsWithLabel(self, label: str) -> set[str]:
570 """Find any subsets which may contain the specified label.
572 This function returns the name of subsets which return the specified
573 label. May return an empty set if there are no subsets, or no subsets
574 containing the specified label.
576 Parameters
577 ----------
578 label : `str`
579 The task label to use in membership check
581 Returns
582 -------
583 subsets : `set` of `str`
584 Returns a set (possibly empty) of subsets names which contain the
585 specified label.
587 Raises
588 ------
589 ValueError
590 Raised if the specified label does not exist within this pipeline.
591 """
592 results = set()
593 if label not in self._pipelineIR.tasks:
594 raise ValueError(f"Label {label} does not appear within the pipeline")
595 for subset in self._pipelineIR.labeled_subsets.values():
596 if label in subset.subset:
597 results.add(subset.label)
598 return results
600 @property
601 def subsets(self) -> MappingProxyType[str, set]:
602 """Returns a `MappingProxyType` where the keys are the labels of
603 labeled subsets in the `Pipeline` and the values are the set of task
604 labels contained within that subset.
605 """
606 return MappingProxyType(
607 {label: subsetIr.subset for label, subsetIr in self._pipelineIR.labeled_subsets.items()}
608 )
610 def addLabeledSubset(self, label: str, description: str, taskLabels: set[str]) -> None:
611 """Add a new labeled subset to the `Pipeline`.
613 Parameters
614 ----------
615 label : `str`
616 The label to assign to the subset.
617 description : `str`
618 A description of what the subset is for.
619 taskLabels : `set` [`str`]
620 The set of task labels to be associated with the labeled subset.
622 Raises
623 ------
624 ValueError
625 Raised if label already exists in the `Pipeline`.
626 Raised if a task label is not found within the `Pipeline`.
627 """
628 if label in self._pipelineIR.labeled_subsets.keys():
629 raise ValueError(f"Subset label {label} is already found within the Pipeline")
630 if extra := (taskLabels - self._pipelineIR.tasks.keys()):
631 raise ValueError(f"Task labels {extra} were not found within the Pipeline")
632 self._pipelineIR.labeled_subsets[label] = pipelineIR.LabeledSubset(label, taskLabels, description)
634 def removeLabeledSubset(self, label: str) -> None:
635 """Remove a labeled subset from the `Pipeline`.
637 Parameters
638 ----------
639 label : `str`
640 The label of the subset to remove from the `Pipeline`
642 Raises
643 ------
644 ValueError
645 Raised if the label is not found within the `Pipeline`
646 """
647 if label not in self._pipelineIR.labeled_subsets.keys():
648 raise ValueError(f"Subset label {label} was not found in the pipeline")
649 self._pipelineIR.labeled_subsets.pop(label)
651 def addInstrument(self, instrument: Instrument | str) -> None:
652 """Add an instrument to the pipeline, or replace an instrument that is
653 already defined.
655 Parameters
656 ----------
657 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
658 Either a derived class object of a `lsst.daf.butler.instrument` or
659 a string corresponding to a fully qualified
660 `lsst.daf.butler.instrument` name.
661 """
662 if isinstance(instrument, str):
663 pass
664 else:
665 # TODO: assume that this is a subclass of Instrument, no type
666 # checking
667 instrument = get_full_type_name(instrument)
668 self._pipelineIR.instrument = instrument
670 def getInstrument(self) -> str | None:
671 """Get the instrument from the pipeline.
673 Returns
674 -------
675 instrument : `str`, or None
676 The fully qualified name of a `lsst.obs.base.Instrument` subclass,
677 name, or None if the pipeline does not have an instrument.
678 """
679 return self._pipelineIR.instrument
681 def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate:
682 """Return a data ID with all dimension constraints embedded in the
683 pipeline.
685 Parameters
686 ----------
687 universe : `lsst.daf.butler.DimensionUniverse`
688 Object that defines all dimensions.
690 Returns
691 -------
692 data_id : `lsst.daf.butler.DataCoordinate`
693 Data ID with all dimension constraints embedded in the
694 pipeline.
695 """
696 instrument_class_name = self._pipelineIR.instrument
697 if instrument_class_name is not None:
698 instrument_class = cast(PipeBaseInstrument, doImportType(instrument_class_name))
699 if instrument_class is not None:
700 return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe)
701 return DataCoordinate.make_empty(universe)
703 def addTask(self, task: type[PipelineTask] | str, label: str) -> None:
704 """Add a new task to the pipeline, or replace a task that is already
705 associated with the supplied label.
707 Parameters
708 ----------
709 task: `PipelineTask` or `str`
710 Either a derived class object of a `PipelineTask` or a string
711 corresponding to a fully qualified `PipelineTask` name.
712 label: `str`
713 A label that is used to identify the `PipelineTask` being added
714 """
715 if isinstance(task, str):
716 taskName = task
717 elif issubclass(task, PipelineTask):
718 taskName = get_full_type_name(task)
719 else:
720 raise ValueError(
721 "task must be either a child class of PipelineTask or a string containing"
722 " a fully qualified name to one"
723 )
724 if not label:
725 # in some cases (with command line-generated pipeline) tasks can
726 # be defined without label which is not acceptable, use task
727 # _DefaultName in that case
728 if isinstance(task, str):
729 task_class = cast(PipelineTask, doImportType(task))
730 label = task_class._DefaultName
731 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
733 def removeTask(self, label: str) -> None:
734 """Remove a task from the pipeline.
736 Parameters
737 ----------
738 label : `str`
739 The label used to identify the task that is to be removed
741 Raises
742 ------
743 KeyError
744 If no task with that label exists in the pipeline
746 """
747 self._pipelineIR.tasks.pop(label)
749 def addConfigOverride(self, label: str, key: str, value: object) -> None:
750 """Apply single config override.
752 Parameters
753 ----------
754 label : `str`
755 Label of the task.
756 key: `str`
757 Fully-qualified field name.
758 value : object
759 Value to be given to a field.
760 """
761 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
763 def addConfigFile(self, label: str, filename: str) -> None:
764 """Add overrides from a specified file.
766 Parameters
767 ----------
768 label : `str`
769 The label used to identify the task associated with config to
770 modify
771 filename : `str`
772 Path to the override file.
773 """
774 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
776 def addConfigPython(self, label: str, pythonString: str) -> None:
777 """Add Overrides by running a snippet of python code against a config.
779 Parameters
780 ----------
781 label : `str`
782 The label used to identity the task associated with config to
783 modify.
784 pythonString: `str`
785 A string which is valid python code to be executed. This is done
786 with config as the only local accessible value.
787 """
788 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
790 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
791 if label == "parameters":
792 self._pipelineIR.parameters.mapping.update(newConfig.rest)
793 if newConfig.file:
794 raise ValueError("Setting parameters section with config file is not supported")
795 if newConfig.python:
796 raise ValueError("Setting parameters section using python block in unsupported")
797 return
798 if label not in self._pipelineIR.tasks:
799 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
800 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
802 def write_to_uri(self, uri: ResourcePathExpression) -> None:
803 """Write the pipeline to a file or directory.
805 Parameters
806 ----------
807 uri : convertible to `~lsst.resources.ResourcePath`
808 URI to write to; may have any scheme with
809 `~lsst.resources.ResourcePath` write support or no scheme for a
810 local file/directory. Should have a ``.yaml`` extension.
811 """
812 self._pipelineIR.write_to_uri(uri)
814 def to_graph(self, registry: Registry | None = None) -> pipeline_graph.PipelineGraph:
815 """Construct a pipeline graph from this pipeline.
817 Constructing a graph applies all configuration overrides, freezes all
818 configuration, checks all contracts, and checks for dataset type
819 consistency between tasks (as much as possible without access to a data
820 repository). It cannot be reversed.
822 Parameters
823 ----------
824 registry : `lsst.daf.butler.Registry`, optional
825 Data repository client. If provided, the graph's dataset types
826 and dimensions will be resolved (see `PipelineGraph.resolve`).
828 Returns
829 -------
830 graph : `pipeline_graph.PipelineGraph`
831 Representation of the pipeline as a graph.
832 """
833 instrument_class_name = self._pipelineIR.instrument
834 data_id = {}
835 if instrument_class_name is not None:
836 instrument_class: type[Instrument] = doImportType(instrument_class_name)
837 if instrument_class is not None:
838 data_id["instrument"] = instrument_class.getName()
839 graph = pipeline_graph.PipelineGraph(data_id=data_id)
840 graph.description = self._pipelineIR.description
841 for label in self._pipelineIR.tasks:
842 self._add_task_to_graph(label, graph)
843 if self._pipelineIR.contracts is not None:
844 label_to_config = {x.label: x.config for x in graph.tasks.values()}
845 for contract in self._pipelineIR.contracts:
846 # execute this in its own line so it can raise a good error
847 # message if there was problems with the eval
848 success = eval(contract.contract, None, label_to_config)
849 if not success:
850 extra_info = f": {contract.msg}" if contract.msg is not None else ""
851 raise pipelineIR.ContractError(
852 f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
853 )
854 for label, subset in self._pipelineIR.labeled_subsets.items():
855 graph.add_task_subset(
856 label, subset.subset, subset.description if subset.description is not None else ""
857 )
858 graph.sort()
859 if registry is not None:
860 graph.resolve(registry)
861 return graph
863 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
864 r"""Return a generator of `TaskDef`\s which can be used to create
865 quantum graphs.
867 Returns
868 -------
869 generator : generator of `TaskDef`
870 The generator returned will be the sorted iterator of tasks which
871 are to be used in constructing a quantum graph.
873 Raises
874 ------
875 NotImplementedError
876 If a dataId is supplied in a config block. This is in place for
877 future use
878 """
879 yield from self.to_graph()._iter_task_defs()
881 def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None:
882 """Add a single task from this pipeline to a pipeline graph that is
883 under construction.
885 Parameters
886 ----------
887 label : `str`
888 Label for the task to be added.
889 graph : `pipeline_graph.PipelineGraph`
890 Graph to add the task to.
891 """
892 if (taskIR := self._pipelineIR.tasks.get(label)) is None:
893 raise NameError(f"Label {label} does not appear in this pipeline")
894 taskClass: type[PipelineTask] = doImportType(taskIR.klass)
895 config = taskClass.ConfigClass()
896 instrument: PipeBaseInstrument | None = None
897 if (instrumentName := self._pipelineIR.instrument) is not None:
898 instrument_cls: type = doImportType(instrumentName)
899 instrument = instrument_cls()
900 config.applyConfigOverrides(
901 instrument,
902 getattr(taskClass, "_DefaultName", ""),
903 taskIR.config,
904 self._pipelineIR.parameters,
905 label,
906 )
907 graph.add_task(label, taskClass, config)
909 def __iter__(self) -> Generator[TaskDef, None, None]:
910 return self.toExpandedPipeline()
912 def __getitem__(self, item: str) -> TaskDef:
913 # Making a whole graph and then making a TaskDef from that is pretty
914 # backwards, but I'm hoping to deprecate this method shortly in favor
915 # of making the graph explicitly and working with its node objects.
916 graph = pipeline_graph.PipelineGraph()
917 self._add_task_to_graph(item, graph)
918 (result,) = graph._iter_task_defs()
919 return result
921 def __len__(self) -> int:
922 return len(self._pipelineIR.tasks)
924 def __eq__(self, other: object) -> bool:
925 if not isinstance(other, Pipeline):
926 return False
927 elif self._pipelineIR == other._pipelineIR:
928 # Shortcut: if the IR is the same, the expanded pipeline must be
929 # the same as well. But the converse is not true.
930 return True
931 else:
932 self_expanded = {td.label: (td.taskClass,) for td in self}
933 other_expanded = {td.label: (td.taskClass,) for td in other}
934 if self_expanded != other_expanded:
935 return False
936 # After DM-27847, we should compare configuration here, or better,
937 # delegated to TaskDef.__eq__ after making that compare configurations.
938 raise NotImplementedError(
939 "Pipelines cannot be compared because config instances cannot be compared; see DM-27847."
940 )
943@dataclass(frozen=True)
944class TaskDatasetTypes:
945 """An immutable struct that extracts and classifies the dataset types used
946 by a `PipelineTask`
947 """
949 initInputs: NamedValueSet[DatasetType]
950 """Dataset types that are needed as inputs in order to construct this Task.
952 Task-level `initInputs` may be classified as either
953 `~PipelineDatasetTypes.initInputs` or
954 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
955 """
957 initOutputs: NamedValueSet[DatasetType]
958 """Dataset types that may be written after constructing this Task.
960 Task-level `initOutputs` may be classified as either
961 `~PipelineDatasetTypes.initOutputs` or
962 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
963 """
965 inputs: NamedValueSet[DatasetType]
966 """Dataset types that are regular inputs to this Task.
968 If an input dataset needed for a Quantum cannot be found in the input
969 collection(s) or produced by another Task in the Pipeline, that Quantum
970 (and all dependent Quanta) will not be produced.
972 Task-level `inputs` may be classified as either
973 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
974 at the Pipeline level.
975 """
977 queryConstraints: NamedValueSet[DatasetType]
978 """Regular inputs that should not be used as constraints on the initial
979 QuantumGraph generation data ID query, according to their tasks
980 (`NamedValueSet`).
981 """
983 prerequisites: NamedValueSet[DatasetType]
984 """Dataset types that are prerequisite inputs to this Task.
986 Prerequisite inputs must exist in the input collection(s) before the
987 pipeline is run, but do not constrain the graph - if a prerequisite is
988 missing for a Quantum, `PrerequisiteMissingError` is raised.
990 Prerequisite inputs are not resolved until the second stage of
991 QuantumGraph generation.
992 """
994 outputs: NamedValueSet[DatasetType]
995 """Dataset types that are produced by this Task.
997 Task-level `outputs` may be classified as either
998 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
999 at the Pipeline level.
1000 """
1002 @classmethod
1003 def fromTaskDef(
1004 cls,
1005 taskDef: TaskDef,
1006 *,
1007 registry: Registry,
1008 include_configs: bool = True,
1009 storage_class_mapping: Mapping[str, str] | None = None,
1010 ) -> TaskDatasetTypes:
1011 """Extract and classify the dataset types from a single `PipelineTask`.
1013 Parameters
1014 ----------
1015 taskDef: `TaskDef`
1016 An instance of a `TaskDef` class for a particular `PipelineTask`.
1017 registry: `Registry`
1018 Registry used to construct normalized
1019 `~lsst.daf.butler.DatasetType` objects and retrieve those that are
1020 incomplete.
1021 include_configs : `bool`, optional
1022 If `True` (default) include config dataset types as
1023 ``initOutputs``.
1024 storage_class_mapping : `~collections.abc.Mapping` of `str` to \
1025 `~lsst.daf.butler.StorageClass`, optional
1026 If a taskdef contains a component dataset type that is unknown
1027 to the registry, its parent `~lsst.daf.butler.StorageClass` will
1028 be looked up in this mapping if it is supplied. If the mapping does
1029 not contain the composite dataset type, or the mapping is not
1030 supplied an exception will be raised.
1032 Returns
1033 -------
1034 types: `TaskDatasetTypes`
1035 The dataset types used by this task.
1037 Raises
1038 ------
1039 ValueError
1040 Raised if dataset type connection definition differs from
1041 registry definition.
1042 LookupError
1043 Raised if component parent StorageClass could not be determined
1044 and storage_class_mapping does not contain the composite type, or
1045 is set to None.
1046 """
1048 def makeDatasetTypesSet(
1049 connectionType: str,
1050 is_input: bool,
1051 freeze: bool = True,
1052 ) -> NamedValueSet[DatasetType]:
1053 """Construct a set of true `~lsst.daf.butler.DatasetType` objects.
1055 Parameters
1056 ----------
1057 connectionType : `str`
1058 Name of the connection type to produce a set for, corresponds
1059 to an attribute of type `list` on the connection class instance
1060 is_input : `bool`
1061 These are input dataset types, else they are output dataset
1062 types.
1063 freeze : `bool`, optional
1064 If `True`, call `NamedValueSet.freeze` on the object returned.
1066 Returns
1067 -------
1068 datasetTypes : `NamedValueSet`
1069 A set of all datasetTypes which correspond to the input
1070 connection type specified in the connection class of this
1071 `PipelineTask`
1073 Raises
1074 ------
1075 ValueError
1076 Raised if dataset type connection definition differs from
1077 registry definition.
1078 LookupError
1079 Raised if component parent StorageClass could not be determined
1080 and storage_class_mapping does not contain the composite type,
1081 or is set to None.
1083 Notes
1084 -----
1085 This function is a closure over the variables ``registry`` and
1086 ``taskDef``, and ``storage_class_mapping``.
1087 """
1088 datasetTypes = NamedValueSet[DatasetType]()
1089 for c in iterConnections(taskDef.connections, connectionType):
1090 dimensions = set(getattr(c, "dimensions", set()))
1091 if "skypix" in dimensions:
1092 try:
1093 datasetType = registry.getDatasetType(c.name)
1094 except LookupError as err:
1095 raise LookupError(
1096 f"DatasetType '{c.name}' referenced by "
1097 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
1098 "placeholder, but does not already exist in the registry. "
1099 "Note that reference catalog names are now used as the dataset "
1100 "type name instead of 'ref_cat'."
1101 ) from err
1102 rest1 = set(registry.dimensions.conform(dimensions - {"skypix"}).names)
1103 rest2 = datasetType.dimensions.names - datasetType.dimensions.skypix.names
1104 if rest1 != rest2:
1105 raise ValueError(
1106 f"Non-skypix dimensions for dataset type {c.name} declared in "
1107 f"connections ({rest1}) are inconsistent with those in "
1108 f"registry's version of this dataset ({rest2})."
1109 )
1110 else:
1111 # Component dataset types are not explicitly in the
1112 # registry. This complicates consistency checks with
1113 # registry and requires we work out the composite storage
1114 # class.
1115 registryDatasetType = None
1116 try:
1117 registryDatasetType = registry.getDatasetType(c.name)
1118 except KeyError:
1119 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
1120 if componentName:
1121 if storage_class_mapping is None or compositeName not in storage_class_mapping:
1122 raise LookupError(
1123 "Component parent class cannot be determined, and "
1124 "composite name was not in storage class mapping, or no "
1125 "storage_class_mapping was supplied"
1126 ) from None
1127 else:
1128 parentStorageClass = storage_class_mapping[compositeName]
1129 else:
1130 parentStorageClass = None
1131 datasetType = c.makeDatasetType(
1132 registry.dimensions, parentStorageClass=parentStorageClass
1133 )
1134 registryDatasetType = datasetType
1135 else:
1136 datasetType = c.makeDatasetType(
1137 registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass
1138 )
1140 if registryDatasetType and datasetType != registryDatasetType:
1141 # The dataset types differ but first check to see if
1142 # they are compatible before raising.
1143 if is_input:
1144 # This DatasetType must be compatible on get.
1145 is_compatible = datasetType.is_compatible_with(registryDatasetType)
1146 else:
1147 # Has to be able to be converted to expect type
1148 # on put.
1149 is_compatible = registryDatasetType.is_compatible_with(datasetType)
1150 if is_compatible:
1151 # For inputs we want the pipeline to use the
1152 # pipeline definition, for outputs it should use
1153 # the registry definition.
1154 if not is_input:
1155 datasetType = registryDatasetType
1156 _LOG.debug(
1157 "Dataset types differ (task %s != registry %s) but are compatible"
1158 " for %s in %s.",
1159 datasetType,
1160 registryDatasetType,
1161 "input" if is_input else "output",
1162 taskDef.label,
1163 )
1164 else:
1165 try:
1166 # Explicitly check for storage class just to
1167 # make more specific message.
1168 _ = datasetType.storageClass
1169 except KeyError:
1170 raise ValueError(
1171 "Storage class does not exist for supplied dataset type "
1172 f"{datasetType} for {taskDef.label}."
1173 ) from None
1174 raise ValueError(
1175 f"Supplied dataset type ({datasetType}) inconsistent with "
1176 f"registry definition ({registryDatasetType}) "
1177 f"for {taskDef.label}."
1178 )
1179 datasetTypes.add(datasetType)
1180 if freeze:
1181 datasetTypes.freeze()
1182 return datasetTypes
1184 # optionally add initOutput dataset for config
1185 initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False)
1186 if include_configs:
1187 initOutputs.add(
1188 DatasetType(
1189 taskDef.configDatasetName,
1190 registry.dimensions.empty,
1191 storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
1192 )
1193 )
1194 initOutputs.freeze()
1196 # optionally add output dataset for metadata
1197 outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False)
1199 # Metadata is supposed to be of the TaskMetadata type, its dimensions
1200 # correspond to a task quantum.
1201 dimensions = registry.dimensions.conform(taskDef.connections.dimensions)
1203 # Allow the storage class definition to be read from the existing
1204 # dataset type definition if present.
1205 try:
1206 current = registry.getDatasetType(taskDef.metadataDatasetName)
1207 except KeyError:
1208 # No previous definition so use the default.
1209 storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS
1210 else:
1211 storageClass = current.storageClass.name
1212 outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
1214 if taskDef.logOutputDatasetName is not None:
1215 # Log output dimensions correspond to a task quantum.
1216 dimensions = registry.dimensions.conform(taskDef.connections.dimensions)
1217 outputs.update(
1218 {
1219 DatasetType(
1220 taskDef.logOutputDatasetName,
1221 dimensions,
1222 acc.LOG_OUTPUT_STORAGE_CLASS,
1223 )
1224 }
1225 )
1227 outputs.freeze()
1229 inputs = makeDatasetTypesSet("inputs", is_input=True)
1230 queryConstraints = NamedValueSet(
1231 inputs[c.name]
1232 for c in cast(Iterable[Input], iterConnections(taskDef.connections, "inputs"))
1233 if not c.deferGraphConstraint
1234 )
1236 return cls(
1237 initInputs=makeDatasetTypesSet("initInputs", is_input=True),
1238 initOutputs=initOutputs,
1239 inputs=inputs,
1240 queryConstraints=queryConstraints,
1241 prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True),
1242 outputs=outputs,
1243 )
1246@dataclass(frozen=True)
1247class PipelineDatasetTypes:
1248 """An immutable struct that classifies the dataset types used in a
1249 `Pipeline`.
1250 """
1252 packagesDatasetName: ClassVar[str] = acc.PACKAGES_INIT_OUTPUT_NAME
1253 """Name of a dataset type used to save package versions.
1254 """
1256 initInputs: NamedValueSet[DatasetType]
1257 """Dataset types that are needed as inputs in order to construct the Tasks
1258 in this Pipeline.
1260 This does not include dataset types that are produced when constructing
1261 other Tasks in the Pipeline (these are classified as `initIntermediates`).
1262 """
1264 initOutputs: NamedValueSet[DatasetType]
1265 """Dataset types that may be written after constructing the Tasks in this
1266 Pipeline.
1268 This does not include dataset types that are also used as inputs when
1269 constructing other Tasks in the Pipeline (these are classified as
1270 `initIntermediates`).
1271 """
1273 initIntermediates: NamedValueSet[DatasetType]
1274 """Dataset types that are both used when constructing one or more Tasks
1275 in the Pipeline and produced as a side-effect of constructing another
1276 Task in the Pipeline.
1277 """
1279 inputs: NamedValueSet[DatasetType]
1280 """Dataset types that are regular inputs for the full pipeline.
1282 If an input dataset needed for a Quantum cannot be found in the input
1283 collection(s), that Quantum (and all dependent Quanta) will not be
1284 produced.
1285 """
1287 queryConstraints: NamedValueSet[DatasetType]
1288 """Regular inputs that should be used as constraints on the initial
1289 QuantumGraph generation data ID query, according to their tasks
1290 (`NamedValueSet`).
1291 """
1293 prerequisites: NamedValueSet[DatasetType]
1294 """Dataset types that are prerequisite inputs for the full Pipeline.
1296 Prerequisite inputs must exist in the input collection(s) before the
1297 pipeline is run, but do not constrain the graph - if a prerequisite is
1298 missing for a Quantum, `PrerequisiteMissingError` is raised.
1300 Prerequisite inputs are not resolved until the second stage of
1301 QuantumGraph generation.
1302 """
1304 intermediates: NamedValueSet[DatasetType]
1305 """Dataset types that are output by one Task in the Pipeline and consumed
1306 as inputs by one or more other Tasks in the Pipeline.
1307 """
1309 outputs: NamedValueSet[DatasetType]
1310 """Dataset types that are output by a Task in the Pipeline and not consumed
1311 by any other Task in the Pipeline.
1312 """
1314 byTask: Mapping[str, TaskDatasetTypes]
1315 """Per-Task dataset types, keyed by label in the `Pipeline`.
1317 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
1318 neither has been modified since the dataset types were extracted, of
1319 course).
1320 """
1322 @classmethod
1323 def fromPipeline(
1324 cls,
1325 pipeline: Pipeline | Iterable[TaskDef],
1326 *,
1327 registry: Registry,
1328 include_configs: bool = True,
1329 include_packages: bool = True,
1330 ) -> PipelineDatasetTypes:
1331 """Extract and classify the dataset types from all tasks in a
1332 `Pipeline`.
1334 Parameters
1335 ----------
1336 pipeline: `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ]
1337 A collection of tasks that can be run together.
1338 registry: `Registry`
1339 Registry used to construct normalized
1340 `~lsst.daf.butler.DatasetType` objects and retrieve those that are
1341 incomplete.
1342 include_configs : `bool`, optional
1343 If `True` (default) include config dataset types as
1344 ``initOutputs``.
1345 include_packages : `bool`, optional
1346 If `True` (default) include the dataset type for software package
1347 versions in ``initOutputs``.
1349 Returns
1350 -------
1351 types: `PipelineDatasetTypes`
1352 The dataset types used by this `Pipeline`.
1354 Raises
1355 ------
1356 ValueError
1357 Raised if Tasks are inconsistent about which datasets are marked
1358 prerequisite. This indicates that the Tasks cannot be run as part
1359 of the same `Pipeline`.
1360 """
1361 allInputs = NamedValueSet[DatasetType]()
1362 allOutputs = NamedValueSet[DatasetType]()
1363 allInitInputs = NamedValueSet[DatasetType]()
1364 allInitOutputs = NamedValueSet[DatasetType]()
1365 prerequisites = NamedValueSet[DatasetType]()
1366 queryConstraints = NamedValueSet[DatasetType]()
1367 byTask = dict()
1368 if include_packages:
1369 allInitOutputs.add(
1370 DatasetType(
1371 cls.packagesDatasetName,
1372 registry.dimensions.empty,
1373 storageClass=acc.PACKAGES_INIT_OUTPUT_STORAGE_CLASS,
1374 )
1375 )
1376 # create a list of TaskDefs in case the input is a generator
1377 pipeline = list(pipeline)
1379 # collect all the output dataset types
1380 typeStorageclassMap: dict[str, str] = {}
1381 for taskDef in pipeline:
1382 for outConnection in iterConnections(taskDef.connections, "outputs"):
1383 typeStorageclassMap[outConnection.name] = outConnection.storageClass
1385 for taskDef in pipeline:
1386 thisTask = TaskDatasetTypes.fromTaskDef(
1387 taskDef,
1388 registry=registry,
1389 include_configs=include_configs,
1390 storage_class_mapping=typeStorageclassMap,
1391 )
1392 allInitInputs.update(thisTask.initInputs)
1393 allInitOutputs.update(thisTask.initOutputs)
1394 allInputs.update(thisTask.inputs)
1395 # Inputs are query constraints if any task considers them a query
1396 # constraint.
1397 queryConstraints.update(thisTask.queryConstraints)
1398 prerequisites.update(thisTask.prerequisites)
1399 allOutputs.update(thisTask.outputs)
1400 byTask[taskDef.label] = thisTask
1401 if not prerequisites.isdisjoint(allInputs):
1402 raise ValueError(
1403 "{} marked as both prerequisites and regular inputs".format(
1404 {dt.name for dt in allInputs & prerequisites}
1405 )
1406 )
1407 if not prerequisites.isdisjoint(allOutputs):
1408 raise ValueError(
1409 "{} marked as both prerequisites and outputs".format(
1410 {dt.name for dt in allOutputs & prerequisites}
1411 )
1412 )
1413 # Make sure that components which are marked as inputs get treated as
1414 # intermediates if there is an output which produces the composite
1415 # containing the component
1416 intermediateComponents = NamedValueSet[DatasetType]()
1417 intermediateComposites = NamedValueSet[DatasetType]()
1418 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
1419 for dsType in allInputs:
1420 # get the name of a possible component
1421 name, component = dsType.nameAndComponent()
1422 # if there is a component name, that means this is a component
1423 # DatasetType, if there is an output which produces the parent of
1424 # this component, treat this input as an intermediate
1425 if component is not None:
1426 # This needs to be in this if block, because someone might have
1427 # a composite that is a pure input from existing data
1428 if name in outputNameMapping:
1429 intermediateComponents.add(dsType)
1430 intermediateComposites.add(outputNameMapping[name])
1432 def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None:
1433 common = a.names & b.names
1434 for name in common:
1435 # Any compatibility is allowed. This function does not know
1436 # if a dataset type is to be used for input or output.
1437 if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])):
1438 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
1440 checkConsistency(allInitInputs, allInitOutputs)
1441 checkConsistency(allInputs, allOutputs)
1442 checkConsistency(allInputs, intermediateComposites)
1443 checkConsistency(allOutputs, intermediateComposites)
1445 def frozen(s: Set[DatasetType]) -> NamedValueSet[DatasetType]:
1446 assert isinstance(s, NamedValueSet)
1447 s.freeze()
1448 return s
1450 inputs = frozen(allInputs - allOutputs - intermediateComponents)
1452 return cls(
1453 initInputs=frozen(allInitInputs - allInitOutputs),
1454 initIntermediates=frozen(allInitInputs & allInitOutputs),
1455 initOutputs=frozen(allInitOutputs - allInitInputs),
1456 inputs=inputs,
1457 queryConstraints=frozen(queryConstraints & inputs),
1458 # If there are storage class differences in inputs and outputs
1459 # the intermediates have to choose priority. Here choose that
1460 # inputs to tasks much match the requested storage class by
1461 # applying the inputs over the top of the outputs.
1462 intermediates=frozen(allOutputs & allInputs | intermediateComponents),
1463 outputs=frozen(allOutputs - allInputs - intermediateComposites),
1464 prerequisites=frozen(prerequisites),
1465 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
1466 )
1468 @classmethod
1469 def initOutputNames(
1470 cls,
1471 pipeline: Pipeline | Iterable[TaskDef],
1472 *,
1473 include_configs: bool = True,
1474 include_packages: bool = True,
1475 ) -> Iterator[str]:
1476 """Return the names of dataset types ot task initOutputs, Configs,
1477 and package versions for a pipeline.
1479 Parameters
1480 ----------
1481 pipeline: `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ]
1482 A `Pipeline` instance or collection of `TaskDef` instances.
1483 include_configs : `bool`, optional
1484 If `True` (default) include config dataset types.
1485 include_packages : `bool`, optional
1486 If `True` (default) include the dataset type for package versions.
1488 Yields
1489 ------
1490 datasetTypeName : `str`
1491 Name of the dataset type.
1492 """
1493 if include_packages:
1494 # Package versions dataset type
1495 yield cls.packagesDatasetName
1497 if isinstance(pipeline, Pipeline):
1498 pipeline = pipeline.toExpandedPipeline()
1500 for taskDef in pipeline:
1501 # all task InitOutputs
1502 for name in taskDef.connections.initOutputs:
1503 attribute = getattr(taskDef.connections, name)
1504 yield attribute.name
1506 # config dataset name
1507 if include_configs:
1508 yield taskDef.configDatasetName