Coverage for python/lsst/pipe/base/pipeline_graph/_tasks.py: 41%
254 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode")
25import dataclasses
26import enum
27from collections.abc import Iterator, Mapping
28from typing import TYPE_CHECKING, Any, cast
30from lsst.daf.butler import DimensionGraph, DimensionUniverse
31from lsst.utils.classes import immutable
32from lsst.utils.doImport import doImportType
33from lsst.utils.introspection import get_full_type_name
35from .. import automatic_connection_constants as acc
36from ..connections import PipelineTaskConnections
37from ..connectionTypes import BaseConnection, InitOutput, Output
38from ._edges import Edge, ReadEdge, WriteEdge
39from ._exceptions import TaskNotImportedError, UnresolvedGraphError
40from ._nodes import NodeKey, NodeType
42if TYPE_CHECKING:
43 from ..config import PipelineTaskConfig
44 from ..pipelineTask import PipelineTask
47class TaskImportMode(enum.Enum):
48 """Enumeration of the ways to handle importing tasks when reading a
49 serialized PipelineGraph.
50 """
52 DO_NOT_IMPORT = enum.auto()
53 """Do not import tasks or instantiate their configs and connections."""
55 REQUIRE_CONSISTENT_EDGES = enum.auto()
56 """Import tasks and instantiate their config and connection objects, and
57 check that the connections still define the same edges.
58 """
60 ASSUME_CONSISTENT_EDGES = enum.auto()
61 """Import tasks and instantiate their config and connection objects, but do
62 not check that the connections still define the same edges.
64 This is safe only when the caller knows the task definition has not changed
65 since the pipeline graph was persisted, such as when it was saved and
66 loaded with the same pipeline version.
67 """
69 OVERRIDE_EDGES = enum.auto()
70 """Import tasks and instantiate their config and connection objects, and
71 allow the edges defined in those connections to override those in the
72 persisted graph.
74 This may cause dataset type nodes to be unresolved, since resolutions
75 consistent with the original edges may be invalidated.
76 """
79@dataclasses.dataclass(frozen=True)
80class _TaskNodeImportedData:
81 """An internal struct that holds `TaskNode` and `TaskInitNode` state that
82 requires task classes to be imported.
83 """
85 task_class: type[PipelineTask]
86 """Type object for the task."""
88 config: PipelineTaskConfig
89 """Configuration object for the task."""
91 connection_map: dict[str, BaseConnection]
92 """Mapping from connection name to connection.
94 In addition to ``connections.allConnections``, this also holds the
95 "automatic" config, log, and metadata connections using the names defined
96 in the `.automatic_connection_constants` module.
97 """
99 connections: PipelineTaskConnections
100 """Configured connections object for the task."""
102 @classmethod
103 def configure(
104 cls,
105 label: str,
106 task_class: type[PipelineTask],
107 config: PipelineTaskConfig,
108 connections: PipelineTaskConnections | None = None,
109 ) -> _TaskNodeImportedData:
110 """Construct while creating a `PipelineTaskConnections` instance if
111 necessary.
113 Parameters
114 ----------
115 label : `str`
116 Label for the task in the pipeline. Only used in error messages.
117 task_class : `type` [ `.PipelineTask` ]
118 Pipeline task `type` object.
119 config : `.PipelineTaskConfig`
120 Configuration for the task.
121 connections : `.PipelineTaskConnections`, optional
122 Object that describes the dataset types used by the task. If not
123 provided, one will be constructed from the given configuration. If
124 provided, it is assumed that ``config`` has already been validated
125 and frozen.
127 Returns
128 -------
129 data : `_TaskNodeImportedData`
130 Instance of this struct.
131 """
132 if connections is None:
133 # If we don't have connections yet, assume the config hasn't been
134 # validated yet.
135 try:
136 config.validate()
137 except Exception as err:
138 raise ValueError(
139 f"Configuration validation failed for task {label!r} (see chained exception)."
140 ) from err
141 config.freeze()
142 connections = task_class.ConfigClass.ConnectionsClass(config=config)
143 connection_map = dict(connections.allConnections)
144 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput(
145 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
146 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
147 )
148 if not config.saveMetadata:
149 raise ValueError(f"Metadata for task {label} cannot be disabled.")
150 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output(
151 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
152 acc.METADATA_OUTPUT_STORAGE_CLASS,
153 dimensions=set(connections.dimensions),
154 )
155 if config.saveLogOutput:
156 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output(
157 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
158 acc.LOG_OUTPUT_STORAGE_CLASS,
159 dimensions=set(connections.dimensions),
160 )
161 return cls(task_class, config, connection_map, connections)
164@immutable
165class TaskInitNode:
166 """A node in a pipeline graph that represents the construction of a
167 `PipelineTask`.
169 Parameters
170 ----------
171 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
172 Graph edges that represent inputs required just to construct an
173 instance of this task, keyed by connection name.
174 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
175 Graph edges that represent outputs of this task that are available
176 after just constructing it, keyed by connection name.
178 This does not include the special `config_init_output` edge; use
179 `iter_all_outputs` to include that, too.
180 config_output : `WriteEdge`
181 The special init output edge that persists the task's configuration.
182 imported_data : `_TaskNodeImportedData`, optional
183 Internal struct that holds information that requires the task class to
184 have been be imported.
185 task_class_name : `str`, optional
186 Fully-qualified name of the task class. Must be provided if
187 ``imported_data`` is not.
188 config_str : `str`, optional
189 Configuration for the task as a string of override statements. Must be
190 provided if ``imported_data`` is not.
192 Notes
193 -----
194 When included in an exported `networkx` graph (e.g.
195 `PipelineGraph.make_xgraph`), task initialization nodes set the following
196 node attributes:
198 - ``task_class_name``
199 - ``bipartite`` (see `NodeType.bipartite`)
200 - ``task_class`` (only if `is_imported` is `True`)
201 - ``config`` (only if `is_importd` is `True`)
202 """
204 def __init__(
205 self,
206 key: NodeKey,
207 *,
208 inputs: Mapping[str, ReadEdge],
209 outputs: Mapping[str, WriteEdge],
210 config_output: WriteEdge,
211 imported_data: _TaskNodeImportedData | None = None,
212 task_class_name: str | None = None,
213 config_str: str | None = None,
214 ):
215 self.key = key
216 self.inputs = inputs
217 self.outputs = outputs
218 self.config_output = config_output
219 # Instead of setting attributes to None, we do not set them at all;
220 # this works better with the @immutable decorator, which supports
221 # deferred initialization but not reassignment.
222 if task_class_name is not None:
223 self._task_class_name = task_class_name
224 if config_str is not None:
225 self._config_str = config_str
226 if imported_data is not None:
227 self._imported_data = imported_data
228 else:
229 assert (
230 self._task_class_name is not None and self._config_str is not None
231 ), "If imported_data is not present, task_class_name and config_str must be."
233 key: NodeKey
234 """Key that identifies this node in internal and exported networkx graphs.
235 """
237 inputs: Mapping[str, ReadEdge]
238 """Graph edges that represent inputs required just to construct an instance
239 of this task, keyed by connection name.
240 """
242 outputs: Mapping[str, WriteEdge]
243 """Graph edges that represent outputs of this task that are available after
244 just constructing it, keyed by connection name.
246 This does not include the special `config_output` edge; use
247 `iter_all_outputs` to include that, too.
248 """
250 config_output: WriteEdge
251 """The special output edge that persists the task's configuration.
252 """
254 @property
255 def label(self) -> str:
256 """Label of this configuration of a task in the pipeline."""
257 return str(self.key)
259 @property
260 def is_imported(self) -> bool:
261 """Whether this the task type for this node has been imported and
262 its configuration overrides applied.
264 If this is `False`, the `task_class` and `config` attributes may not
265 be accessed.
266 """
267 return hasattr(self, "_imported_data")
269 @property
270 def task_class(self) -> type[PipelineTask]:
271 """Type object for the task.
273 Accessing this attribute when `is_imported` is `False` will raise
274 `TaskNotImportedError`, but accessing `task_class_name` will not.
275 """
276 return self._get_imported_data().task_class
278 @property
279 def task_class_name(self) -> str:
280 """The fully-qualified string name of the task class."""
281 try:
282 return self._task_class_name
283 except AttributeError:
284 pass
285 self._task_class_name = get_full_type_name(self.task_class)
286 return self._task_class_name
288 @property
289 def config(self) -> PipelineTaskConfig:
290 """Configuration for the task.
292 This is always frozen.
294 Accessing this attribute when `is_imported` is `False` will raise
295 `TaskNotImportedError`, but calling `get_config_str` will not.
296 """
297 return self._get_imported_data().config
299 def get_config_str(self) -> str:
300 """Return the configuration for this task as a string of override
301 statements.
303 Returns
304 -------
305 config_str : `str`
306 String containing configuration-overload statements.
307 """
308 try:
309 return self._config_str
310 except AttributeError:
311 pass
312 self._config_str = self.config.saveToString()
313 return self._config_str
315 def iter_all_inputs(self) -> Iterator[ReadEdge]:
316 """Iterate over all inputs required for construction.
318 This is the same as iteration over ``inputs.values()``, but it will be
319 updated to include any automatic init-input connections added in the
320 future, while `inputs` will continue to hold only task-defined init
321 inputs.
322 """
323 return iter(self.inputs.values())
325 def iter_all_outputs(self) -> Iterator[WriteEdge]:
326 """Iterate over all outputs available after construction, including
327 special ones.
328 """
329 yield from self.outputs.values()
330 yield self.config_output
332 def diff_edges(self, other: TaskInitNode) -> list[str]:
333 """Compare the edges of this task initialization node to those from the
334 same task label in a different pipeline.
336 Parameters
337 ----------
338 other : `TaskInitNode`
339 Other node to compare to. Must have the same task label, but need
340 not have the same configuration or even the same task class.
342 Returns
343 -------
344 differences : `list` [ `str` ]
345 List of string messages describing differences between ``self`` and
346 ``other``. Will be empty if the two nodes have the same edges.
347 Messages will use 'A' to refer to ``self`` and 'B' to refer to
348 ``other``.
349 """
350 result = []
351 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input")
352 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output")
353 result += self.config_output.diff(other.config_output, "config init output")
354 return result
356 def _to_xgraph_state(self) -> dict[str, Any]:
357 """Convert this nodes's attributes into a dictionary suitable for use
358 in exported networkx graphs.
359 """
360 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
361 if hasattr(self, "_imported_data"):
362 result["task_class"] = self.task_class
363 result["config"] = self.config
364 return result
366 def _get_imported_data(self) -> _TaskNodeImportedData:
367 """Return the imported data struct.
369 Returns
370 -------
371 imported_data : `_TaskNodeImportedData`
372 Internal structure holding state that requires the task class to
373 have been imported.
375 Raises
376 ------
377 TaskNotImportedError
378 Raised if `is_imported` is `False`.
379 """
380 try:
381 return self._imported_data
382 except AttributeError:
383 raise TaskNotImportedError(
384 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported "
385 "(see PipelineGraph.import_and_configure)."
386 ) from None
389@immutable
390class TaskNode:
391 """A node in a pipeline graph that represents a labeled configuration of a
392 `PipelineTask`.
394 Parameters
395 ----------
396 key : `NodeKey`
397 Identifier for this node in networkx graphs.
398 init : `TaskInitNode`
399 Node representing the initialization of this task.
400 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
401 Graph edges that represent prerequisite inputs to this task, keyed by
402 connection name.
404 Prerequisite inputs must already exist in the data repository when a
405 `QuantumGraph` is built, but have more flexibility in how they are
406 looked up than regular inputs.
407 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
408 Graph edges that represent regular runtime inputs to this task, keyed
409 by connection name.
410 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
411 Graph edges that represent regular runtime outputs of this task, keyed
412 by connection name.
414 This does not include the special `log_output` and `metadata_output`
415 edges; use `iter_all_outputs` to include that, too.
416 log_output : `WriteEdge` or `None`
417 The special runtime output that persists the task's logs.
418 metadata_output : `WriteEdge`
419 The special runtime output that persists the task's metadata.
420 dimensions : `lsst.daf.butler.DimensionGraph` or `frozenset`
421 Dimensions of the task. If a `frozenset`, the dimensions have not been
422 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely
423 compared to other sets of dimensions.
425 Notes
426 -----
427 Task nodes are intentionally not equality comparable, since there are many
428 different (and useful) ways to compare these objects with no clear winner
429 as the most obvious behavior.
431 When included in an exported `networkx` graph (e.g.
432 `PipelineGraph.make_xgraph`), task nodes set the following node attributes:
434 - ``task_class_name``
435 - ``bipartite`` (see `NodeType.bipartite`)
436 - ``task_class`` (only if `is_imported` is `True`)
437 - ``config`` (only if `is_importd` is `True`)
438 """
440 def __init__(
441 self,
442 key: NodeKey,
443 init: TaskInitNode,
444 *,
445 prerequisite_inputs: Mapping[str, ReadEdge],
446 inputs: Mapping[str, ReadEdge],
447 outputs: Mapping[str, WriteEdge],
448 log_output: WriteEdge | None,
449 metadata_output: WriteEdge,
450 dimensions: DimensionGraph | frozenset,
451 ):
452 self.key = key
453 self.init = init
454 self.prerequisite_inputs = prerequisite_inputs
455 self.inputs = inputs
456 self.outputs = outputs
457 self.log_output = log_output
458 self.metadata_output = metadata_output
459 self._dimensions = dimensions
461 @staticmethod
462 def _from_imported_data(
463 key: NodeKey,
464 init_key: NodeKey,
465 data: _TaskNodeImportedData,
466 universe: DimensionUniverse | None,
467 ) -> TaskNode:
468 """Construct from a `PipelineTask` type and its configuration.
470 Parameters
471 ----------
472 key : `NodeKey`
473 Identifier for this node in networkx graphs.
474 init : `TaskInitNode`
475 Node representing the initialization of this task.
476 data : `_TaskNodeImportedData`
477 Internal struct that holds information that requires the task class
478 to have been be imported.
479 universe : `lsst.daf.butler.DimensionUniverse` or `None`
480 Definitions of all dimensions.
482 Returns
483 -------
484 node : `TaskNode`
485 New task node.
487 Raises
488 ------
489 ValueError
490 Raised if configuration validation failed when constructing
491 ``connections``.
492 """
493 init_inputs = {
494 name: ReadEdge._from_connection_map(init_key, name, data.connection_map)
495 for name in data.connections.initInputs
496 }
497 prerequisite_inputs = {
498 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True)
499 for name in data.connections.prerequisiteInputs
500 }
501 inputs = {
502 name: ReadEdge._from_connection_map(key, name, data.connection_map)
503 for name in data.connections.inputs
504 }
505 init_outputs = {
506 name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
507 for name in data.connections.initOutputs
508 }
509 outputs = {
510 name: WriteEdge._from_connection_map(key, name, data.connection_map)
511 for name in data.connections.outputs
512 }
513 init = TaskInitNode(
514 key=init_key,
515 inputs=init_inputs,
516 outputs=init_outputs,
517 config_output=WriteEdge._from_connection_map(
518 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map
519 ),
520 imported_data=data,
521 )
522 instance = TaskNode(
523 key=key,
524 init=init,
525 prerequisite_inputs=prerequisite_inputs,
526 inputs=inputs,
527 outputs=outputs,
528 log_output=(
529 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map)
530 if data.config.saveLogOutput
531 else None
532 ),
533 metadata_output=WriteEdge._from_connection_map(
534 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
535 ),
536 dimensions=(
537 frozenset(data.connections.dimensions)
538 if universe is None
539 else universe.extract(data.connections.dimensions)
540 ),
541 )
542 return instance
544 key: NodeKey
545 """Key that identifies this node in internal and exported networkx graphs.
546 """
548 prerequisite_inputs: Mapping[str, ReadEdge]
549 """Graph edges that represent prerequisite inputs to this task.
551 Prerequisite inputs must already exist in the data repository when a
552 `QuantumGraph` is built, but have more flexibility in how they are looked
553 up than regular inputs.
554 """
556 inputs: Mapping[str, ReadEdge]
557 """Graph edges that represent regular runtime inputs to this task.
558 """
560 outputs: Mapping[str, WriteEdge]
561 """Graph edges that represent regular runtime outputs of this task.
563 This does not include the special `log_output` and `metadata_output` edges;
564 use `iter_all_outputs` to include that, too.
565 """
567 log_output: WriteEdge | None
568 """The special runtime output that persists the task's logs.
569 """
571 metadata_output: WriteEdge
572 """The special runtime output that persists the task's metadata.
573 """
575 @property
576 def label(self) -> str:
577 """Label of this configuration of a task in the pipeline."""
578 return self.key.name
580 @property
581 def is_imported(self) -> bool:
582 """Whether this the task type for this node has been imported and
583 its configuration overrides applied.
585 If this is `False`, the `task_class` and `config` attributes may not
586 be accessed.
587 """
588 return self.init.is_imported
590 @property
591 def task_class(self) -> type[PipelineTask]:
592 """Type object for the task.
594 Accessing this attribute when `is_imported` is `False` will raise
595 `TaskNotImportedError`, but accessing `task_class_name` will not.
596 """
597 return self.init.task_class
599 @property
600 def task_class_name(self) -> str:
601 """The fully-qualified string name of the task class."""
602 return self.init.task_class_name
604 @property
605 def config(self) -> PipelineTaskConfig:
606 """Configuration for the task.
608 This is always frozen.
610 Accessing this attribute when `is_imported` is `False` will raise
611 `TaskNotImportedError`, but calling `get_config_str` will not.
612 """
613 return self.init.config
615 @property
616 def has_resolved_dimensions(self) -> bool:
617 """Whether the `dimensions` attribute may be accessed.
619 If `False`, the `raw_dimensions` attribute may be used to obtain a
620 set of dimension names that has not been resolved by a
621 `~lsst.daf.butler.DimensionsUniverse`.
622 """
623 return type(self._dimensions) is DimensionGraph
625 @property
626 def dimensions(self) -> DimensionGraph:
627 """Standardized dimensions of the task."""
628 if not self.has_resolved_dimensions:
629 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.")
630 return cast(DimensionGraph, self._dimensions)
632 @property
633 def raw_dimensions(self) -> frozenset[str]:
634 """Raw dimensions of the task, with standardization by a
635 `~lsst.daf.butler.DimensionUniverse` not guaranteed.
636 """
637 if self.has_resolved_dimensions:
638 return frozenset(cast(DimensionGraph, self._dimensions).names)
639 else:
640 return cast(frozenset[str], self._dimensions)
642 def __repr__(self) -> str:
643 if self.has_resolved_dimensions:
644 return f"{self.label} ({self.task_class_name}, {self.dimensions})"
645 else:
646 return f"{self.label} ({self.task_class_name})"
648 def get_config_str(self) -> str:
649 """Return the configuration for this task as a string of override
650 statements.
652 Returns
653 -------
654 config_str : `str`
655 String containing configuration-overload statements.
656 """
657 return self.init.get_config_str()
659 def iter_all_inputs(self) -> Iterator[ReadEdge]:
660 """Iterate over all runtime inputs, including both regular inputs and
661 prerequisites.
662 """
663 yield from self.prerequisite_inputs.values()
664 yield from self.inputs.values()
666 def iter_all_outputs(self) -> Iterator[WriteEdge]:
667 """Iterate over all runtime outputs, including special ones."""
668 yield from self.outputs.values()
669 yield self.metadata_output
670 if self.log_output is not None:
671 yield self.log_output
673 def diff_edges(self, other: TaskNode) -> list[str]:
674 """Compare the edges of this task node to those from the same task
675 label in a different pipeline.
677 This also calls `TaskInitNode.diff_edges`.
679 Parameters
680 ----------
681 other : `TaskInitNode`
682 Other node to compare to. Must have the same task label, but need
683 not have the same configuration or even the same task class.
685 Returns
686 -------
687 differences : `list` [ `str` ]
688 List of string messages describing differences between ``self`` and
689 ``other``. Will be empty if the two nodes have the same edges.
690 Messages will use 'A' to refer to ``self`` and 'B' to refer to
691 ``other``.
692 """
693 result = self.init.diff_edges(other.init)
694 result += _diff_edge_mapping(
695 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input"
696 )
697 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input")
698 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output")
699 if self.log_output is not None:
700 if other.log_output is not None:
701 result += self.log_output.diff(other.log_output, "log output")
702 else:
703 result.append("Log output is present in A, but not in B.")
704 elif other.log_output is not None:
705 result.append("Log output is present in B, but not in A.")
706 result += self.metadata_output.diff(other.metadata_output, "metadata output")
707 return result
709 def _imported_and_configured(self, rebuild: bool) -> TaskNode:
710 """Import the task class and use it to construct a new instance.
712 Parameters
713 ----------
714 rebuild : `bool`
715 If `True`, import the task class and configure its connections to
716 generate new edges that may differ from the current ones. If
717 `False`, import the task class but just update the `task_class` and
718 `config` attributes, and assume the edges have not changed.
720 Returns
721 -------
722 node : `TaskNode`
723 Task node instance for which `is_imported` is `True`. Will be
724 ``self`` if this is the case already.
725 """
726 from ..pipelineTask import PipelineTask
728 if self.is_imported:
729 return self
730 task_class = doImportType(self.task_class_name)
731 if not issubclass(task_class, PipelineTask):
732 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
733 config = task_class.ConfigClass()
734 config.loadFromString(self.get_config_str())
735 return self._reconfigured(config, rebuild=rebuild, task_class=task_class)
737 def _reconfigured(
738 self,
739 config: PipelineTaskConfig,
740 rebuild: bool,
741 task_class: type[PipelineTask] | None = None,
742 ) -> TaskNode:
743 """Return a version of this node with new configuration.
745 Parameters
746 ----------
747 config : `.PipelineTaskConfig`
748 New configuration for the task.
749 rebuild : `bool`
750 If `True`, use the configured connections to generate new edges
751 that may differ from the current ones. If `False`, just update the
752 `task_class` and `config` attributes, and assume the edges have not
753 changed.
754 task_class : `type` [ `PipelineTask` ], optional
755 Subclass of `PipelineTask`. This defaults to ``self.task_class`,
756 but may be passed as an argument if that is not available because
757 the task class was not imported when ``self`` was constructed.
759 Returns
760 -------
761 node : `TaskNode`
762 Task node instance with the new config.
763 """
764 if task_class is None:
765 task_class = self.task_class
766 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config)
767 if rebuild:
768 return self._from_imported_data(
769 self.key,
770 self.init.key,
771 imported_data,
772 universe=self._dimensions.universe if type(self._dimensions) is DimensionGraph else None,
773 )
774 else:
775 return TaskNode(
776 self.key,
777 TaskInitNode(
778 self.init.key,
779 inputs=self.init.inputs,
780 outputs=self.init.outputs,
781 config_output=self.init.config_output,
782 imported_data=imported_data,
783 ),
784 prerequisite_inputs=self.prerequisite_inputs,
785 inputs=self.inputs,
786 outputs=self.outputs,
787 log_output=self.log_output,
788 metadata_output=self.metadata_output,
789 dimensions=self._dimensions,
790 )
792 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
793 """Return an otherwise-equivalent task node with resolved dimensions.
795 Parameters
796 ----------
797 universe : `lsst.daf.butler.DimensionUniverse` or `None`
798 Definitions for all dimensions.
800 Returns
801 -------
802 node : `TaskNode`
803 Task node instance with `dimensions` resolved by the given
804 universe. Will be ``self`` if this is the case already.
805 """
806 if self.has_resolved_dimensions:
807 if cast(DimensionGraph, self._dimensions).universe is universe:
808 return self
809 elif universe is None:
810 return self
811 return TaskNode(
812 key=self.key,
813 init=self.init,
814 prerequisite_inputs=self.prerequisite_inputs,
815 inputs=self.inputs,
816 outputs=self.outputs,
817 log_output=self.log_output,
818 metadata_output=self.metadata_output,
819 dimensions=(
820 universe.extract(self.raw_dimensions) if universe is not None else self.raw_dimensions
821 ),
822 )
824 def _to_xgraph_state(self) -> dict[str, Any]:
825 """Convert this nodes's attributes into a dictionary suitable for use
826 in exported networkx graphs.
827 """
828 result = self.init._to_xgraph_state()
829 if self.has_resolved_dimensions:
830 result["dimensions"] = self._dimensions
831 result["raw_dimensions"] = self.raw_dimensions
832 return result
834 def _get_imported_data(self) -> _TaskNodeImportedData:
835 """Return the imported data struct.
837 Returns
838 -------
839 imported_data : `_TaskNodeImportedData`
840 Internal structure holding state that requires the task class to
841 have been imported.
843 Raises
844 ------
845 TaskNotImportedError
846 Raised if `is_imported` is `False`.
847 """
848 return self.init._get_imported_data()
851def _diff_edge_mapping(
852 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str
853) -> list[str]:
854 """Compare a pair of mappings of edges.
856 Parameters
857 ----------
858 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
859 First mapping to compare. Expected to have connection names as keys.
860 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
861 First mapping to compare. If keys differ from those of ``a_mapping``,
862 this will be reported as a difference (in addition to element-wise
863 comparisons).
864 task_label : `str`
865 Task label associated with both mappings.
866 connection_type : `str`
867 Type of connection (e.g. "input" or "init output") associated with both
868 connections. This is a human-readable string to include in difference
869 messages.
871 Returns
872 -------
873 differences : `list` [ `str` ]
874 List of string messages describing differences between the two
875 mappings. Will be empty if the two mappings have the same edges.
876 Messages will include "A" and "B", and are expected to be a preceded
877 by a message describing what "A" and "B" are in the context in which
878 this method is called.
880 Notes
881 -----
882 This is expected to be used to compare one edge-holding mapping attribute
883 of a task or task init node to the same attribute on another task or task
884 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`,
885 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`,
886 `TaskInitNode.outputs`).
887 """
888 results = []
889 b_to_do = set(b_mapping.keys())
890 for connection_name, a_edge in a_mapping.items():
891 if (b_edge := b_mapping.get(connection_name)) is None:
892 results.append(
893 f"{connection_type.capitalize()} {connection_name!r} of task "
894 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
895 )
896 else:
897 results.extend(a_edge.diff(b_edge, connection_type))
898 b_to_do.discard(connection_name)
899 for connection_name in b_to_do:
900 results.append(
901 f"{connection_type.capitalize()} {connection_name!r} of task "
902 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
903 )
904 return results