Coverage for python/lsst/pipe/base/pipeline.py: 23%
468 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-22 11:04 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-22 11:04 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Module defining Pipeline class and related methods.
29"""
31from __future__ import annotations
33__all__ = ["Pipeline", "TaskDef", "TaskDatasetTypes", "PipelineDatasetTypes", "LabelSpecifier"]
35import copy
36import logging
37import re
38import urllib.parse
40# -------------------------------
41# Imports of standard modules --
42# -------------------------------
43from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Set
44from dataclasses import dataclass
45from types import MappingProxyType
46from typing import TYPE_CHECKING, ClassVar, cast
48# -----------------------------
49# Imports for other modules --
50from lsst.daf.butler import DataCoordinate, DatasetType, DimensionUniverse, NamedValueSet, Registry
51from lsst.resources import ResourcePath, ResourcePathExpression
52from lsst.utils import doImportType
53from lsst.utils.introspection import get_full_type_name
55from . import automatic_connection_constants as acc
56from . import pipeline_graph, pipelineIR
57from ._instrument import Instrument as PipeBaseInstrument
58from .config import PipelineTaskConfig
59from .connections import PipelineTaskConnections, iterConnections
60from .connectionTypes import Input
61from .pipelineTask import PipelineTask
63if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
64 from lsst.obs.base import Instrument
65 from lsst.pex.config import Config
67# ----------------------------------
68# Local non-exported definitions --
69# ----------------------------------
71_LOG = logging.getLogger(__name__)
73# ------------------------
74# Exported definitions --
75# ------------------------
78@dataclass
79class LabelSpecifier:
80 """A structure to specify a subset of labels to load.
82 This structure may contain a set of labels to be used in subsetting a
83 pipeline, or a beginning and end point. Beginning or end may be empty,
84 in which case the range will be a half open interval. Unlike python
85 iteration bounds, end bounds are *INCLUDED*. Note that range based
86 selection is not well defined for pipelines that are not linear in nature,
87 and correct behavior is not guaranteed, or may vary from run to run.
88 """
90 labels: set[str] | None = None
91 begin: str | None = None
92 end: str | None = None
94 def __post_init__(self) -> None:
95 if self.labels is not None and (self.begin or self.end):
96 raise ValueError(
97 "This struct can only be initialized with a labels set or a begin (and/or) end specifier"
98 )
101class TaskDef:
102 """TaskDef is a collection of information about task needed by Pipeline.
104 The information includes task name, configuration object and optional
105 task class. This class is just a collection of attributes and it exposes
106 all of them so that attributes could potentially be modified in place
107 (e.g. if configuration needs extra overrides).
109 Parameters
110 ----------
111 taskName : `str`, optional
112 The fully-qualified `PipelineTask` class name. If not provided,
113 ``taskClass`` must be.
114 config : `lsst.pipe.base.config.PipelineTaskConfig`, optional
115 Instance of the configuration class corresponding to this task class,
116 usually with all overrides applied. This config will be frozen. If
117 not provided, ``taskClass`` must be provided and
118 ``taskClass.ConfigClass()`` will be used.
119 taskClass : `type`, optional
120 `PipelineTask` class object; if provided and ``taskName`` is as well,
121 the caller guarantees that they are consistent. If not provided,
122 ``taskName`` is used to import the type.
123 label : `str`, optional
124 Task label, usually a short string unique in a pipeline. If not
125 provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
126 be used.
127 connections : `PipelineTaskConnections`, optional
128 Object that describes the dataset types used by the task. If not
129 provided, one will be constructed from the given configuration. If
130 provided, it is assumed that ``config`` has already been validated
131 and frozen.
132 """
134 def __init__(
135 self,
136 taskName: str | None = None,
137 config: PipelineTaskConfig | None = None,
138 taskClass: type[PipelineTask] | None = None,
139 label: str | None = None,
140 connections: PipelineTaskConnections | None = None,
141 ):
142 if taskName is None:
143 if taskClass is None:
144 raise ValueError("At least one of `taskName` and `taskClass` must be provided.")
145 taskName = get_full_type_name(taskClass)
146 elif taskClass is None:
147 taskClass = doImportType(taskName)
148 if config is None:
149 if taskClass is None:
150 raise ValueError("`taskClass` must be provided if `config` is not.")
151 config = taskClass.ConfigClass()
152 if label is None:
153 if taskClass is None:
154 raise ValueError("`taskClass` must be provided if `label` is not.")
155 label = taskClass._DefaultName
156 self.taskName = taskName
157 if connections is None:
158 # If we don't have connections yet, assume the config hasn't been
159 # validated yet.
160 try:
161 config.validate()
162 except Exception:
163 _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
164 raise
165 config.freeze()
166 connections = config.connections.ConnectionsClass(config=config)
167 self.config = config
168 self.taskClass = taskClass
169 self.label = label
170 self.connections = connections
172 @property
173 def configDatasetName(self) -> str:
174 """Name of a dataset type for configuration of this task (`str`)."""
175 return acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=self.label)
177 @property
178 def metadataDatasetName(self) -> str:
179 """Name of a dataset type for metadata of this task (`str`)."""
180 return self.makeMetadataDatasetName(self.label)
182 @classmethod
183 def makeMetadataDatasetName(cls, label: str) -> str:
184 """Construct the name of the dataset type for metadata for a task.
186 Parameters
187 ----------
188 label : `str`
189 Label for the task within its pipeline.
191 Returns
192 -------
193 name : `str`
194 Name of the task's metadata dataset type.
195 """
196 return acc.METADATA_OUTPUT_TEMPLATE.format(label=label)
198 @property
199 def logOutputDatasetName(self) -> str | None:
200 """Name of a dataset type for log output from this task, `None` if
201 logs are not to be saved (`str`).
202 """
203 if self.config.saveLogOutput:
204 return acc.LOG_OUTPUT_TEMPLATE.format(label=self.label)
205 else:
206 return None
208 def __str__(self) -> str:
209 rep = "TaskDef(" + self.taskName
210 if self.label:
211 rep += ", label=" + self.label
212 rep += ")"
213 return rep
215 def __eq__(self, other: object) -> bool:
216 if not isinstance(other, TaskDef):
217 return False
218 # This does not consider equality of configs when determining equality
219 # as config equality is a difficult thing to define. Should be updated
220 # after DM-27847
221 return self.taskClass == other.taskClass and self.label == other.label
223 def __hash__(self) -> int:
224 return hash((self.taskClass, self.label))
226 @classmethod
227 def _unreduce(cls, taskName: str, config: PipelineTaskConfig, label: str) -> TaskDef:
228 """Unpickle pickle. Custom callable for unpickling.
230 All arguments are forwarded directly to the constructor; this
231 trampoline is only needed because ``__reduce__`` callables can't be
232 called with keyword arguments.
233 """
234 return cls(taskName=taskName, config=config, label=label)
236 def __reduce__(self) -> tuple[Callable[[str, PipelineTaskConfig, str], TaskDef], tuple[str, Config, str]]:
237 return (self._unreduce, (self.taskName, self.config, self.label))
240class Pipeline:
241 """A `Pipeline` is a representation of a series of tasks to run, and the
242 configuration for those tasks.
244 Parameters
245 ----------
246 description : `str`
247 A description of that this pipeline does.
248 """
250 PipelineSubsetCtrl = pipelineIR.PipelineSubsetCtrl
252 def __init__(self, description: str):
253 pipeline_dict = {"description": description, "tasks": {}}
254 self._pipelineIR = pipelineIR.PipelineIR(pipeline_dict)
256 @classmethod
257 def fromFile(cls, filename: str) -> Pipeline:
258 """Load a pipeline defined in a pipeline yaml file.
260 Parameters
261 ----------
262 filename : `str`
263 A path that points to a pipeline defined in yaml format. This
264 filename may also supply additional labels to be used in
265 subsetting the loaded Pipeline. These labels are separated from
266 the path by a ``#``, and may be specified as a comma separated
267 list, or a range denoted as beginning..end. Beginning or end may
268 be empty, in which case the range will be a half open interval.
269 Unlike python iteration bounds, end bounds are *INCLUDED*. Note
270 that range based selection is not well defined for pipelines that
271 are not linear in nature, and correct behavior is not guaranteed,
272 or may vary from run to run.
274 Returns
275 -------
276 pipeline: `Pipeline`
277 The pipeline loaded from specified location with appropriate (if
278 any) subsetting.
280 Notes
281 -----
282 This method attempts to prune any contracts that contain labels which
283 are not in the declared subset of labels. This pruning is done using a
284 string based matching due to the nature of contracts and may prune more
285 than it should.
286 """
287 return cls.from_uri(filename)
289 @classmethod
290 def from_uri(cls, uri: ResourcePathExpression) -> Pipeline:
291 """Load a pipeline defined in a pipeline yaml file at a location
292 specified by a URI.
294 Parameters
295 ----------
296 uri : convertible to `~lsst.resources.ResourcePath`
297 If a string is supplied this should be a URI path that points to a
298 pipeline defined in yaml format, either as a direct path to the
299 yaml file, or as a directory containing a ``pipeline.yaml`` file
300 the form used by `write_to_uri` with ``expand=True``). This uri may
301 also supply additional labels to be used in subsetting the loaded
302 `Pipeline`. These labels are separated from the path by a ``#``,
303 and may be specified as a comma separated list, or a range denoted
304 as beginning..end. Beginning or end may be empty, in which case the
305 range will be a half open interval. Unlike python iteration bounds,
306 end bounds are *INCLUDED*. Note that range based selection is not
307 well defined for pipelines that are not linear in nature, and
308 correct behavior is not guaranteed, or may vary from run to run.
309 The same specifiers can be used with a
310 `~lsst.resources.ResourcePath` object, by being the sole contents
311 in the fragments attribute.
313 Returns
314 -------
315 pipeline : `Pipeline`
316 The pipeline loaded from specified location with appropriate (if
317 any) subsetting.
319 Notes
320 -----
321 This method attempts to prune any contracts that contain labels which
322 are not in the declared subset of labels. This pruning is done using a
323 string based matching due to the nature of contracts and may prune more
324 than it should.
325 """
326 # Split up the uri and any labels that were supplied
327 uri, label_specifier = cls._parse_file_specifier(uri)
328 pipeline: Pipeline = cls.fromIR(pipelineIR.PipelineIR.from_uri(uri))
330 # If there are labels supplied, only keep those
331 if label_specifier is not None:
332 pipeline = pipeline.subsetFromLabels(label_specifier)
333 return pipeline
335 def subsetFromLabels(
336 self,
337 labelSpecifier: LabelSpecifier,
338 subsetCtrl: pipelineIR.PipelineSubsetCtrl = PipelineSubsetCtrl.DROP,
339 ) -> Pipeline:
340 """Subset a pipeline to contain only labels specified in
341 ``labelSpecifier``.
343 Parameters
344 ----------
345 labelSpecifier : `labelSpecifier`
346 Object containing labels that describes how to subset a pipeline.
347 subsetCtrl : `PipelineSubsetCtrl`
348 Control object which decides how subsets with missing labels are
349 handled. Setting to `PipelineSubsetCtrl.DROP` (the default) will
350 cause any subsets that have labels which are not in the set of all
351 task labels to be dropped. Setting to `PipelineSubsetCtrl.EDIT`
352 will cause the subset to instead be edited to remove the
353 nonexistent label.
355 Returns
356 -------
357 pipeline : `Pipeline`
358 A new pipeline object that is a subset of the old pipeline.
360 Raises
361 ------
362 ValueError
363 Raised if there is an issue with specified labels
365 Notes
366 -----
367 This method attempts to prune any contracts that contain labels which
368 are not in the declared subset of labels. This pruning is done using a
369 string based matching due to the nature of contracts and may prune more
370 than it should.
371 """
372 # Labels supplied as a set
373 if labelSpecifier.labels:
374 labelSet = labelSpecifier.labels
375 # Labels supplied as a range, first create a list of all the labels
376 # in the pipeline sorted according to task dependency. Then only
377 # keep labels that lie between the supplied bounds
378 else:
379 # Create a copy of the pipeline to use when assessing the label
380 # ordering. Use a dict for fast searching while preserving order.
381 # Remove contracts so they do not fail in the expansion step. This
382 # is needed because a user may only configure the tasks they intend
383 # to run, which may cause some contracts to fail if they will later
384 # be dropped
385 pipeline = copy.deepcopy(self)
386 pipeline._pipelineIR.contracts = []
387 labels = {taskdef.label: True for taskdef in pipeline.toExpandedPipeline()}
389 # Verify the bounds are in the labels
390 if labelSpecifier.begin is not None:
391 if labelSpecifier.begin not in labels:
392 raise ValueError(
393 f"Beginning of range subset, {labelSpecifier.begin}, not found in pipeline definition"
394 )
395 if labelSpecifier.end is not None:
396 if labelSpecifier.end not in labels:
397 raise ValueError(
398 f"End of range subset, {labelSpecifier.end}, not found in pipeline definition"
399 )
401 labelSet = set()
402 for label in labels:
403 if labelSpecifier.begin is not None:
404 if label != labelSpecifier.begin:
405 continue
406 else:
407 labelSpecifier.begin = None
408 labelSet.add(label)
409 if labelSpecifier.end is not None and label == labelSpecifier.end:
410 break
411 return Pipeline.fromIR(self._pipelineIR.subset_from_labels(labelSet, subsetCtrl))
413 @staticmethod
414 def _parse_file_specifier(uri: ResourcePathExpression) -> tuple[ResourcePath, LabelSpecifier | None]:
415 """Split appart a uri and any possible label subsets"""
416 if isinstance(uri, str):
417 # This is to support legacy pipelines during transition
418 uri, num_replace = re.subn("[:](?!\\/\\/)", "#", uri)
419 if num_replace:
420 raise ValueError(
421 f"The pipeline file {uri} seems to use the legacy :"
422 " to separate labels, please use # instead."
423 )
424 if uri.count("#") > 1:
425 raise ValueError("Only one set of labels is allowed when specifying a pipeline to load")
426 # Everything else can be converted directly to ResourcePath.
427 uri = ResourcePath(uri)
428 label_subset = uri.fragment or None
430 specifier: LabelSpecifier | None
431 if label_subset is not None:
432 label_subset = urllib.parse.unquote(label_subset)
433 args: dict[str, set[str] | str | None]
434 # labels supplied as a list
435 if "," in label_subset:
436 if ".." in label_subset:
437 raise ValueError(
438 "Can only specify a list of labels or a rangewhen loading a Pipline not both"
439 )
440 args = {"labels": set(label_subset.split(","))}
441 # labels supplied as a range
442 elif ".." in label_subset:
443 # Try to de-structure the labelSubset, this will fail if more
444 # than one range is specified
445 begin, end, *rest = label_subset.split("..")
446 if rest:
447 raise ValueError("Only one range can be specified when loading a pipeline")
448 args = {"begin": begin if begin else None, "end": end if end else None}
449 # Assume anything else is a single label
450 else:
451 args = {"labels": {label_subset}}
453 # MyPy doesn't like how cavalier kwarg construction is with types.
454 specifier = LabelSpecifier(**args) # type: ignore
455 else:
456 specifier = None
458 return uri, specifier
460 @classmethod
461 def fromString(cls, pipeline_string: str) -> Pipeline:
462 """Create a pipeline from string formatted as a pipeline document.
464 Parameters
465 ----------
466 pipeline_string : `str`
467 A string that is formatted according like a pipeline document.
469 Returns
470 -------
471 pipeline: `Pipeline`
472 The new pipeline.
473 """
474 pipeline = cls.fromIR(pipelineIR.PipelineIR.from_string(pipeline_string))
475 return pipeline
477 @classmethod
478 def fromIR(cls, deserialized_pipeline: pipelineIR.PipelineIR) -> Pipeline:
479 """Create a pipeline from an already created `PipelineIR` object.
481 Parameters
482 ----------
483 deserialized_pipeline : `PipelineIR`
484 An already created pipeline intermediate representation object.
486 Returns
487 -------
488 pipeline: `Pipeline`
489 The new pipeline.
490 """
491 pipeline = cls.__new__(cls)
492 pipeline._pipelineIR = deserialized_pipeline
493 return pipeline
495 @classmethod
496 def fromPipeline(cls, pipeline: Pipeline) -> Pipeline:
497 """Create a new pipeline by copying an already existing `Pipeline`.
499 Parameters
500 ----------
501 pipeline : `Pipeline`
502 An already created pipeline intermediate representation object.
504 Returns
505 -------
506 pipeline: `Pipeline`
507 The new pipeline.
508 """
509 return cls.fromIR(copy.deepcopy(pipeline._pipelineIR))
511 def __str__(self) -> str:
512 return str(self._pipelineIR)
514 def mergePipeline(self, pipeline: Pipeline) -> None:
515 """Merge another in-memory `Pipeline` object into this one.
517 This merges another pipeline into this object, as if it were declared
518 in the import block of the yaml definition of this pipeline. This
519 modifies this pipeline in place.
521 Parameters
522 ----------
523 pipeline : `Pipeline`
524 The `Pipeline` object that is to be merged into this object.
525 """
526 self._pipelineIR.merge_pipelines((pipeline._pipelineIR,))
528 def addLabelToSubset(self, subset: str, label: str) -> None:
529 """Add a task label from the specified subset.
531 Parameters
532 ----------
533 subset : `str`
534 The labeled subset to modify.
535 label : `str`
536 The task label to add to the specified subset.
538 Raises
539 ------
540 ValueError
541 Raised if the specified subset does not exist within the pipeline.
542 Raised if the specified label does not exist within the pipeline.
543 """
544 if label not in self._pipelineIR.tasks:
545 raise ValueError(f"Label {label} does not appear within the pipeline")
546 if subset not in self._pipelineIR.labeled_subsets:
547 raise ValueError(f"Subset {subset} does not appear within the pipeline")
548 self._pipelineIR.labeled_subsets[subset].subset.add(label)
550 def removeLabelFromSubset(self, subset: str, label: str) -> None:
551 """Remove a task label from the specified subset.
553 Parameters
554 ----------
555 subset : `str`
556 The labeled subset to modify.
557 label : `str`
558 The task label to remove from the specified subset.
560 Raises
561 ------
562 ValueError
563 Raised if the specified subset does not exist in the pipeline.
564 Raised if the specified label does not exist within the specified
565 subset.
566 """
567 if subset not in self._pipelineIR.labeled_subsets:
568 raise ValueError(f"Subset {subset} does not appear within the pipeline")
569 if label not in self._pipelineIR.labeled_subsets[subset].subset:
570 raise ValueError(f"Label {label} does not appear within the pipeline")
571 self._pipelineIR.labeled_subsets[subset].subset.remove(label)
573 def findSubsetsWithLabel(self, label: str) -> set[str]:
574 """Find any subsets which may contain the specified label.
576 This function returns the name of subsets which return the specified
577 label. May return an empty set if there are no subsets, or no subsets
578 containing the specified label.
580 Parameters
581 ----------
582 label : `str`
583 The task label to use in membership check.
585 Returns
586 -------
587 subsets : `set` of `str`
588 Returns a set (possibly empty) of subsets names which contain the
589 specified label.
591 Raises
592 ------
593 ValueError
594 Raised if the specified label does not exist within this pipeline.
595 """
596 results = set()
597 if label not in self._pipelineIR.tasks:
598 raise ValueError(f"Label {label} does not appear within the pipeline")
599 for subset in self._pipelineIR.labeled_subsets.values():
600 if label in subset.subset:
601 results.add(subset.label)
602 return results
604 @property
605 def subsets(self) -> MappingProxyType[str, set]:
606 """Returns a `MappingProxyType` where the keys are the labels of
607 labeled subsets in the `Pipeline` and the values are the set of task
608 labels contained within that subset.
609 """
610 return MappingProxyType(
611 {label: subsetIr.subset for label, subsetIr in self._pipelineIR.labeled_subsets.items()}
612 )
614 def addLabeledSubset(self, label: str, description: str, taskLabels: set[str]) -> None:
615 """Add a new labeled subset to the `Pipeline`.
617 Parameters
618 ----------
619 label : `str`
620 The label to assign to the subset.
621 description : `str`
622 A description of what the subset is for.
623 taskLabels : `set` [`str`]
624 The set of task labels to be associated with the labeled subset.
626 Raises
627 ------
628 ValueError
629 Raised if label already exists in the `Pipeline`.
630 Raised if a task label is not found within the `Pipeline`.
631 """
632 if label in self._pipelineIR.labeled_subsets.keys():
633 raise ValueError(f"Subset label {label} is already found within the Pipeline")
634 if extra := (taskLabels - self._pipelineIR.tasks.keys()):
635 raise ValueError(f"Task labels {extra} were not found within the Pipeline")
636 self._pipelineIR.labeled_subsets[label] = pipelineIR.LabeledSubset(label, taskLabels, description)
638 def removeLabeledSubset(self, label: str) -> None:
639 """Remove a labeled subset from the `Pipeline`.
641 Parameters
642 ----------
643 label : `str`
644 The label of the subset to remove from the `Pipeline`.
646 Raises
647 ------
648 ValueError
649 Raised if the label is not found within the `Pipeline`.
650 """
651 if label not in self._pipelineIR.labeled_subsets.keys():
652 raise ValueError(f"Subset label {label} was not found in the pipeline")
653 self._pipelineIR.labeled_subsets.pop(label)
655 def addInstrument(self, instrument: Instrument | str) -> None:
656 """Add an instrument to the pipeline, or replace an instrument that is
657 already defined.
659 Parameters
660 ----------
661 instrument : `~lsst.daf.butler.instrument.Instrument` or `str`
662 Either a derived class object of a `lsst.daf.butler.instrument` or
663 a string corresponding to a fully qualified
664 `lsst.daf.butler.instrument` name.
665 """
666 if isinstance(instrument, str):
667 pass
668 else:
669 # TODO: assume that this is a subclass of Instrument, no type
670 # checking
671 instrument = get_full_type_name(instrument)
672 self._pipelineIR.instrument = instrument
674 def getInstrument(self) -> str | None:
675 """Get the instrument from the pipeline.
677 Returns
678 -------
679 instrument : `str`, or None
680 The fully qualified name of a `lsst.obs.base.Instrument` subclass,
681 name, or None if the pipeline does not have an instrument.
682 """
683 return self._pipelineIR.instrument
685 def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate:
686 """Return a data ID with all dimension constraints embedded in the
687 pipeline.
689 Parameters
690 ----------
691 universe : `lsst.daf.butler.DimensionUniverse`
692 Object that defines all dimensions.
694 Returns
695 -------
696 data_id : `lsst.daf.butler.DataCoordinate`
697 Data ID with all dimension constraints embedded in the
698 pipeline.
699 """
700 instrument_class_name = self._pipelineIR.instrument
701 if instrument_class_name is not None:
702 instrument_class = cast(PipeBaseInstrument, doImportType(instrument_class_name))
703 if instrument_class is not None:
704 return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe)
705 return DataCoordinate.make_empty(universe)
707 def addTask(self, task: type[PipelineTask] | str, label: str) -> None:
708 """Add a new task to the pipeline, or replace a task that is already
709 associated with the supplied label.
711 Parameters
712 ----------
713 task : `PipelineTask` or `str`
714 Either a derived class object of a `PipelineTask` or a string
715 corresponding to a fully qualified `PipelineTask` name.
716 label : `str`
717 A label that is used to identify the `PipelineTask` being added.
718 """
719 if isinstance(task, str):
720 taskName = task
721 elif issubclass(task, PipelineTask):
722 taskName = get_full_type_name(task)
723 else:
724 raise ValueError(
725 "task must be either a child class of PipelineTask or a string containing"
726 " a fully qualified name to one"
727 )
728 if not label:
729 # in some cases (with command line-generated pipeline) tasks can
730 # be defined without label which is not acceptable, use task
731 # _DefaultName in that case
732 if isinstance(task, str):
733 task_class = cast(PipelineTask, doImportType(task))
734 label = task_class._DefaultName
735 self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)
737 def removeTask(self, label: str) -> None:
738 """Remove a task from the pipeline.
740 Parameters
741 ----------
742 label : `str`
743 The label used to identify the task that is to be removed.
745 Raises
746 ------
747 KeyError
748 If no task with that label exists in the pipeline.
749 """
750 self._pipelineIR.tasks.pop(label)
752 def addConfigOverride(self, label: str, key: str, value: object) -> None:
753 """Apply single config override.
755 Parameters
756 ----------
757 label : `str`
758 Label of the task.
759 key : `str`
760 Fully-qualified field name.
761 value : object
762 Value to be given to a field.
763 """
764 self._addConfigImpl(label, pipelineIR.ConfigIR(rest={key: value}))
766 def addConfigFile(self, label: str, filename: str) -> None:
767 """Add overrides from a specified file.
769 Parameters
770 ----------
771 label : `str`
772 The label used to identify the task associated with config to
773 modify.
774 filename : `str`
775 Path to the override file.
776 """
777 self._addConfigImpl(label, pipelineIR.ConfigIR(file=[filename]))
779 def addConfigPython(self, label: str, pythonString: str) -> None:
780 """Add Overrides by running a snippet of python code against a config.
782 Parameters
783 ----------
784 label : `str`
785 The label used to identity the task associated with config to
786 modify.
787 pythonString : `str`
788 A string which is valid python code to be executed. This is done
789 with config as the only local accessible value.
790 """
791 self._addConfigImpl(label, pipelineIR.ConfigIR(python=pythonString))
793 def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
794 if label == "parameters":
795 self._pipelineIR.parameters.mapping.update(newConfig.rest)
796 if newConfig.file:
797 raise ValueError("Setting parameters section with config file is not supported")
798 if newConfig.python:
799 raise ValueError("Setting parameters section using python block in unsupported")
800 return
801 if label not in self._pipelineIR.tasks:
802 raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
803 self._pipelineIR.tasks[label].add_or_update_config(newConfig)
805 def write_to_uri(self, uri: ResourcePathExpression) -> None:
806 """Write the pipeline to a file or directory.
808 Parameters
809 ----------
810 uri : convertible to `~lsst.resources.ResourcePath`
811 URI to write to; may have any scheme with
812 `~lsst.resources.ResourcePath` write support or no scheme for a
813 local file/directory. Should have a ``.yaml`` extension.
814 """
815 self._pipelineIR.write_to_uri(uri)
817 def to_graph(self, registry: Registry | None = None) -> pipeline_graph.PipelineGraph:
818 """Construct a pipeline graph from this pipeline.
820 Constructing a graph applies all configuration overrides, freezes all
821 configuration, checks all contracts, and checks for dataset type
822 consistency between tasks (as much as possible without access to a data
823 repository). It cannot be reversed.
825 Parameters
826 ----------
827 registry : `lsst.daf.butler.Registry`, optional
828 Data repository client. If provided, the graph's dataset types
829 and dimensions will be resolved (see `PipelineGraph.resolve`).
831 Returns
832 -------
833 graph : `pipeline_graph.PipelineGraph`
834 Representation of the pipeline as a graph.
835 """
836 instrument_class_name = self._pipelineIR.instrument
837 data_id = {}
838 if instrument_class_name is not None:
839 instrument_class: type[Instrument] = doImportType(instrument_class_name)
840 if instrument_class is not None:
841 data_id["instrument"] = instrument_class.getName()
842 graph = pipeline_graph.PipelineGraph(data_id=data_id)
843 graph.description = self._pipelineIR.description
844 for label in self._pipelineIR.tasks:
845 self._add_task_to_graph(label, graph)
846 if self._pipelineIR.contracts is not None:
847 label_to_config = {x.label: x.config for x in graph.tasks.values()}
848 for contract in self._pipelineIR.contracts:
849 # execute this in its own line so it can raise a good error
850 # message if there was problems with the eval
851 success = eval(contract.contract, None, label_to_config)
852 if not success:
853 extra_info = f": {contract.msg}" if contract.msg is not None else ""
854 raise pipelineIR.ContractError(
855 f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
856 )
857 for label, subset in self._pipelineIR.labeled_subsets.items():
858 graph.add_task_subset(
859 label, subset.subset, subset.description if subset.description is not None else ""
860 )
861 graph.sort()
862 if registry is not None:
863 graph.resolve(registry)
864 return graph
866 def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
867 r"""Return a generator of `TaskDef`\s which can be used to create
868 quantum graphs.
870 Returns
871 -------
872 generator : generator of `TaskDef`
873 The generator returned will be the sorted iterator of tasks which
874 are to be used in constructing a quantum graph.
876 Raises
877 ------
878 NotImplementedError
879 If a dataId is supplied in a config block. This is in place for
880 future use.
881 """
882 yield from self.to_graph()._iter_task_defs()
884 def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None:
885 """Add a single task from this pipeline to a pipeline graph that is
886 under construction.
888 Parameters
889 ----------
890 label : `str`
891 Label for the task to be added.
892 graph : `pipeline_graph.PipelineGraph`
893 Graph to add the task to.
894 """
895 if (taskIR := self._pipelineIR.tasks.get(label)) is None:
896 raise NameError(f"Label {label} does not appear in this pipeline")
897 taskClass: type[PipelineTask] = doImportType(taskIR.klass)
898 config = taskClass.ConfigClass()
899 instrument: PipeBaseInstrument | None = None
900 if (instrumentName := self._pipelineIR.instrument) is not None:
901 instrument_cls: type = doImportType(instrumentName)
902 instrument = instrument_cls()
903 config.applyConfigOverrides(
904 instrument,
905 getattr(taskClass, "_DefaultName", ""),
906 taskIR.config,
907 self._pipelineIR.parameters,
908 label,
909 )
910 graph.add_task(label, taskClass, config)
912 def __iter__(self) -> Generator[TaskDef, None, None]:
913 return self.toExpandedPipeline()
915 def __getitem__(self, item: str) -> TaskDef:
916 # Making a whole graph and then making a TaskDef from that is pretty
917 # backwards, but I'm hoping to deprecate this method shortly in favor
918 # of making the graph explicitly and working with its node objects.
919 graph = pipeline_graph.PipelineGraph()
920 self._add_task_to_graph(item, graph)
921 (result,) = graph._iter_task_defs()
922 return result
924 def __len__(self) -> int:
925 return len(self._pipelineIR.tasks)
927 def __eq__(self, other: object) -> bool:
928 if not isinstance(other, Pipeline):
929 return False
930 elif self._pipelineIR == other._pipelineIR:
931 # Shortcut: if the IR is the same, the expanded pipeline must be
932 # the same as well. But the converse is not true.
933 return True
934 else:
935 self_expanded = {td.label: (td.taskClass,) for td in self}
936 other_expanded = {td.label: (td.taskClass,) for td in other}
937 if self_expanded != other_expanded:
938 return False
939 # After DM-27847, we should compare configuration here, or better,
940 # delegated to TaskDef.__eq__ after making that compare configurations.
941 raise NotImplementedError(
942 "Pipelines cannot be compared because config instances cannot be compared; see DM-27847."
943 )
946@dataclass(frozen=True)
947class TaskDatasetTypes:
948 """An immutable struct that extracts and classifies the dataset types used
949 by a `PipelineTask`.
950 """
952 initInputs: NamedValueSet[DatasetType]
953 """Dataset types that are needed as inputs in order to construct this Task.
955 Task-level `initInputs` may be classified as either
956 `~PipelineDatasetTypes.initInputs` or
957 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
958 """
960 initOutputs: NamedValueSet[DatasetType]
961 """Dataset types that may be written after constructing this Task.
963 Task-level `initOutputs` may be classified as either
964 `~PipelineDatasetTypes.initOutputs` or
965 `~PipelineDatasetTypes.initIntermediates` at the Pipeline level.
966 """
968 inputs: NamedValueSet[DatasetType]
969 """Dataset types that are regular inputs to this Task.
971 If an input dataset needed for a Quantum cannot be found in the input
972 collection(s) or produced by another Task in the Pipeline, that Quantum
973 (and all dependent Quanta) will not be produced.
975 Task-level `inputs` may be classified as either
976 `~PipelineDatasetTypes.inputs` or `~PipelineDatasetTypes.intermediates`
977 at the Pipeline level.
978 """
980 queryConstraints: NamedValueSet[DatasetType]
981 """Regular inputs that should not be used as constraints on the initial
982 QuantumGraph generation data ID query, according to their tasks
983 (`NamedValueSet`).
984 """
986 prerequisites: NamedValueSet[DatasetType]
987 """Dataset types that are prerequisite inputs to this Task.
989 Prerequisite inputs must exist in the input collection(s) before the
990 pipeline is run, but do not constrain the graph - if a prerequisite is
991 missing for a Quantum, `PrerequisiteMissingError` is raised.
993 Prerequisite inputs are not resolved until the second stage of
994 QuantumGraph generation.
995 """
997 outputs: NamedValueSet[DatasetType]
998 """Dataset types that are produced by this Task.
1000 Task-level `outputs` may be classified as either
1001 `~PipelineDatasetTypes.outputs` or `~PipelineDatasetTypes.intermediates`
1002 at the Pipeline level.
1003 """
1005 @classmethod
1006 def fromTaskDef(
1007 cls,
1008 taskDef: TaskDef,
1009 *,
1010 registry: Registry,
1011 include_configs: bool = True,
1012 storage_class_mapping: Mapping[str, str] | None = None,
1013 ) -> TaskDatasetTypes:
1014 """Extract and classify the dataset types from a single `PipelineTask`.
1016 Parameters
1017 ----------
1018 taskDef : `TaskDef`
1019 An instance of a `TaskDef` class for a particular `PipelineTask`.
1020 registry : `Registry`
1021 Registry used to construct normalized
1022 `~lsst.daf.butler.DatasetType` objects and retrieve those that are
1023 incomplete.
1024 include_configs : `bool`, optional
1025 If `True` (default) include config dataset types as
1026 ``initOutputs``.
1027 storage_class_mapping : `~collections.abc.Mapping` of `str` to \
1028 `~lsst.daf.butler.StorageClass`, optional
1029 If a taskdef contains a component dataset type that is unknown
1030 to the registry, its parent `~lsst.daf.butler.StorageClass` will
1031 be looked up in this mapping if it is supplied. If the mapping does
1032 not contain the composite dataset type, or the mapping is not
1033 supplied an exception will be raised.
1035 Returns
1036 -------
1037 types: `TaskDatasetTypes`
1038 The dataset types used by this task.
1040 Raises
1041 ------
1042 ValueError
1043 Raised if dataset type connection definition differs from
1044 registry definition.
1045 LookupError
1046 Raised if component parent StorageClass could not be determined
1047 and storage_class_mapping does not contain the composite type, or
1048 is set to None.
1049 """
1051 def makeDatasetTypesSet(
1052 connectionType: str,
1053 is_input: bool,
1054 freeze: bool = True,
1055 ) -> NamedValueSet[DatasetType]:
1056 """Construct a set of true `~lsst.daf.butler.DatasetType` objects.
1058 Parameters
1059 ----------
1060 connectionType : `str`
1061 Name of the connection type to produce a set for, corresponds
1062 to an attribute of type `list` on the connection class
1063 instance.
1064 is_input : `bool`
1065 These are input dataset types, else they are output dataset
1066 types.
1067 freeze : `bool`, optional
1068 If `True`, call `NamedValueSet.freeze` on the object returned.
1070 Returns
1071 -------
1072 datasetTypes : `NamedValueSet`
1073 A set of all datasetTypes which correspond to the input
1074 connection type specified in the connection class of this
1075 `PipelineTask`.
1077 Raises
1078 ------
1079 ValueError
1080 Raised if dataset type connection definition differs from
1081 registry definition.
1082 LookupError
1083 Raised if component parent StorageClass could not be determined
1084 and storage_class_mapping does not contain the composite type,
1085 or is set to None.
1087 Notes
1088 -----
1089 This function is a closure over the variables ``registry`` and
1090 ``taskDef``, and ``storage_class_mapping``.
1091 """
1092 datasetTypes = NamedValueSet[DatasetType]()
1093 for c in iterConnections(taskDef.connections, connectionType):
1094 dimensions = set(getattr(c, "dimensions", set()))
1095 if "skypix" in dimensions:
1096 try:
1097 datasetType = registry.getDatasetType(c.name)
1098 except LookupError as err:
1099 raise LookupError(
1100 f"DatasetType '{c.name}' referenced by "
1101 f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
1102 "placeholder, but does not already exist in the registry. "
1103 "Note that reference catalog names are now used as the dataset "
1104 "type name instead of 'ref_cat'."
1105 ) from err
1106 rest1 = set(registry.dimensions.conform(dimensions - {"skypix"}).names)
1107 rest2 = datasetType.dimensions.names - datasetType.dimensions.skypix.names
1108 if rest1 != rest2:
1109 raise ValueError(
1110 f"Non-skypix dimensions for dataset type {c.name} declared in "
1111 f"connections ({rest1}) are inconsistent with those in "
1112 f"registry's version of this dataset ({rest2})."
1113 )
1114 else:
1115 # Component dataset types are not explicitly in the
1116 # registry. This complicates consistency checks with
1117 # registry and requires we work out the composite storage
1118 # class.
1119 registryDatasetType = None
1120 try:
1121 registryDatasetType = registry.getDatasetType(c.name)
1122 except KeyError:
1123 compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
1124 if componentName:
1125 if storage_class_mapping is None or compositeName not in storage_class_mapping:
1126 raise LookupError(
1127 "Component parent class cannot be determined, and "
1128 "composite name was not in storage class mapping, or no "
1129 "storage_class_mapping was supplied"
1130 ) from None
1131 else:
1132 parentStorageClass = storage_class_mapping[compositeName]
1133 else:
1134 parentStorageClass = None
1135 datasetType = c.makeDatasetType(
1136 registry.dimensions, parentStorageClass=parentStorageClass
1137 )
1138 registryDatasetType = datasetType
1139 else:
1140 datasetType = c.makeDatasetType(
1141 registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass
1142 )
1144 if registryDatasetType and datasetType != registryDatasetType:
1145 # The dataset types differ but first check to see if
1146 # they are compatible before raising.
1147 if is_input:
1148 # This DatasetType must be compatible on get.
1149 is_compatible = datasetType.is_compatible_with(registryDatasetType)
1150 else:
1151 # Has to be able to be converted to expect type
1152 # on put.
1153 is_compatible = registryDatasetType.is_compatible_with(datasetType)
1154 if is_compatible:
1155 # For inputs we want the pipeline to use the
1156 # pipeline definition, for outputs it should use
1157 # the registry definition.
1158 if not is_input:
1159 datasetType = registryDatasetType
1160 _LOG.debug(
1161 "Dataset types differ (task %s != registry %s) but are compatible"
1162 " for %s in %s.",
1163 datasetType,
1164 registryDatasetType,
1165 "input" if is_input else "output",
1166 taskDef.label,
1167 )
1168 else:
1169 try:
1170 # Explicitly check for storage class just to
1171 # make more specific message.
1172 _ = datasetType.storageClass
1173 except KeyError:
1174 raise ValueError(
1175 "Storage class does not exist for supplied dataset type "
1176 f"{datasetType} for {taskDef.label}."
1177 ) from None
1178 raise ValueError(
1179 f"Supplied dataset type ({datasetType}) inconsistent with "
1180 f"registry definition ({registryDatasetType}) "
1181 f"for {taskDef.label}."
1182 )
1183 datasetTypes.add(datasetType)
1184 if freeze:
1185 datasetTypes.freeze()
1186 return datasetTypes
1188 # optionally add initOutput dataset for config
1189 initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False)
1190 if include_configs:
1191 initOutputs.add(
1192 DatasetType(
1193 taskDef.configDatasetName,
1194 registry.dimensions.empty,
1195 storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
1196 )
1197 )
1198 initOutputs.freeze()
1200 # optionally add output dataset for metadata
1201 outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False)
1203 # Metadata is supposed to be of the TaskMetadata type, its dimensions
1204 # correspond to a task quantum.
1205 dimensions = registry.dimensions.conform(taskDef.connections.dimensions)
1207 # Allow the storage class definition to be read from the existing
1208 # dataset type definition if present.
1209 try:
1210 current = registry.getDatasetType(taskDef.metadataDatasetName)
1211 except KeyError:
1212 # No previous definition so use the default.
1213 storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS
1214 else:
1215 storageClass = current.storageClass.name
1216 outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
1218 if taskDef.logOutputDatasetName is not None:
1219 # Log output dimensions correspond to a task quantum.
1220 dimensions = registry.dimensions.conform(taskDef.connections.dimensions)
1221 outputs.update(
1222 {
1223 DatasetType(
1224 taskDef.logOutputDatasetName,
1225 dimensions,
1226 acc.LOG_OUTPUT_STORAGE_CLASS,
1227 )
1228 }
1229 )
1231 outputs.freeze()
1233 inputs = makeDatasetTypesSet("inputs", is_input=True)
1234 queryConstraints = NamedValueSet(
1235 inputs[c.name]
1236 for c in cast(Iterable[Input], iterConnections(taskDef.connections, "inputs"))
1237 if not c.deferGraphConstraint
1238 )
1240 return cls(
1241 initInputs=makeDatasetTypesSet("initInputs", is_input=True),
1242 initOutputs=initOutputs,
1243 inputs=inputs,
1244 queryConstraints=queryConstraints,
1245 prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True),
1246 outputs=outputs,
1247 )
1250@dataclass(frozen=True)
1251class PipelineDatasetTypes:
1252 """An immutable struct that classifies the dataset types used in a
1253 `Pipeline`.
1254 """
1256 packagesDatasetName: ClassVar[str] = acc.PACKAGES_INIT_OUTPUT_NAME
1257 """Name of a dataset type used to save package versions.
1258 """
1260 initInputs: NamedValueSet[DatasetType]
1261 """Dataset types that are needed as inputs in order to construct the Tasks
1262 in this Pipeline.
1264 This does not include dataset types that are produced when constructing
1265 other Tasks in the Pipeline (these are classified as `initIntermediates`).
1266 """
1268 initOutputs: NamedValueSet[DatasetType]
1269 """Dataset types that may be written after constructing the Tasks in this
1270 Pipeline.
1272 This does not include dataset types that are also used as inputs when
1273 constructing other Tasks in the Pipeline (these are classified as
1274 `initIntermediates`).
1275 """
1277 initIntermediates: NamedValueSet[DatasetType]
1278 """Dataset types that are both used when constructing one or more Tasks
1279 in the Pipeline and produced as a side-effect of constructing another
1280 Task in the Pipeline.
1281 """
1283 inputs: NamedValueSet[DatasetType]
1284 """Dataset types that are regular inputs for the full pipeline.
1286 If an input dataset needed for a Quantum cannot be found in the input
1287 collection(s), that Quantum (and all dependent Quanta) will not be
1288 produced.
1289 """
1291 queryConstraints: NamedValueSet[DatasetType]
1292 """Regular inputs that should be used as constraints on the initial
1293 QuantumGraph generation data ID query, according to their tasks
1294 (`NamedValueSet`).
1295 """
1297 prerequisites: NamedValueSet[DatasetType]
1298 """Dataset types that are prerequisite inputs for the full Pipeline.
1300 Prerequisite inputs must exist in the input collection(s) before the
1301 pipeline is run, but do not constrain the graph - if a prerequisite is
1302 missing for a Quantum, `PrerequisiteMissingError` is raised.
1304 Prerequisite inputs are not resolved until the second stage of
1305 QuantumGraph generation.
1306 """
1308 intermediates: NamedValueSet[DatasetType]
1309 """Dataset types that are output by one Task in the Pipeline and consumed
1310 as inputs by one or more other Tasks in the Pipeline.
1311 """
1313 outputs: NamedValueSet[DatasetType]
1314 """Dataset types that are output by a Task in the Pipeline and not consumed
1315 by any other Task in the Pipeline.
1316 """
1318 byTask: Mapping[str, TaskDatasetTypes]
1319 """Per-Task dataset types, keyed by label in the `Pipeline`.
1321 This is guaranteed to be zip-iterable with the `Pipeline` itself (assuming
1322 neither has been modified since the dataset types were extracted, of
1323 course).
1324 """
1326 @classmethod
1327 def fromPipeline(
1328 cls,
1329 pipeline: Pipeline | Iterable[TaskDef],
1330 *,
1331 registry: Registry,
1332 include_configs: bool = True,
1333 include_packages: bool = True,
1334 ) -> PipelineDatasetTypes:
1335 """Extract and classify the dataset types from all tasks in a
1336 `Pipeline`.
1338 Parameters
1339 ----------
1340 pipeline : `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ]
1341 A collection of tasks that can be run together.
1342 registry : `Registry`
1343 Registry used to construct normalized
1344 `~lsst.daf.butler.DatasetType` objects and retrieve those that are
1345 incomplete.
1346 include_configs : `bool`, optional
1347 If `True` (default) include config dataset types as
1348 ``initOutputs``.
1349 include_packages : `bool`, optional
1350 If `True` (default) include the dataset type for software package
1351 versions in ``initOutputs``.
1353 Returns
1354 -------
1355 types: `PipelineDatasetTypes`
1356 The dataset types used by this `Pipeline`.
1358 Raises
1359 ------
1360 ValueError
1361 Raised if Tasks are inconsistent about which datasets are marked
1362 prerequisite. This indicates that the Tasks cannot be run as part
1363 of the same `Pipeline`.
1364 """
1365 allInputs = NamedValueSet[DatasetType]()
1366 allOutputs = NamedValueSet[DatasetType]()
1367 allInitInputs = NamedValueSet[DatasetType]()
1368 allInitOutputs = NamedValueSet[DatasetType]()
1369 prerequisites = NamedValueSet[DatasetType]()
1370 queryConstraints = NamedValueSet[DatasetType]()
1371 byTask = dict()
1372 if include_packages:
1373 allInitOutputs.add(
1374 DatasetType(
1375 cls.packagesDatasetName,
1376 registry.dimensions.empty,
1377 storageClass=acc.PACKAGES_INIT_OUTPUT_STORAGE_CLASS,
1378 )
1379 )
1380 # create a list of TaskDefs in case the input is a generator
1381 pipeline = list(pipeline)
1383 # collect all the output dataset types
1384 typeStorageclassMap: dict[str, str] = {}
1385 for taskDef in pipeline:
1386 for outConnection in iterConnections(taskDef.connections, "outputs"):
1387 typeStorageclassMap[outConnection.name] = outConnection.storageClass
1389 for taskDef in pipeline:
1390 thisTask = TaskDatasetTypes.fromTaskDef(
1391 taskDef,
1392 registry=registry,
1393 include_configs=include_configs,
1394 storage_class_mapping=typeStorageclassMap,
1395 )
1396 allInitInputs.update(thisTask.initInputs)
1397 allInitOutputs.update(thisTask.initOutputs)
1398 allInputs.update(thisTask.inputs)
1399 # Inputs are query constraints if any task considers them a query
1400 # constraint.
1401 queryConstraints.update(thisTask.queryConstraints)
1402 prerequisites.update(thisTask.prerequisites)
1403 allOutputs.update(thisTask.outputs)
1404 byTask[taskDef.label] = thisTask
1405 if not prerequisites.isdisjoint(allInputs):
1406 raise ValueError(
1407 "{} marked as both prerequisites and regular inputs".format(
1408 {dt.name for dt in allInputs & prerequisites}
1409 )
1410 )
1411 if not prerequisites.isdisjoint(allOutputs):
1412 raise ValueError(
1413 "{} marked as both prerequisites and outputs".format(
1414 {dt.name for dt in allOutputs & prerequisites}
1415 )
1416 )
1417 # Make sure that components which are marked as inputs get treated as
1418 # intermediates if there is an output which produces the composite
1419 # containing the component
1420 intermediateComponents = NamedValueSet[DatasetType]()
1421 intermediateComposites = NamedValueSet[DatasetType]()
1422 outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
1423 for dsType in allInputs:
1424 # get the name of a possible component
1425 name, component = dsType.nameAndComponent()
1426 # if there is a component name, that means this is a component
1427 # DatasetType, if there is an output which produces the parent of
1428 # this component, treat this input as an intermediate
1429 if component is not None:
1430 # This needs to be in this if block, because someone might have
1431 # a composite that is a pure input from existing data
1432 if name in outputNameMapping:
1433 intermediateComponents.add(dsType)
1434 intermediateComposites.add(outputNameMapping[name])
1436 def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None:
1437 common = a.names & b.names
1438 for name in common:
1439 # Any compatibility is allowed. This function does not know
1440 # if a dataset type is to be used for input or output.
1441 if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])):
1442 raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")
1444 checkConsistency(allInitInputs, allInitOutputs)
1445 checkConsistency(allInputs, allOutputs)
1446 checkConsistency(allInputs, intermediateComposites)
1447 checkConsistency(allOutputs, intermediateComposites)
1449 def frozen(s: Set[DatasetType]) -> NamedValueSet[DatasetType]:
1450 assert isinstance(s, NamedValueSet)
1451 s.freeze()
1452 return s
1454 inputs = frozen(allInputs - allOutputs - intermediateComponents)
1456 return cls(
1457 initInputs=frozen(allInitInputs - allInitOutputs),
1458 initIntermediates=frozen(allInitInputs & allInitOutputs),
1459 initOutputs=frozen(allInitOutputs - allInitInputs),
1460 inputs=inputs,
1461 queryConstraints=frozen(queryConstraints & inputs),
1462 # If there are storage class differences in inputs and outputs
1463 # the intermediates have to choose priority. Here choose that
1464 # inputs to tasks much match the requested storage class by
1465 # applying the inputs over the top of the outputs.
1466 intermediates=frozen(allOutputs & allInputs | intermediateComponents),
1467 outputs=frozen(allOutputs - allInputs - intermediateComposites),
1468 prerequisites=frozen(prerequisites),
1469 byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
1470 )
1472 @classmethod
1473 def initOutputNames(
1474 cls,
1475 pipeline: Pipeline | Iterable[TaskDef],
1476 *,
1477 include_configs: bool = True,
1478 include_packages: bool = True,
1479 ) -> Iterator[str]:
1480 """Return the names of dataset types ot task initOutputs, Configs,
1481 and package versions for a pipeline.
1483 Parameters
1484 ----------
1485 pipeline : `Pipeline` or `~collections.abc.Iterable` [ `TaskDef` ]
1486 A `Pipeline` instance or collection of `TaskDef` instances.
1487 include_configs : `bool`, optional
1488 If `True` (default) include config dataset types.
1489 include_packages : `bool`, optional
1490 If `True` (default) include the dataset type for package versions.
1492 Yields
1493 ------
1494 datasetTypeName : `str`
1495 Name of the dataset type.
1496 """
1497 if include_packages:
1498 # Package versions dataset type
1499 yield cls.packagesDatasetName
1501 if isinstance(pipeline, Pipeline):
1502 pipeline = pipeline.toExpandedPipeline()
1504 for taskDef in pipeline:
1505 # all task InitOutputs
1506 for name in taskDef.connections.initOutputs:
1507 attribute = getattr(taskDef.connections, name)
1508 yield attribute.name
1510 # config dataset name
1511 if include_configs:
1512 yield taskDef.configDatasetName