Coverage for python / lsst / pipe / base / pipeline_graph / _tasks.py: 34%
315 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:59 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("TaskImportMode", "TaskInitNode", "TaskNode")
31import dataclasses
32import enum
33from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
34from typing import TYPE_CHECKING, Any, cast
36from lsst.daf.butler import (
37 DataCoordinate,
38 DatasetRef,
39 DatasetType,
40 DimensionGroup,
41 DimensionUniverse,
42 Registry,
43)
44from lsst.pex.config import FieldValidationError
45from lsst.utils.classes import immutable
46from lsst.utils.doImport import doImportType
47from lsst.utils.introspection import get_full_type_name
49from .. import automatic_connection_constants as acc
50from ..connections import PipelineTaskConnections
51from ..connectionTypes import BaseConnection, BaseInput, InitOutput, Output
52from ._edges import Edge, ReadEdge, WriteEdge
53from ._exceptions import TaskNotImportedError, UnresolvedGraphError
54from ._nodes import NodeKey, NodeType
56if TYPE_CHECKING:
57 from ..config import PipelineTaskConfig
58 from ..pipelineTask import PipelineTask
61class TaskImportMode(enum.Enum):
62 """Enumeration of the ways to handle importing tasks when reading a
63 serialized PipelineGraph.
64 """
66 DO_NOT_IMPORT = enum.auto()
67 """Do not import tasks or instantiate their configs and connections."""
69 REQUIRE_CONSISTENT_EDGES = enum.auto()
70 """Import tasks and instantiate their config and connection objects, and
71 check that the connections still define the same edges.
72 """
74 ASSUME_CONSISTENT_EDGES = enum.auto()
75 """Import tasks and instantiate their config and connection objects, but do
76 not check that the connections still define the same edges.
78 This is safe only when the caller knows the task definition has not changed
79 since the pipeline graph was persisted, such as when it was saved and
80 loaded with the same pipeline version.
81 """
83 OVERRIDE_EDGES = enum.auto()
84 """Import tasks and instantiate their config and connection objects, and
85 allow the edges defined in those connections to override those in the
86 persisted graph.
88 This may cause dataset type nodes to be unresolved, since resolutions
89 consistent with the original edges may be invalidated.
90 """
93@dataclasses.dataclass(frozen=True)
94class _TaskNodeImportedData:
95 """An internal struct that holds `TaskNode` and `TaskInitNode` state that
96 requires task classes to be imported.
97 """
99 task_class: type[PipelineTask]
100 """Type object for the task."""
102 config: PipelineTaskConfig
103 """Configuration object for the task."""
105 connection_map: dict[str, BaseConnection]
106 """Mapping from connection name to connection.
108 In addition to ``connections.allConnections``, this also holds the
109 "automatic" config, log, and metadata connections using the names defined
110 in the `.automatic_connection_constants` module.
111 """
113 connections: PipelineTaskConnections
114 """Configured connections object for the task."""
116 @classmethod
117 def configure(
118 cls,
119 label: str,
120 task_class: type[PipelineTask],
121 config: PipelineTaskConfig,
122 connections: PipelineTaskConnections | None = None,
123 ) -> _TaskNodeImportedData:
124 """Construct while creating a `PipelineTaskConnections` instance if
125 necessary.
127 Parameters
128 ----------
129 label : `str`
130 Label for the task in the pipeline. Only used in error messages.
131 task_class : `type` [ `.PipelineTask` ]
132 Pipeline task `type` object.
133 config : `.PipelineTaskConfig`
134 Configuration for the task.
135 connections : `.PipelineTaskConnections`, optional
136 Object that describes the dataset types used by the task. If not
137 provided, one will be constructed from the given configuration. If
138 provided, it is assumed that ``config`` has already been validated
139 and frozen.
141 Returns
142 -------
143 data : `_TaskNodeImportedData`
144 Instance of this struct.
145 """
146 if connections is None:
147 # If we don't have connections yet, assume the config hasn't been
148 # validated yet.
149 try:
150 config.validate()
151 except FieldValidationError as err:
152 err.fullname = f"{label}: {err.fullname}"
153 raise err
154 except Exception as err:
155 raise ValueError(
156 f"Configuration validation failed for task {label!r} (see chained exception)."
157 ) from err
158 config.freeze()
159 # MyPy doesn't see the metaclass attribute defined for this.
160 connections = config.ConnectionsClass(config=config) # type: ignore
161 connection_map = dict(connections.allConnections)
162 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput(
163 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
164 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
165 )
166 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output(
167 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
168 acc.METADATA_OUTPUT_STORAGE_CLASS,
169 dimensions=set(connections.dimensions),
170 )
171 if config.saveLogOutput:
172 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output(
173 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
174 acc.LOG_OUTPUT_STORAGE_CLASS,
175 dimensions=set(connections.dimensions),
176 )
177 return cls(task_class, config, connection_map, connections)
180@immutable
181class TaskInitNode:
182 """A node in a pipeline graph that represents the construction of a
183 `PipelineTask`.
185 Parameters
186 ----------
187 key : `NodeKey`
188 Key that identifies this node in internal and exported networkx graphs.
189 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
190 Graph edges that represent inputs required just to construct an
191 instance of this task, keyed by connection name.
192 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
193 Graph edges that represent outputs of this task that are available
194 after just constructing it, keyed by connection name.
196 This does not include the special `config_init_output` edge; use
197 `iter_all_outputs` to include that, too.
198 config_output : `WriteEdge`
199 The special init output edge that persists the task's configuration.
200 imported_data : `_TaskNodeImportedData`, optional
201 Internal struct that holds information that requires the task class to
202 have been be imported.
203 task_class_name : `str`, optional
204 Fully-qualified name of the task class. Must be provided if
205 ``imported_data`` is not.
206 config_str : `str`, optional
207 Configuration for the task as a string of override statements. Must be
208 provided if ``imported_data`` is not.
210 Notes
211 -----
212 When included in an exported `networkx` graph (e.g.
213 `PipelineGraph.make_xgraph`), task initialization nodes set the following
214 node attributes:
216 - ``task_class_name``
217 - ``bipartite`` (see `NodeType.bipartite`)
218 - ``task_class`` (only if `is_imported` is `True`)
219 - ``config`` (only if `is_imported` is `True`)
220 """
222 def __init__(
223 self,
224 key: NodeKey,
225 *,
226 inputs: Mapping[str, ReadEdge],
227 outputs: Mapping[str, WriteEdge],
228 config_output: WriteEdge,
229 imported_data: _TaskNodeImportedData | None = None,
230 task_class_name: str | None = None,
231 config_str: str | None = None,
232 ):
233 self.key = key
234 self.inputs = inputs
235 self.outputs = outputs
236 self.config_output = config_output
237 # Instead of setting attributes to None, we do not set them at all;
238 # this works better with the @immutable decorator, which supports
239 # deferred initialization but not reassignment.
240 if task_class_name is not None:
241 self._task_class_name = task_class_name
242 if config_str is not None:
243 self._config_str = config_str
244 if imported_data is not None:
245 self._imported_data = imported_data
246 else:
247 assert self._task_class_name is not None and self._config_str is not None, (
248 "If imported_data is not present, task_class_name and config_str must be."
249 )
251 key: NodeKey
252 """Key that identifies this node in internal and exported networkx graphs.
253 """
255 inputs: Mapping[str, ReadEdge]
256 """Graph edges that represent inputs required just to construct an instance
257 of this task, keyed by connection name.
258 """
260 outputs: Mapping[str, WriteEdge]
261 """Graph edges that represent outputs of this task that are available after
262 just constructing it, keyed by connection name.
264 This does not include the special `config_output` edge; use
265 `iter_all_outputs` to include that, too.
266 """
268 config_output: WriteEdge
269 """The special output edge that persists the task's configuration.
270 """
272 @property
273 def label(self) -> str:
274 """Label of this configuration of a task in the pipeline."""
275 return str(self.key)
277 @property
278 def is_imported(self) -> bool:
279 """Whether this the task type for this node has been imported and
280 its configuration overrides applied.
282 If this is `False`, the `task_class` and `config` attributes may not
283 be accessed.
284 """
285 return hasattr(self, "_imported_data")
287 @property
288 def task_class(self) -> type[PipelineTask]:
289 """Type object for the task.
291 Accessing this attribute when `is_imported` is `False` will raise
292 `TaskNotImportedError`, but accessing `task_class_name` will not.
293 """
294 return self._get_imported_data().task_class
296 @property
297 def task_class_name(self) -> str:
298 """The fully-qualified string name of the task class."""
299 try:
300 return self._task_class_name
301 except AttributeError:
302 pass
303 self._task_class_name = get_full_type_name(self.task_class)
304 return self._task_class_name
306 @property
307 def config(self) -> PipelineTaskConfig:
308 """Configuration for the task.
310 This is always frozen.
312 Accessing this attribute when `is_imported` is `False` will raise
313 `TaskNotImportedError`, but calling `get_config_str` will not.
314 """
315 return self._get_imported_data().config
317 def __repr__(self) -> str:
318 return f"{self.label} [init] ({self.task_class_name})"
320 def get_config_str(self) -> str:
321 """Return the configuration for this task as a string of override
322 statements.
324 Returns
325 -------
326 config_str : `str`
327 String containing configuration-overload statements.
328 """
329 try:
330 return self._config_str
331 except AttributeError:
332 pass
333 self._config_str = self.config.saveToString()
334 return self._config_str
336 def iter_all_inputs(self) -> Iterator[ReadEdge]:
337 """Iterate over all inputs required for construction.
339 This is the same as iteration over ``inputs.values()``, but it will be
340 updated to include any automatic init-input connections added in the
341 future, while `inputs` will continue to hold only task-defined init
342 inputs.
344 Yields
345 ------
346 `ReadEdge`
347 All the inputs required for construction.
348 """
349 return iter(self.inputs.values())
351 def iter_all_outputs(self) -> Iterator[WriteEdge]:
352 """Iterate over all outputs available after construction, including
353 special ones.
355 Yields
356 ------
357 `ReadEdge`
358 All the outputs available after construction.
359 """
360 yield from self.outputs.values()
361 yield self.config_output
363 def get_input_edge(self, connection_name: str) -> ReadEdge:
364 """Look up an input edge by connection name.
366 Parameters
367 ----------
368 connection_name : `str`
369 Name of the connection.
371 Returns
372 -------
373 edge : `ReadEdge`
374 Input edge.
375 """
376 return self.inputs[connection_name]
378 def get_output_edge(self, connection_name: str) -> WriteEdge:
379 """Look up an output edge by connection name.
381 Parameters
382 ----------
383 connection_name : `str`
384 Name of the connection.
386 Returns
387 -------
388 edge : `WriteEdge`
389 Output edge.
390 """
391 if connection_name == acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME:
392 return self.config_output
393 return self.outputs[connection_name]
395 def get_edge(self, connection_name: str) -> Edge:
396 """Look up an edge by connection name.
398 Parameters
399 ----------
400 connection_name : `str`
401 Name of the connection.
403 Returns
404 -------
405 edge : `Edge`
406 Edge.
407 """
408 try:
409 return self.get_input_edge(connection_name)
410 except KeyError:
411 pass
412 return self.get_output_edge(connection_name)
414 def diff_edges(self, other: TaskInitNode) -> list[str]:
415 """Compare the edges of this task initialization node to those from the
416 same task label in a different pipeline.
418 Parameters
419 ----------
420 other : `TaskInitNode`
421 Other node to compare to. Must have the same task label, but need
422 not have the same configuration or even the same task class.
424 Returns
425 -------
426 differences : `list` [ `str` ]
427 List of string messages describing differences between ``self`` and
428 ``other``. Will be empty if the two nodes have the same edges.
429 Messages will use 'A' to refer to ``self`` and 'B' to refer to
430 ``other``.
431 """
432 result = []
433 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input")
434 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output")
435 result += self.config_output.diff(other.config_output, "config init output")
436 return result
438 def _to_xgraph_state(self) -> dict[str, Any]:
439 """Convert this nodes's attributes into a dictionary suitable for use
440 in exported networkx graphs.
441 """
442 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
443 if hasattr(self, "_imported_data"):
444 result["task_class"] = self.task_class
445 result["config"] = self.config
446 return result
448 def _get_imported_data(self) -> _TaskNodeImportedData:
449 """Return the imported data struct.
451 Returns
452 -------
453 imported_data : `_TaskNodeImportedData`
454 Internal structure holding state that requires the task class to
455 have been imported.
457 Raises
458 ------
459 TaskNotImportedError
460 Raised if `is_imported` is `False`.
461 """
462 try:
463 return self._imported_data
464 except AttributeError:
465 raise TaskNotImportedError(
466 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported "
467 "(see PipelineGraph.import_and_configure)."
468 ) from None
470 @staticmethod
471 def _unreduce(kwargs: dict[str, Any]) -> TaskInitNode:
472 """Unpickle a `TaskInitNode` instance."""
473 # Connections classes are not pickleable, so we can't use the
474 # dataclass-provided pickle implementation of _TaskNodeImportedData,
475 # and it's easier to just call its `configure` method than to fix it.
476 if (imported_data_args := kwargs.pop("imported_data_args", None)) is not None:
477 imported_data = _TaskNodeImportedData.configure(*imported_data_args)
478 else:
479 imported_data = None
480 return TaskInitNode(imported_data=imported_data, **kwargs)
482 def __reduce__(self) -> tuple[Callable[[dict[str, Any]], TaskInitNode], tuple[dict[str, Any]]]:
483 kwargs = dict(
484 key=self.key,
485 inputs=self.inputs,
486 outputs=self.outputs,
487 config_output=self.config_output,
488 task_class_name=getattr(self, "_task_class_name", None),
489 config_str=getattr(self, "_config_str", None),
490 )
491 if hasattr(self, "_imported_data"):
492 kwargs["imported_data_args"] = (
493 self.label,
494 self.task_class,
495 self.config,
496 )
497 return (self._unreduce, (kwargs,))
500@immutable
501class TaskNode:
502 """A node in a pipeline graph that represents a labeled configuration of a
503 `PipelineTask`.
505 Parameters
506 ----------
507 key : `NodeKey`
508 Identifier for this node in networkx graphs.
509 init : `TaskInitNode`
510 Node representing the initialization of this task.
511 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
512 Graph edges that represent prerequisite inputs to this task, keyed by
513 connection name.
515 Prerequisite inputs must already exist in the data repository when a
516 `QuantumGraph` is built, but have more flexibility in how they are
517 looked up than regular inputs.
518 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
519 Graph edges that represent regular runtime inputs to this task, keyed
520 by connection name.
521 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
522 Graph edges that represent regular runtime outputs of this task, keyed
523 by connection name.
525 This does not include the special `log_output` and `metadata_output`
526 edges; use `iter_all_outputs` to include that, too.
527 log_output : `WriteEdge` or `None`
528 The special runtime output that persists the task's logs.
529 metadata_output : `WriteEdge`
530 The special runtime output that persists the task's metadata.
531 dimensions : `lsst.daf.butler.DimensionGroup` or `frozenset` [ `str` ]
532 Dimensions of the task. If a `frozenset`, the dimensions have not been
533 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely
534 compared to other sets of dimensions.
536 Notes
537 -----
538 Task nodes are intentionally not equality comparable, since there are many
539 different (and useful) ways to compare these objects with no clear winner
540 as the most obvious behavior.
542 When included in an exported `networkx` graph (e.g.
543 `PipelineGraph.make_xgraph`), task nodes set the following node attributes:
545 - ``task_class_name``
546 - ``bipartite`` (see `NodeType.bipartite`)
547 - ``task_class`` (only if `is_imported` is `True`)
548 - ``config`` (only if `is_imported` is `True`)
549 """
551 def __init__(
552 self,
553 key: NodeKey,
554 init: TaskInitNode,
555 *,
556 prerequisite_inputs: Mapping[str, ReadEdge],
557 inputs: Mapping[str, ReadEdge],
558 outputs: Mapping[str, WriteEdge],
559 log_output: WriteEdge | None,
560 metadata_output: WriteEdge,
561 dimensions: DimensionGroup | frozenset[str],
562 ):
563 self.key = key
564 self.init = init
565 self.prerequisite_inputs = prerequisite_inputs
566 self.inputs = inputs
567 self.outputs = outputs
568 self.log_output = log_output
569 self.metadata_output = metadata_output
570 self._dimensions = dimensions
572 @staticmethod
573 def _from_imported_data(
574 key: NodeKey,
575 init_key: NodeKey,
576 data: _TaskNodeImportedData,
577 universe: DimensionUniverse | None,
578 ) -> TaskNode:
579 """Construct from a `PipelineTask` type and its configuration.
581 Parameters
582 ----------
583 key : `NodeKey`
584 Identifier for this node in networkx graphs.
585 init_key : `TaskInitNode`
586 Node representing the initialization of this task.
587 data : `_TaskNodeImportedData`
588 Internal struct that holds information that requires the task class
589 to have been be imported.
590 universe : `lsst.daf.butler.DimensionUniverse` or `None`
591 Definitions of all dimensions.
593 Returns
594 -------
595 node : `TaskNode`
596 New task node.
598 Raises
599 ------
600 ValueError
601 Raised if configuration validation failed when constructing
602 ``connections``.
603 """
604 init_inputs = {
605 name: ReadEdge._from_connection_map(init_key, name, data.connection_map)
606 for name in data.connections.initInputs
607 }
608 prerequisite_inputs = {
609 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True)
610 for name in data.connections.prerequisiteInputs
611 }
612 inputs = {
613 name: ReadEdge._from_connection_map(key, name, data.connection_map)
614 for name in data.connections.inputs
615 if not getattr(data.connections, name).deferBinding
616 }
617 init_outputs = {
618 name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
619 for name in data.connections.initOutputs
620 }
621 outputs = {
622 name: WriteEdge._from_connection_map(key, name, data.connection_map)
623 for name in data.connections.outputs
624 }
625 init = TaskInitNode(
626 key=init_key,
627 inputs=init_inputs,
628 outputs=init_outputs,
629 config_output=WriteEdge._from_connection_map(
630 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map
631 ),
632 imported_data=data,
633 )
634 instance = TaskNode(
635 key=key,
636 init=init,
637 prerequisite_inputs=prerequisite_inputs,
638 inputs=inputs,
639 outputs=outputs,
640 log_output=(
641 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map)
642 if data.config.saveLogOutput
643 else None
644 ),
645 metadata_output=WriteEdge._from_connection_map(
646 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
647 ),
648 dimensions=(
649 frozenset(data.connections.dimensions)
650 if universe is None
651 else universe.conform(data.connections.dimensions)
652 ),
653 )
654 return instance
656 key: NodeKey
657 """Key that identifies this node in internal and exported networkx graphs.
658 """
660 prerequisite_inputs: Mapping[str, ReadEdge]
661 """Graph edges that represent prerequisite inputs to this task.
663 Prerequisite inputs must already exist in the data repository when a
664 `QuantumGraph` is built, but have more flexibility in how they are looked
665 up than regular inputs.
666 """
668 inputs: Mapping[str, ReadEdge]
669 """Graph edges that represent regular runtime inputs to this task.
670 """
672 outputs: Mapping[str, WriteEdge]
673 """Graph edges that represent regular runtime outputs of this task.
675 This does not include the special `log_output` and `metadata_output` edges;
676 use `iter_all_outputs` to include that, too.
677 """
679 log_output: WriteEdge | None
680 """The special runtime output that persists the task's logs.
681 """
683 metadata_output: WriteEdge
684 """The special runtime output that persists the task's metadata.
685 """
687 @property
688 def label(self) -> str:
689 """Label of this configuration of a task in the pipeline."""
690 return self.key.name
692 @property
693 def is_imported(self) -> bool:
694 """Whether this the task type for this node has been imported and
695 its configuration overrides applied.
697 If this is `False`, the `task_class` and `config` attributes may not
698 be accessed.
699 """
700 return self.init.is_imported
702 @property
703 def task_class(self) -> type[PipelineTask]:
704 """Type object for the task.
706 Accessing this attribute when `is_imported` is `False` will raise
707 `TaskNotImportedError`, but accessing `task_class_name` will not.
708 """
709 return self.init.task_class
711 @property
712 def task_class_name(self) -> str:
713 """The fully-qualified string name of the task class."""
714 return self.init.task_class_name
716 @property
717 def config(self) -> PipelineTaskConfig:
718 """Configuration for the task.
720 This is always frozen.
722 Accessing this attribute when `is_imported` is `False` will raise
723 `TaskNotImportedError`, but calling `get_config_str` will not.
724 """
725 return self.init.config
727 @property
728 def has_resolved_dimensions(self) -> bool:
729 """Whether the `dimensions` attribute may be accessed.
731 If `False`, the `raw_dimensions` attribute may be used to obtain a
732 set of dimension names that has not been resolved by a
733 `~lsst.daf.butler.DimensionsUniverse`.
734 """
735 return type(self._dimensions) is DimensionGroup
737 @property
738 def dimensions(self) -> DimensionGroup:
739 """Standardized dimensions of the task."""
740 if not self.has_resolved_dimensions:
741 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.")
742 return cast(DimensionGroup, self._dimensions)
744 @property
745 def raw_dimensions(self) -> frozenset[str]:
746 """Raw dimensions of the task, with standardization by a
747 `~lsst.daf.butler.DimensionUniverse` not guaranteed.
748 """
749 if self.has_resolved_dimensions:
750 return frozenset(cast(DimensionGroup, self._dimensions).names)
751 else:
752 return cast(frozenset[str], self._dimensions)
754 def __repr__(self) -> str:
755 if self.has_resolved_dimensions:
756 return f"{self.label} ({self.task_class_name}, {self.dimensions})"
757 else:
758 return f"{self.label} ({self.task_class_name})"
760 def get_config_str(self) -> str:
761 """Return the configuration for this task as a string of override
762 statements.
764 Returns
765 -------
766 config_str : `str`
767 String containing configuration-overload statements.
768 """
769 return self.init.get_config_str()
771 def iter_all_inputs(self) -> Iterator[ReadEdge]:
772 """Iterate over all runtime inputs, including both regular inputs and
773 prerequisites.
775 Yields
776 ------
777 `ReadEdge`
778 All the runtime inputs.
779 """
780 yield from self.prerequisite_inputs.values()
781 yield from self.inputs.values()
783 def iter_all_outputs(self) -> Iterator[WriteEdge]:
784 """Iterate over all runtime outputs, including special ones.
786 Yields
787 ------
788 `ReadEdge`
789 All the runtime outputs.
790 """
791 yield from self.outputs.values()
792 yield self.metadata_output
793 if self.log_output is not None:
794 yield self.log_output
796 def get_input_edge(self, connection_name: str) -> ReadEdge:
797 """Look up an input edge by connection name.
799 Parameters
800 ----------
801 connection_name : `str`
802 Name of the connection.
804 Returns
805 -------
806 edge : `ReadEdge`
807 Input edge.
808 """
809 if (edge := self.inputs.get(connection_name)) is not None:
810 return edge
811 return self.prerequisite_inputs[connection_name]
813 def get_output_edge(self, connection_name: str) -> WriteEdge:
814 """Look up an output edge by connection name.
816 Parameters
817 ----------
818 connection_name : `str`
819 Name of the connection.
821 Returns
822 -------
823 edge : `WriteEdge`
824 Output edge.
825 """
826 if connection_name == acc.METADATA_OUTPUT_CONNECTION_NAME:
827 return self.metadata_output
828 if connection_name == acc.LOG_OUTPUT_CONNECTION_NAME:
829 if self.log_output is None:
830 raise KeyError(connection_name)
831 return self.log_output
832 return self.outputs[connection_name]
834 def get_edge(self, connection_name: str) -> Edge:
835 """Look up an edge by connection name.
837 Parameters
838 ----------
839 connection_name : `str`
840 Name of the connection.
842 Returns
843 -------
844 edge : `Edge`
845 Edge.
846 """
847 try:
848 return self.get_input_edge(connection_name)
849 except KeyError:
850 pass
851 return self.get_output_edge(connection_name)
853 def diff_edges(self, other: TaskNode) -> list[str]:
854 """Compare the edges of this task node to those from the same task
855 label in a different pipeline.
857 This also calls `TaskInitNode.diff_edges`.
859 Parameters
860 ----------
861 other : `TaskInitNode`
862 Other node to compare to. Must have the same task label, but need
863 not have the same configuration or even the same task class.
865 Returns
866 -------
867 differences : `list` [ `str` ]
868 List of string messages describing differences between ``self`` and
869 ``other``. Will be empty if the two nodes have the same edges.
870 Messages will use 'A' to refer to ``self`` and 'B' to refer to
871 ``other``.
872 """
873 result = self.init.diff_edges(other.init)
874 result += _diff_edge_mapping(
875 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input"
876 )
877 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input")
878 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output")
879 if self.log_output is not None:
880 if other.log_output is not None:
881 result += self.log_output.diff(other.log_output, "log output")
882 else:
883 result.append("Log output is present in A, but not in B.")
884 elif other.log_output is not None:
885 result.append("Log output is present in B, but not in A.")
886 result += self.metadata_output.diff(other.metadata_output, "metadata output")
887 return result
889 def get_lookup_function(
890 self, connection_name: str
891 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None:
892 """Return the custom dataset query function for an edge, if one exists.
894 Parameters
895 ----------
896 connection_name : `str`
897 Name of the connection.
899 Returns
900 -------
901 lookup_function : `~collections.abc.Callable` or `None`
902 Callable that takes a dataset type, a butler registry, a data
903 coordinate (the quantum data ID), and an ordered list of
904 collections to search, and returns an iterable of
905 `~lsst.daf.butler.DatasetRef`.
906 """
907 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None)
909 def is_optional(self, connection_name: str) -> bool:
910 """Check whether the given connection has ``minimum==0``.
912 Parameters
913 ----------
914 connection_name : `str`
915 Name of the connection.
917 Returns
918 -------
919 optional : `bool`
920 Whether this task can run without any datasets for the given
921 connection.
922 """
923 connection = getattr(self.get_connections(), connection_name)
924 return isinstance(connection, BaseInput) and connection.minimum == 0
926 def get_connections(self) -> PipelineTaskConnections:
927 """Return the connections class instance for this task.
929 Returns
930 -------
931 connections : `.PipelineTaskConnections`
932 Task-provided object that defines inputs and outputs from
933 configuration.
934 """
935 return self._get_imported_data().connections
937 def get_spatial_bounds_connections(self) -> frozenset[str]:
938 """Return the names of connections whose data IDs should be included
939 in the calculation of the spatial bounds for this task's quanta.
941 Returns
942 -------
943 connection_names : `frozenset` [ `str` ]
944 Names of connections with spatial dimensions.
945 """
946 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections())
948 def get_temporal_bounds_connections(self) -> frozenset[str]:
949 """Return the names of connections whose data IDs should be included
950 in the calculation of the temporal bounds for this task's quanta.
952 Returns
953 -------
954 connection_names : `frozenset` [ `str` ]
955 Names of connections with temporal dimensions.
956 """
957 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections())
959 def _imported_and_configured(self, rebuild: bool) -> TaskNode:
960 """Import the task class and use it to construct a new instance.
962 Parameters
963 ----------
964 rebuild : `bool`
965 If `True`, import the task class and configure its connections to
966 generate new edges that may differ from the current ones. If
967 `False`, import the task class but just update the `task_class` and
968 `config` attributes, and assume the edges have not changed.
970 Returns
971 -------
972 node : `TaskNode`
973 Task node instance for which `is_imported` is `True`. Will be
974 ``self`` if this is the case already.
975 """
976 from ..pipelineTask import PipelineTask
978 if self.is_imported:
979 return self
980 task_class = doImportType(self.task_class_name)
981 if not issubclass(task_class, PipelineTask):
982 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
983 config = task_class.ConfigClass()
984 config.loadFromString(self.get_config_str())
985 return self._reconfigured(config, rebuild=rebuild, task_class=task_class)
987 def _reconfigured(
988 self,
989 config: PipelineTaskConfig,
990 rebuild: bool,
991 task_class: type[PipelineTask] | None = None,
992 ) -> TaskNode:
993 """Return a version of this node with new configuration.
995 Parameters
996 ----------
997 config : `.PipelineTaskConfig`
998 New configuration for the task.
999 rebuild : `bool`
1000 If `True`, use the configured connections to generate new edges
1001 that may differ from the current ones. If `False`, just update the
1002 `task_class` and `config` attributes, and assume the edges have not
1003 changed.
1004 task_class : `type` [ `PipelineTask` ], optional
1005 Subclass of `PipelineTask`. This defaults to ``self.task_class`,
1006 but may be passed as an argument if that is not available because
1007 the task class was not imported when ``self`` was constructed.
1009 Returns
1010 -------
1011 node : `TaskNode`
1012 Task node instance with the new config.
1013 """
1014 if task_class is None:
1015 task_class = self.task_class
1016 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config)
1017 if rebuild:
1018 return self._from_imported_data(
1019 self.key,
1020 self.init.key,
1021 imported_data,
1022 universe=self._dimensions.universe if type(self._dimensions) is DimensionGroup else None,
1023 )
1024 else:
1025 return TaskNode(
1026 self.key,
1027 TaskInitNode(
1028 self.init.key,
1029 inputs=self.init.inputs,
1030 outputs=self.init.outputs,
1031 config_output=self.init.config_output,
1032 imported_data=imported_data,
1033 ),
1034 prerequisite_inputs=self.prerequisite_inputs,
1035 inputs=self.inputs,
1036 outputs=self.outputs,
1037 log_output=self.log_output,
1038 metadata_output=self.metadata_output,
1039 dimensions=self._dimensions,
1040 )
1042 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
1043 """Return an otherwise-equivalent task node with resolved dimensions.
1045 Parameters
1046 ----------
1047 universe : `lsst.daf.butler.DimensionUniverse` or `None`
1048 Definitions for all dimensions.
1050 Returns
1051 -------
1052 node : `TaskNode`
1053 Task node instance with `dimensions` resolved by the given
1054 universe. Will be ``self`` if this is the case already.
1055 """
1056 if self.has_resolved_dimensions:
1057 if cast(DimensionGroup, self._dimensions).universe is universe:
1058 return self
1059 elif universe is None:
1060 return self
1061 return TaskNode(
1062 key=self.key,
1063 init=self.init,
1064 prerequisite_inputs=self.prerequisite_inputs,
1065 inputs=self.inputs,
1066 outputs=self.outputs,
1067 log_output=self.log_output,
1068 metadata_output=self.metadata_output,
1069 dimensions=(
1070 universe.conform(self.raw_dimensions) if universe is not None else self.raw_dimensions
1071 ),
1072 )
1074 def _to_xgraph_state(self) -> dict[str, Any]:
1075 """Convert this nodes's attributes into a dictionary suitable for use
1076 in exported networkx graphs.
1077 """
1078 result = self.init._to_xgraph_state()
1079 if self.has_resolved_dimensions:
1080 result["dimensions"] = self._dimensions
1081 result["raw_dimensions"] = self.raw_dimensions
1082 return result
1084 def _get_imported_data(self) -> _TaskNodeImportedData:
1085 """Return the imported data struct.
1087 Returns
1088 -------
1089 imported_data : `_TaskNodeImportedData`
1090 Internal structure holding state that requires the task class to
1091 have been imported.
1093 Raises
1094 ------
1095 TaskNotImportedError
1096 Raised if `is_imported` is `False`.
1097 """
1098 return self.init._get_imported_data()
1100 @staticmethod
1101 def _unreduce(kwargs: dict[str, Any]) -> TaskNode:
1102 """Unpickle a `TaskNode` instance."""
1103 return TaskNode(**kwargs)
1105 def __reduce__(self) -> tuple[Callable[[dict[str, Any]], TaskNode], tuple[dict[str, Any]]]:
1106 return (
1107 self._unreduce,
1108 (
1109 dict(
1110 key=self.key,
1111 init=self.init,
1112 prerequisite_inputs=self.prerequisite_inputs,
1113 inputs=self.inputs,
1114 outputs=self.outputs,
1115 log_output=self.log_output,
1116 metadata_output=self.metadata_output,
1117 dimensions=self._dimensions,
1118 ),
1119 ),
1120 )
1123def _diff_edge_mapping(
1124 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str
1125) -> list[str]:
1126 """Compare a pair of mappings of edges.
1128 Parameters
1129 ----------
1130 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
1131 First mapping to compare. Expected to have connection names as keys.
1132 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
1133 First mapping to compare. If keys differ from those of ``a_mapping``,
1134 this will be reported as a difference (in addition to element-wise
1135 comparisons).
1136 task_label : `str`
1137 Task label associated with both mappings.
1138 connection_type : `str`
1139 Type of connection (e.g. "input" or "init output") associated with both
1140 connections. This is a human-readable string to include in difference
1141 messages.
1143 Returns
1144 -------
1145 differences : `list` [ `str` ]
1146 List of string messages describing differences between the two
1147 mappings. Will be empty if the two mappings have the same edges.
1148 Messages will include "A" and "B", and are expected to be a preceded
1149 by a message describing what "A" and "B" are in the context in which
1150 this method is called.
1152 Notes
1153 -----
1154 This is expected to be used to compare one edge-holding mapping attribute
1155 of a task or task init node to the same attribute on another task or task
1156 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`,
1157 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`,
1158 `TaskInitNode.outputs`).
1159 """
1160 results = []
1161 b_to_do = set(b_mapping.keys())
1162 for connection_name, a_edge in a_mapping.items():
1163 if (b_edge := b_mapping.get(connection_name)) is None:
1164 results.append(
1165 f"{connection_type.capitalize()} {connection_name!r} of task "
1166 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
1167 )
1168 else:
1169 results.extend(a_edge.diff(b_edge, connection_type))
1170 b_to_do.discard(connection_name)
1171 for connection_name in b_to_do:
1172 results.append(
1173 f"{connection_type.capitalize()} {connection_name!r} of task "
1174 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
1175 )
1176 return results