Coverage for python/lsst/pipe/base/pipeline_graph/_tasks.py: 41%
262 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:32 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode")
31import dataclasses
32import enum
33from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
34from typing import TYPE_CHECKING, Any, cast
36from lsst.daf.butler import (
37 DataCoordinate,
38 DatasetRef,
39 DatasetType,
40 DimensionGraph,
41 DimensionUniverse,
42 Registry,
43)
44from lsst.utils.classes import immutable
45from lsst.utils.doImport import doImportType
46from lsst.utils.introspection import get_full_type_name
48from .. import automatic_connection_constants as acc
49from ..connections import PipelineTaskConnections
50from ..connectionTypes import BaseConnection, InitOutput, Output
51from ._edges import Edge, ReadEdge, WriteEdge
52from ._exceptions import TaskNotImportedError, UnresolvedGraphError
53from ._nodes import NodeKey, NodeType
55if TYPE_CHECKING:
56 from ..config import PipelineTaskConfig
57 from ..pipelineTask import PipelineTask
60class TaskImportMode(enum.Enum):
61 """Enumeration of the ways to handle importing tasks when reading a
62 serialized PipelineGraph.
63 """
65 DO_NOT_IMPORT = enum.auto()
66 """Do not import tasks or instantiate their configs and connections."""
68 REQUIRE_CONSISTENT_EDGES = enum.auto()
69 """Import tasks and instantiate their config and connection objects, and
70 check that the connections still define the same edges.
71 """
73 ASSUME_CONSISTENT_EDGES = enum.auto()
74 """Import tasks and instantiate their config and connection objects, but do
75 not check that the connections still define the same edges.
77 This is safe only when the caller knows the task definition has not changed
78 since the pipeline graph was persisted, such as when it was saved and
79 loaded with the same pipeline version.
80 """
82 OVERRIDE_EDGES = enum.auto()
83 """Import tasks and instantiate their config and connection objects, and
84 allow the edges defined in those connections to override those in the
85 persisted graph.
87 This may cause dataset type nodes to be unresolved, since resolutions
88 consistent with the original edges may be invalidated.
89 """
92@dataclasses.dataclass(frozen=True)
93class _TaskNodeImportedData:
94 """An internal struct that holds `TaskNode` and `TaskInitNode` state that
95 requires task classes to be imported.
96 """
98 task_class: type[PipelineTask]
99 """Type object for the task."""
101 config: PipelineTaskConfig
102 """Configuration object for the task."""
104 connection_map: dict[str, BaseConnection]
105 """Mapping from connection name to connection.
107 In addition to ``connections.allConnections``, this also holds the
108 "automatic" config, log, and metadata connections using the names defined
109 in the `.automatic_connection_constants` module.
110 """
112 connections: PipelineTaskConnections
113 """Configured connections object for the task."""
115 @classmethod
116 def configure(
117 cls,
118 label: str,
119 task_class: type[PipelineTask],
120 config: PipelineTaskConfig,
121 connections: PipelineTaskConnections | None = None,
122 ) -> _TaskNodeImportedData:
123 """Construct while creating a `PipelineTaskConnections` instance if
124 necessary.
126 Parameters
127 ----------
128 label : `str`
129 Label for the task in the pipeline. Only used in error messages.
130 task_class : `type` [ `.PipelineTask` ]
131 Pipeline task `type` object.
132 config : `.PipelineTaskConfig`
133 Configuration for the task.
134 connections : `.PipelineTaskConnections`, optional
135 Object that describes the dataset types used by the task. If not
136 provided, one will be constructed from the given configuration. If
137 provided, it is assumed that ``config`` has already been validated
138 and frozen.
140 Returns
141 -------
142 data : `_TaskNodeImportedData`
143 Instance of this struct.
144 """
145 if connections is None:
146 # If we don't have connections yet, assume the config hasn't been
147 # validated yet.
148 try:
149 config.validate()
150 except Exception as err:
151 raise ValueError(
152 f"Configuration validation failed for task {label!r} (see chained exception)."
153 ) from err
154 config.freeze()
155 connections = task_class.ConfigClass.ConnectionsClass(config=config)
156 connection_map = dict(connections.allConnections)
157 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput(
158 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
159 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
160 )
161 if not config.saveMetadata:
162 raise ValueError(f"Metadata for task {label} cannot be disabled.")
163 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output(
164 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
165 acc.METADATA_OUTPUT_STORAGE_CLASS,
166 dimensions=set(connections.dimensions),
167 )
168 if config.saveLogOutput:
169 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output(
170 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
171 acc.LOG_OUTPUT_STORAGE_CLASS,
172 dimensions=set(connections.dimensions),
173 )
174 return cls(task_class, config, connection_map, connections)
177@immutable
178class TaskInitNode:
179 """A node in a pipeline graph that represents the construction of a
180 `PipelineTask`.
182 Parameters
183 ----------
184 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
185 Graph edges that represent inputs required just to construct an
186 instance of this task, keyed by connection name.
187 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
188 Graph edges that represent outputs of this task that are available
189 after just constructing it, keyed by connection name.
191 This does not include the special `config_init_output` edge; use
192 `iter_all_outputs` to include that, too.
193 config_output : `WriteEdge`
194 The special init output edge that persists the task's configuration.
195 imported_data : `_TaskNodeImportedData`, optional
196 Internal struct that holds information that requires the task class to
197 have been be imported.
198 task_class_name : `str`, optional
199 Fully-qualified name of the task class. Must be provided if
200 ``imported_data`` is not.
201 config_str : `str`, optional
202 Configuration for the task as a string of override statements. Must be
203 provided if ``imported_data`` is not.
205 Notes
206 -----
207 When included in an exported `networkx` graph (e.g.
208 `PipelineGraph.make_xgraph`), task initialization nodes set the following
209 node attributes:
211 - ``task_class_name``
212 - ``bipartite`` (see `NodeType.bipartite`)
213 - ``task_class`` (only if `is_imported` is `True`)
214 - ``config`` (only if `is_importd` is `True`)
215 """
217 def __init__(
218 self,
219 key: NodeKey,
220 *,
221 inputs: Mapping[str, ReadEdge],
222 outputs: Mapping[str, WriteEdge],
223 config_output: WriteEdge,
224 imported_data: _TaskNodeImportedData | None = None,
225 task_class_name: str | None = None,
226 config_str: str | None = None,
227 ):
228 self.key = key
229 self.inputs = inputs
230 self.outputs = outputs
231 self.config_output = config_output
232 # Instead of setting attributes to None, we do not set them at all;
233 # this works better with the @immutable decorator, which supports
234 # deferred initialization but not reassignment.
235 if task_class_name is not None:
236 self._task_class_name = task_class_name
237 if config_str is not None:
238 self._config_str = config_str
239 if imported_data is not None:
240 self._imported_data = imported_data
241 else:
242 assert (
243 self._task_class_name is not None and self._config_str is not None
244 ), "If imported_data is not present, task_class_name and config_str must be."
246 key: NodeKey
247 """Key that identifies this node in internal and exported networkx graphs.
248 """
250 inputs: Mapping[str, ReadEdge]
251 """Graph edges that represent inputs required just to construct an instance
252 of this task, keyed by connection name.
253 """
255 outputs: Mapping[str, WriteEdge]
256 """Graph edges that represent outputs of this task that are available after
257 just constructing it, keyed by connection name.
259 This does not include the special `config_output` edge; use
260 `iter_all_outputs` to include that, too.
261 """
263 config_output: WriteEdge
264 """The special output edge that persists the task's configuration.
265 """
267 @property
268 def label(self) -> str:
269 """Label of this configuration of a task in the pipeline."""
270 return str(self.key)
272 @property
273 def is_imported(self) -> bool:
274 """Whether this the task type for this node has been imported and
275 its configuration overrides applied.
277 If this is `False`, the `task_class` and `config` attributes may not
278 be accessed.
279 """
280 return hasattr(self, "_imported_data")
282 @property
283 def task_class(self) -> type[PipelineTask]:
284 """Type object for the task.
286 Accessing this attribute when `is_imported` is `False` will raise
287 `TaskNotImportedError`, but accessing `task_class_name` will not.
288 """
289 return self._get_imported_data().task_class
291 @property
292 def task_class_name(self) -> str:
293 """The fully-qualified string name of the task class."""
294 try:
295 return self._task_class_name
296 except AttributeError:
297 pass
298 self._task_class_name = get_full_type_name(self.task_class)
299 return self._task_class_name
301 @property
302 def config(self) -> PipelineTaskConfig:
303 """Configuration for the task.
305 This is always frozen.
307 Accessing this attribute when `is_imported` is `False` will raise
308 `TaskNotImportedError`, but calling `get_config_str` will not.
309 """
310 return self._get_imported_data().config
312 def __repr__(self) -> str:
313 return f"{self.label} [init] ({self.task_class_name})"
315 def get_config_str(self) -> str:
316 """Return the configuration for this task as a string of override
317 statements.
319 Returns
320 -------
321 config_str : `str`
322 String containing configuration-overload statements.
323 """
324 try:
325 return self._config_str
326 except AttributeError:
327 pass
328 self._config_str = self.config.saveToString()
329 return self._config_str
331 def iter_all_inputs(self) -> Iterator[ReadEdge]:
332 """Iterate over all inputs required for construction.
334 This is the same as iteration over ``inputs.values()``, but it will be
335 updated to include any automatic init-input connections added in the
336 future, while `inputs` will continue to hold only task-defined init
337 inputs.
338 """
339 return iter(self.inputs.values())
341 def iter_all_outputs(self) -> Iterator[WriteEdge]:
342 """Iterate over all outputs available after construction, including
343 special ones.
344 """
345 yield from self.outputs.values()
346 yield self.config_output
348 def diff_edges(self, other: TaskInitNode) -> list[str]:
349 """Compare the edges of this task initialization node to those from the
350 same task label in a different pipeline.
352 Parameters
353 ----------
354 other : `TaskInitNode`
355 Other node to compare to. Must have the same task label, but need
356 not have the same configuration or even the same task class.
358 Returns
359 -------
360 differences : `list` [ `str` ]
361 List of string messages describing differences between ``self`` and
362 ``other``. Will be empty if the two nodes have the same edges.
363 Messages will use 'A' to refer to ``self`` and 'B' to refer to
364 ``other``.
365 """
366 result = []
367 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input")
368 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output")
369 result += self.config_output.diff(other.config_output, "config init output")
370 return result
372 def _to_xgraph_state(self) -> dict[str, Any]:
373 """Convert this nodes's attributes into a dictionary suitable for use
374 in exported networkx graphs.
375 """
376 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
377 if hasattr(self, "_imported_data"):
378 result["task_class"] = self.task_class
379 result["config"] = self.config
380 return result
382 def _get_imported_data(self) -> _TaskNodeImportedData:
383 """Return the imported data struct.
385 Returns
386 -------
387 imported_data : `_TaskNodeImportedData`
388 Internal structure holding state that requires the task class to
389 have been imported.
391 Raises
392 ------
393 TaskNotImportedError
394 Raised if `is_imported` is `False`.
395 """
396 try:
397 return self._imported_data
398 except AttributeError:
399 raise TaskNotImportedError(
400 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported "
401 "(see PipelineGraph.import_and_configure)."
402 ) from None
405@immutable
406class TaskNode:
407 """A node in a pipeline graph that represents a labeled configuration of a
408 `PipelineTask`.
410 Parameters
411 ----------
412 key : `NodeKey`
413 Identifier for this node in networkx graphs.
414 init : `TaskInitNode`
415 Node representing the initialization of this task.
416 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
417 Graph edges that represent prerequisite inputs to this task, keyed by
418 connection name.
420 Prerequisite inputs must already exist in the data repository when a
421 `QuantumGraph` is built, but have more flexibility in how they are
422 looked up than regular inputs.
423 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
424 Graph edges that represent regular runtime inputs to this task, keyed
425 by connection name.
426 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
427 Graph edges that represent regular runtime outputs of this task, keyed
428 by connection name.
430 This does not include the special `log_output` and `metadata_output`
431 edges; use `iter_all_outputs` to include that, too.
432 log_output : `WriteEdge` or `None`
433 The special runtime output that persists the task's logs.
434 metadata_output : `WriteEdge`
435 The special runtime output that persists the task's metadata.
436 dimensions : `lsst.daf.butler.DimensionGraph` or `frozenset`
437 Dimensions of the task. If a `frozenset`, the dimensions have not been
438 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely
439 compared to other sets of dimensions.
441 Notes
442 -----
443 Task nodes are intentionally not equality comparable, since there are many
444 different (and useful) ways to compare these objects with no clear winner
445 as the most obvious behavior.
447 When included in an exported `networkx` graph (e.g.
448 `PipelineGraph.make_xgraph`), task nodes set the following node attributes:
450 - ``task_class_name``
451 - ``bipartite`` (see `NodeType.bipartite`)
452 - ``task_class`` (only if `is_imported` is `True`)
453 - ``config`` (only if `is_importd` is `True`)
454 """
456 def __init__(
457 self,
458 key: NodeKey,
459 init: TaskInitNode,
460 *,
461 prerequisite_inputs: Mapping[str, ReadEdge],
462 inputs: Mapping[str, ReadEdge],
463 outputs: Mapping[str, WriteEdge],
464 log_output: WriteEdge | None,
465 metadata_output: WriteEdge,
466 dimensions: DimensionGraph | frozenset,
467 ):
468 self.key = key
469 self.init = init
470 self.prerequisite_inputs = prerequisite_inputs
471 self.inputs = inputs
472 self.outputs = outputs
473 self.log_output = log_output
474 self.metadata_output = metadata_output
475 self._dimensions = dimensions
477 @staticmethod
478 def _from_imported_data(
479 key: NodeKey,
480 init_key: NodeKey,
481 data: _TaskNodeImportedData,
482 universe: DimensionUniverse | None,
483 ) -> TaskNode:
484 """Construct from a `PipelineTask` type and its configuration.
486 Parameters
487 ----------
488 key : `NodeKey`
489 Identifier for this node in networkx graphs.
490 init : `TaskInitNode`
491 Node representing the initialization of this task.
492 data : `_TaskNodeImportedData`
493 Internal struct that holds information that requires the task class
494 to have been be imported.
495 universe : `lsst.daf.butler.DimensionUniverse` or `None`
496 Definitions of all dimensions.
498 Returns
499 -------
500 node : `TaskNode`
501 New task node.
503 Raises
504 ------
505 ValueError
506 Raised if configuration validation failed when constructing
507 ``connections``.
508 """
509 init_inputs = {
510 name: ReadEdge._from_connection_map(init_key, name, data.connection_map)
511 for name in data.connections.initInputs
512 }
513 prerequisite_inputs = {
514 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True)
515 for name in data.connections.prerequisiteInputs
516 }
517 inputs = {
518 name: ReadEdge._from_connection_map(key, name, data.connection_map)
519 for name in data.connections.inputs
520 }
521 init_outputs = {
522 name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
523 for name in data.connections.initOutputs
524 }
525 outputs = {
526 name: WriteEdge._from_connection_map(key, name, data.connection_map)
527 for name in data.connections.outputs
528 }
529 init = TaskInitNode(
530 key=init_key,
531 inputs=init_inputs,
532 outputs=init_outputs,
533 config_output=WriteEdge._from_connection_map(
534 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map
535 ),
536 imported_data=data,
537 )
538 instance = TaskNode(
539 key=key,
540 init=init,
541 prerequisite_inputs=prerequisite_inputs,
542 inputs=inputs,
543 outputs=outputs,
544 log_output=(
545 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map)
546 if data.config.saveLogOutput
547 else None
548 ),
549 metadata_output=WriteEdge._from_connection_map(
550 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
551 ),
552 dimensions=(
553 frozenset(data.connections.dimensions)
554 if universe is None
555 else universe.extract(data.connections.dimensions)
556 ),
557 )
558 return instance
560 key: NodeKey
561 """Key that identifies this node in internal and exported networkx graphs.
562 """
564 prerequisite_inputs: Mapping[str, ReadEdge]
565 """Graph edges that represent prerequisite inputs to this task.
567 Prerequisite inputs must already exist in the data repository when a
568 `QuantumGraph` is built, but have more flexibility in how they are looked
569 up than regular inputs.
570 """
572 inputs: Mapping[str, ReadEdge]
573 """Graph edges that represent regular runtime inputs to this task.
574 """
576 outputs: Mapping[str, WriteEdge]
577 """Graph edges that represent regular runtime outputs of this task.
579 This does not include the special `log_output` and `metadata_output` edges;
580 use `iter_all_outputs` to include that, too.
581 """
583 log_output: WriteEdge | None
584 """The special runtime output that persists the task's logs.
585 """
587 metadata_output: WriteEdge
588 """The special runtime output that persists the task's metadata.
589 """
591 @property
592 def label(self) -> str:
593 """Label of this configuration of a task in the pipeline."""
594 return self.key.name
596 @property
597 def is_imported(self) -> bool:
598 """Whether this the task type for this node has been imported and
599 its configuration overrides applied.
601 If this is `False`, the `task_class` and `config` attributes may not
602 be accessed.
603 """
604 return self.init.is_imported
606 @property
607 def task_class(self) -> type[PipelineTask]:
608 """Type object for the task.
610 Accessing this attribute when `is_imported` is `False` will raise
611 `TaskNotImportedError`, but accessing `task_class_name` will not.
612 """
613 return self.init.task_class
615 @property
616 def task_class_name(self) -> str:
617 """The fully-qualified string name of the task class."""
618 return self.init.task_class_name
620 @property
621 def config(self) -> PipelineTaskConfig:
622 """Configuration for the task.
624 This is always frozen.
626 Accessing this attribute when `is_imported` is `False` will raise
627 `TaskNotImportedError`, but calling `get_config_str` will not.
628 """
629 return self.init.config
631 @property
632 def has_resolved_dimensions(self) -> bool:
633 """Whether the `dimensions` attribute may be accessed.
635 If `False`, the `raw_dimensions` attribute may be used to obtain a
636 set of dimension names that has not been resolved by a
637 `~lsst.daf.butler.DimensionsUniverse`.
638 """
639 return type(self._dimensions) is DimensionGraph
641 @property
642 def dimensions(self) -> DimensionGraph:
643 """Standardized dimensions of the task."""
644 if not self.has_resolved_dimensions:
645 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.")
646 return cast(DimensionGraph, self._dimensions)
648 @property
649 def raw_dimensions(self) -> frozenset[str]:
650 """Raw dimensions of the task, with standardization by a
651 `~lsst.daf.butler.DimensionUniverse` not guaranteed.
652 """
653 if self.has_resolved_dimensions:
654 return frozenset(cast(DimensionGraph, self._dimensions).names)
655 else:
656 return cast(frozenset[str], self._dimensions)
658 def __repr__(self) -> str:
659 if self.has_resolved_dimensions:
660 return f"{self.label} ({self.task_class_name}, {self.dimensions})"
661 else:
662 return f"{self.label} ({self.task_class_name})"
664 def get_config_str(self) -> str:
665 """Return the configuration for this task as a string of override
666 statements.
668 Returns
669 -------
670 config_str : `str`
671 String containing configuration-overload statements.
672 """
673 return self.init.get_config_str()
675 def iter_all_inputs(self) -> Iterator[ReadEdge]:
676 """Iterate over all runtime inputs, including both regular inputs and
677 prerequisites.
678 """
679 yield from self.prerequisite_inputs.values()
680 yield from self.inputs.values()
682 def iter_all_outputs(self) -> Iterator[WriteEdge]:
683 """Iterate over all runtime outputs, including special ones."""
684 yield from self.outputs.values()
685 yield self.metadata_output
686 if self.log_output is not None:
687 yield self.log_output
689 def diff_edges(self, other: TaskNode) -> list[str]:
690 """Compare the edges of this task node to those from the same task
691 label in a different pipeline.
693 This also calls `TaskInitNode.diff_edges`.
695 Parameters
696 ----------
697 other : `TaskInitNode`
698 Other node to compare to. Must have the same task label, but need
699 not have the same configuration or even the same task class.
701 Returns
702 -------
703 differences : `list` [ `str` ]
704 List of string messages describing differences between ``self`` and
705 ``other``. Will be empty if the two nodes have the same edges.
706 Messages will use 'A' to refer to ``self`` and 'B' to refer to
707 ``other``.
708 """
709 result = self.init.diff_edges(other.init)
710 result += _diff_edge_mapping(
711 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input"
712 )
713 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input")
714 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output")
715 if self.log_output is not None:
716 if other.log_output is not None:
717 result += self.log_output.diff(other.log_output, "log output")
718 else:
719 result.append("Log output is present in A, but not in B.")
720 elif other.log_output is not None:
721 result.append("Log output is present in B, but not in A.")
722 result += self.metadata_output.diff(other.metadata_output, "metadata output")
723 return result
725 def get_lookup_function(
726 self, connection_name: str
727 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None:
728 """Return the custom dataset query function for an edge, if one exists.
730 Parameters
731 ----------
732 connection_name : `str`
733 Name of the connection.
735 Returns
736 -------
737 lookup_function : `~collections.abc.Callable` or `None`
738 Callable that takes a dataset type, a butler registry, a data
739 coordinate (the quantum data ID), and an ordered list of
740 collections to search, and returns an iterable of
741 `~lsst.daf.butler.DatasetRef`.
742 """
743 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None)
745 def get_spatial_bounds_connections(self) -> frozenset[str]:
746 """Return the names of connections whose data IDs should be included
747 in the calculation of the spatial bounds for this task's quanta.
749 Returns
750 -------
751 connection_names : `frozenset` [ `str` ]
752 Names of connections with spatial dimensions.
753 """
754 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections())
756 def get_temporal_bounds_connections(self) -> frozenset[str]:
757 """Return the names of connections whose data IDs should be included
758 in the calculation of the temporal bounds for this task's quanta.
760 Returns
761 -------
762 connection_names : `frozenset` [ `str` ]
763 Names of connections with temporal dimensions.
764 """
765 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections())
767 def _imported_and_configured(self, rebuild: bool) -> TaskNode:
768 """Import the task class and use it to construct a new instance.
770 Parameters
771 ----------
772 rebuild : `bool`
773 If `True`, import the task class and configure its connections to
774 generate new edges that may differ from the current ones. If
775 `False`, import the task class but just update the `task_class` and
776 `config` attributes, and assume the edges have not changed.
778 Returns
779 -------
780 node : `TaskNode`
781 Task node instance for which `is_imported` is `True`. Will be
782 ``self`` if this is the case already.
783 """
784 from ..pipelineTask import PipelineTask
786 if self.is_imported:
787 return self
788 task_class = doImportType(self.task_class_name)
789 if not issubclass(task_class, PipelineTask):
790 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
791 config = task_class.ConfigClass()
792 config.loadFromString(self.get_config_str())
793 return self._reconfigured(config, rebuild=rebuild, task_class=task_class)
795 def _reconfigured(
796 self,
797 config: PipelineTaskConfig,
798 rebuild: bool,
799 task_class: type[PipelineTask] | None = None,
800 ) -> TaskNode:
801 """Return a version of this node with new configuration.
803 Parameters
804 ----------
805 config : `.PipelineTaskConfig`
806 New configuration for the task.
807 rebuild : `bool`
808 If `True`, use the configured connections to generate new edges
809 that may differ from the current ones. If `False`, just update the
810 `task_class` and `config` attributes, and assume the edges have not
811 changed.
812 task_class : `type` [ `PipelineTask` ], optional
813 Subclass of `PipelineTask`. This defaults to ``self.task_class`,
814 but may be passed as an argument if that is not available because
815 the task class was not imported when ``self`` was constructed.
817 Returns
818 -------
819 node : `TaskNode`
820 Task node instance with the new config.
821 """
822 if task_class is None:
823 task_class = self.task_class
824 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config)
825 if rebuild:
826 return self._from_imported_data(
827 self.key,
828 self.init.key,
829 imported_data,
830 universe=self._dimensions.universe if type(self._dimensions) is DimensionGraph else None,
831 )
832 else:
833 return TaskNode(
834 self.key,
835 TaskInitNode(
836 self.init.key,
837 inputs=self.init.inputs,
838 outputs=self.init.outputs,
839 config_output=self.init.config_output,
840 imported_data=imported_data,
841 ),
842 prerequisite_inputs=self.prerequisite_inputs,
843 inputs=self.inputs,
844 outputs=self.outputs,
845 log_output=self.log_output,
846 metadata_output=self.metadata_output,
847 dimensions=self._dimensions,
848 )
850 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
851 """Return an otherwise-equivalent task node with resolved dimensions.
853 Parameters
854 ----------
855 universe : `lsst.daf.butler.DimensionUniverse` or `None`
856 Definitions for all dimensions.
858 Returns
859 -------
860 node : `TaskNode`
861 Task node instance with `dimensions` resolved by the given
862 universe. Will be ``self`` if this is the case already.
863 """
864 if self.has_resolved_dimensions:
865 if cast(DimensionGraph, self._dimensions).universe is universe:
866 return self
867 elif universe is None:
868 return self
869 return TaskNode(
870 key=self.key,
871 init=self.init,
872 prerequisite_inputs=self.prerequisite_inputs,
873 inputs=self.inputs,
874 outputs=self.outputs,
875 log_output=self.log_output,
876 metadata_output=self.metadata_output,
877 dimensions=(
878 universe.extract(self.raw_dimensions) if universe is not None else self.raw_dimensions
879 ),
880 )
882 def _to_xgraph_state(self) -> dict[str, Any]:
883 """Convert this nodes's attributes into a dictionary suitable for use
884 in exported networkx graphs.
885 """
886 result = self.init._to_xgraph_state()
887 if self.has_resolved_dimensions:
888 result["dimensions"] = self._dimensions
889 result["raw_dimensions"] = self.raw_dimensions
890 return result
892 def _get_imported_data(self) -> _TaskNodeImportedData:
893 """Return the imported data struct.
895 Returns
896 -------
897 imported_data : `_TaskNodeImportedData`
898 Internal structure holding state that requires the task class to
899 have been imported.
901 Raises
902 ------
903 TaskNotImportedError
904 Raised if `is_imported` is `False`.
905 """
906 return self.init._get_imported_data()
909def _diff_edge_mapping(
910 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str
911) -> list[str]:
912 """Compare a pair of mappings of edges.
914 Parameters
915 ----------
916 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
917 First mapping to compare. Expected to have connection names as keys.
918 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
919 First mapping to compare. If keys differ from those of ``a_mapping``,
920 this will be reported as a difference (in addition to element-wise
921 comparisons).
922 task_label : `str`
923 Task label associated with both mappings.
924 connection_type : `str`
925 Type of connection (e.g. "input" or "init output") associated with both
926 connections. This is a human-readable string to include in difference
927 messages.
929 Returns
930 -------
931 differences : `list` [ `str` ]
932 List of string messages describing differences between the two
933 mappings. Will be empty if the two mappings have the same edges.
934 Messages will include "A" and "B", and are expected to be a preceded
935 by a message describing what "A" and "B" are in the context in which
936 this method is called.
938 Notes
939 -----
940 This is expected to be used to compare one edge-holding mapping attribute
941 of a task or task init node to the same attribute on another task or task
942 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`,
943 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`,
944 `TaskInitNode.outputs`).
945 """
946 results = []
947 b_to_do = set(b_mapping.keys())
948 for connection_name, a_edge in a_mapping.items():
949 if (b_edge := b_mapping.get(connection_name)) is None:
950 results.append(
951 f"{connection_type.capitalize()} {connection_name!r} of task "
952 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
953 )
954 else:
955 results.extend(a_edge.diff(b_edge, connection_type))
956 b_to_do.discard(connection_name)
957 for connection_name in b_to_do:
958 results.append(
959 f"{connection_type.capitalize()} {connection_name!r} of task "
960 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
961 )
962 return results