Coverage for python/lsst/pipe/base/pipeline_graph/_tasks.py: 41%
282 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-17 02:45 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-17 02:45 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode")
31import dataclasses
32import enum
33from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
34from typing import TYPE_CHECKING, Any, cast
36from lsst.daf.butler import (
37 DataCoordinate,
38 DatasetRef,
39 DatasetType,
40 DimensionGroup,
41 DimensionUniverse,
42 Registry,
43)
44from lsst.pex.config import FieldValidationError
45from lsst.utils.classes import immutable
46from lsst.utils.doImport import doImportType
47from lsst.utils.introspection import get_full_type_name
49from .. import automatic_connection_constants as acc
50from ..connections import PipelineTaskConnections
51from ..connectionTypes import BaseConnection, InitOutput, Output
52from ._edges import Edge, ReadEdge, WriteEdge
53from ._exceptions import TaskNotImportedError, UnresolvedGraphError
54from ._nodes import NodeKey, NodeType
56if TYPE_CHECKING:
57 from ..config import PipelineTaskConfig
58 from ..pipelineTask import PipelineTask
61class TaskImportMode(enum.Enum):
62 """Enumeration of the ways to handle importing tasks when reading a
63 serialized PipelineGraph.
64 """
66 DO_NOT_IMPORT = enum.auto()
67 """Do not import tasks or instantiate their configs and connections."""
69 REQUIRE_CONSISTENT_EDGES = enum.auto()
70 """Import tasks and instantiate their config and connection objects, and
71 check that the connections still define the same edges.
72 """
74 ASSUME_CONSISTENT_EDGES = enum.auto()
75 """Import tasks and instantiate their config and connection objects, but do
76 not check that the connections still define the same edges.
78 This is safe only when the caller knows the task definition has not changed
79 since the pipeline graph was persisted, such as when it was saved and
80 loaded with the same pipeline version.
81 """
83 OVERRIDE_EDGES = enum.auto()
84 """Import tasks and instantiate their config and connection objects, and
85 allow the edges defined in those connections to override those in the
86 persisted graph.
88 This may cause dataset type nodes to be unresolved, since resolutions
89 consistent with the original edges may be invalidated.
90 """
93@dataclasses.dataclass(frozen=True)
94class _TaskNodeImportedData:
95 """An internal struct that holds `TaskNode` and `TaskInitNode` state that
96 requires task classes to be imported.
97 """
99 task_class: type[PipelineTask]
100 """Type object for the task."""
102 config: PipelineTaskConfig
103 """Configuration object for the task."""
105 connection_map: dict[str, BaseConnection]
106 """Mapping from connection name to connection.
108 In addition to ``connections.allConnections``, this also holds the
109 "automatic" config, log, and metadata connections using the names defined
110 in the `.automatic_connection_constants` module.
111 """
113 connections: PipelineTaskConnections
114 """Configured connections object for the task."""
116 @classmethod
117 def configure(
118 cls,
119 label: str,
120 task_class: type[PipelineTask],
121 config: PipelineTaskConfig,
122 connections: PipelineTaskConnections | None = None,
123 ) -> _TaskNodeImportedData:
124 """Construct while creating a `PipelineTaskConnections` instance if
125 necessary.
127 Parameters
128 ----------
129 label : `str`
130 Label for the task in the pipeline. Only used in error messages.
131 task_class : `type` [ `.PipelineTask` ]
132 Pipeline task `type` object.
133 config : `.PipelineTaskConfig`
134 Configuration for the task.
135 connections : `.PipelineTaskConnections`, optional
136 Object that describes the dataset types used by the task. If not
137 provided, one will be constructed from the given configuration. If
138 provided, it is assumed that ``config`` has already been validated
139 and frozen.
141 Returns
142 -------
143 data : `_TaskNodeImportedData`
144 Instance of this struct.
145 """
146 if connections is None:
147 # If we don't have connections yet, assume the config hasn't been
148 # validated yet.
149 try:
150 config.validate()
151 except FieldValidationError as err:
152 err.fullname = f"{label}: {err.fullname}"
153 raise err
154 except Exception as err:
155 raise ValueError(
156 f"Configuration validation failed for task {label!r} (see chained exception)."
157 ) from err
158 config.freeze()
159 # MyPy doesn't see the metaclass attribute defined for this.
160 connections = config.ConnectionsClass(config=config) # type: ignore
161 connection_map = dict(connections.allConnections)
162 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput(
163 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
164 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
165 )
166 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output(
167 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
168 acc.METADATA_OUTPUT_STORAGE_CLASS,
169 dimensions=set(connections.dimensions),
170 )
171 if config.saveLogOutput:
172 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output(
173 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
174 acc.LOG_OUTPUT_STORAGE_CLASS,
175 dimensions=set(connections.dimensions),
176 )
177 return cls(task_class, config, connection_map, connections)
180@immutable
181class TaskInitNode:
182 """A node in a pipeline graph that represents the construction of a
183 `PipelineTask`.
185 Parameters
186 ----------
187 key : `NodeKey`
188 Key that identifies this node in internal and exported networkx graphs.
189 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
190 Graph edges that represent inputs required just to construct an
191 instance of this task, keyed by connection name.
192 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
193 Graph edges that represent outputs of this task that are available
194 after just constructing it, keyed by connection name.
196 This does not include the special `config_init_output` edge; use
197 `iter_all_outputs` to include that, too.
198 config_output : `WriteEdge`
199 The special init output edge that persists the task's configuration.
200 imported_data : `_TaskNodeImportedData`, optional
201 Internal struct that holds information that requires the task class to
202 have been be imported.
203 task_class_name : `str`, optional
204 Fully-qualified name of the task class. Must be provided if
205 ``imported_data`` is not.
206 config_str : `str`, optional
207 Configuration for the task as a string of override statements. Must be
208 provided if ``imported_data`` is not.
210 Notes
211 -----
212 When included in an exported `networkx` graph (e.g.
213 `PipelineGraph.make_xgraph`), task initialization nodes set the following
214 node attributes:
216 - ``task_class_name``
217 - ``bipartite`` (see `NodeType.bipartite`)
218 - ``task_class`` (only if `is_imported` is `True`)
219 - ``config`` (only if `is_importd` is `True`)
220 """
222 def __init__(
223 self,
224 key: NodeKey,
225 *,
226 inputs: Mapping[str, ReadEdge],
227 outputs: Mapping[str, WriteEdge],
228 config_output: WriteEdge,
229 imported_data: _TaskNodeImportedData | None = None,
230 task_class_name: str | None = None,
231 config_str: str | None = None,
232 ):
233 self.key = key
234 self.inputs = inputs
235 self.outputs = outputs
236 self.config_output = config_output
237 # Instead of setting attributes to None, we do not set them at all;
238 # this works better with the @immutable decorator, which supports
239 # deferred initialization but not reassignment.
240 if task_class_name is not None:
241 self._task_class_name = task_class_name
242 if config_str is not None:
243 self._config_str = config_str
244 if imported_data is not None:
245 self._imported_data = imported_data
246 else:
247 assert (
248 self._task_class_name is not None and self._config_str is not None
249 ), "If imported_data is not present, task_class_name and config_str must be."
251 key: NodeKey
252 """Key that identifies this node in internal and exported networkx graphs.
253 """
255 inputs: Mapping[str, ReadEdge]
256 """Graph edges that represent inputs required just to construct an instance
257 of this task, keyed by connection name.
258 """
260 outputs: Mapping[str, WriteEdge]
261 """Graph edges that represent outputs of this task that are available after
262 just constructing it, keyed by connection name.
264 This does not include the special `config_output` edge; use
265 `iter_all_outputs` to include that, too.
266 """
268 config_output: WriteEdge
269 """The special output edge that persists the task's configuration.
270 """
272 @property
273 def label(self) -> str:
274 """Label of this configuration of a task in the pipeline."""
275 return str(self.key)
277 @property
278 def is_imported(self) -> bool:
279 """Whether this the task type for this node has been imported and
280 its configuration overrides applied.
282 If this is `False`, the `task_class` and `config` attributes may not
283 be accessed.
284 """
285 return hasattr(self, "_imported_data")
287 @property
288 def task_class(self) -> type[PipelineTask]:
289 """Type object for the task.
291 Accessing this attribute when `is_imported` is `False` will raise
292 `TaskNotImportedError`, but accessing `task_class_name` will not.
293 """
294 return self._get_imported_data().task_class
296 @property
297 def task_class_name(self) -> str:
298 """The fully-qualified string name of the task class."""
299 try:
300 return self._task_class_name
301 except AttributeError:
302 pass
303 self._task_class_name = get_full_type_name(self.task_class)
304 return self._task_class_name
306 @property
307 def config(self) -> PipelineTaskConfig:
308 """Configuration for the task.
310 This is always frozen.
312 Accessing this attribute when `is_imported` is `False` will raise
313 `TaskNotImportedError`, but calling `get_config_str` will not.
314 """
315 return self._get_imported_data().config
317 def __repr__(self) -> str:
318 return f"{self.label} [init] ({self.task_class_name})"
320 def get_config_str(self) -> str:
321 """Return the configuration for this task as a string of override
322 statements.
324 Returns
325 -------
326 config_str : `str`
327 String containing configuration-overload statements.
328 """
329 try:
330 return self._config_str
331 except AttributeError:
332 pass
333 self._config_str = self.config.saveToString()
334 return self._config_str
336 def iter_all_inputs(self) -> Iterator[ReadEdge]:
337 """Iterate over all inputs required for construction.
339 This is the same as iteration over ``inputs.values()``, but it will be
340 updated to include any automatic init-input connections added in the
341 future, while `inputs` will continue to hold only task-defined init
342 inputs.
344 Yields
345 ------
346 `ReadEdge`
347 All the inputs required for construction.
348 """
349 return iter(self.inputs.values())
351 def iter_all_outputs(self) -> Iterator[WriteEdge]:
352 """Iterate over all outputs available after construction, including
353 special ones.
355 Yields
356 ------
357 `ReadEdge`
358 All the outputs available after construction.
359 """
360 yield from self.outputs.values()
361 yield self.config_output
363 def diff_edges(self, other: TaskInitNode) -> list[str]:
364 """Compare the edges of this task initialization node to those from the
365 same task label in a different pipeline.
367 Parameters
368 ----------
369 other : `TaskInitNode`
370 Other node to compare to. Must have the same task label, but need
371 not have the same configuration or even the same task class.
373 Returns
374 -------
375 differences : `list` [ `str` ]
376 List of string messages describing differences between ``self`` and
377 ``other``. Will be empty if the two nodes have the same edges.
378 Messages will use 'A' to refer to ``self`` and 'B' to refer to
379 ``other``.
380 """
381 result = []
382 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input")
383 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output")
384 result += self.config_output.diff(other.config_output, "config init output")
385 return result
387 def _to_xgraph_state(self) -> dict[str, Any]:
388 """Convert this nodes's attributes into a dictionary suitable for use
389 in exported networkx graphs.
390 """
391 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
392 if hasattr(self, "_imported_data"):
393 result["task_class"] = self.task_class
394 result["config"] = self.config
395 return result
397 def _get_imported_data(self) -> _TaskNodeImportedData:
398 """Return the imported data struct.
400 Returns
401 -------
402 imported_data : `_TaskNodeImportedData`
403 Internal structure holding state that requires the task class to
404 have been imported.
406 Raises
407 ------
408 TaskNotImportedError
409 Raised if `is_imported` is `False`.
410 """
411 try:
412 return self._imported_data
413 except AttributeError:
414 raise TaskNotImportedError(
415 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported "
416 "(see PipelineGraph.import_and_configure)."
417 ) from None
419 @staticmethod
420 def _unreduce(kwargs: dict[str, Any]) -> TaskInitNode:
421 """Unpickle a `TaskInitNode` instance."""
422 # Connections classes are not pickleable, so we can't use the
423 # dataclass-provided pickle implementation of _TaskNodeImportedData,
424 # and it's easier to just call its `configure` method than to fix it.
425 if (imported_data_args := kwargs.pop("imported_data_args", None)) is not None:
426 imported_data = _TaskNodeImportedData.configure(*imported_data_args)
427 else:
428 imported_data = None
429 return TaskInitNode(imported_data=imported_data, **kwargs)
431 def __reduce__(self) -> tuple[Callable[[dict[str, Any]], TaskInitNode], tuple[dict[str, Any]]]:
432 kwargs = dict(
433 key=self.key,
434 inputs=self.inputs,
435 outputs=self.outputs,
436 config_output=self.config_output,
437 task_class_name=getattr(self, "_task_class_name", None),
438 config_str=getattr(self, "_config_str", None),
439 )
440 if hasattr(self, "_imported_data"):
441 kwargs["imported_data_args"] = (
442 self.label,
443 self.task_class,
444 self.config,
445 )
446 return (self._unreduce, (kwargs,))
449@immutable
450class TaskNode:
451 """A node in a pipeline graph that represents a labeled configuration of a
452 `PipelineTask`.
454 Parameters
455 ----------
456 key : `NodeKey`
457 Identifier for this node in networkx graphs.
458 init : `TaskInitNode`
459 Node representing the initialization of this task.
460 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
461 Graph edges that represent prerequisite inputs to this task, keyed by
462 connection name.
464 Prerequisite inputs must already exist in the data repository when a
465 `QuantumGraph` is built, but have more flexibility in how they are
466 looked up than regular inputs.
467 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
468 Graph edges that represent regular runtime inputs to this task, keyed
469 by connection name.
470 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
471 Graph edges that represent regular runtime outputs of this task, keyed
472 by connection name.
474 This does not include the special `log_output` and `metadata_output`
475 edges; use `iter_all_outputs` to include that, too.
476 log_output : `WriteEdge` or `None`
477 The special runtime output that persists the task's logs.
478 metadata_output : `WriteEdge`
479 The special runtime output that persists the task's metadata.
480 dimensions : `lsst.daf.butler.DimensionGroup` or `frozenset` [ `str` ]
481 Dimensions of the task. If a `frozenset`, the dimensions have not been
482 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely
483 compared to other sets of dimensions.
485 Notes
486 -----
487 Task nodes are intentionally not equality comparable, since there are many
488 different (and useful) ways to compare these objects with no clear winner
489 as the most obvious behavior.
491 When included in an exported `networkx` graph (e.g.
492 `PipelineGraph.make_xgraph`), task nodes set the following node attributes:
494 - ``task_class_name``
495 - ``bipartite`` (see `NodeType.bipartite`)
496 - ``task_class`` (only if `is_imported` is `True`)
497 - ``config`` (only if `is_imported` is `True`)
498 """
500 def __init__(
501 self,
502 key: NodeKey,
503 init: TaskInitNode,
504 *,
505 prerequisite_inputs: Mapping[str, ReadEdge],
506 inputs: Mapping[str, ReadEdge],
507 outputs: Mapping[str, WriteEdge],
508 log_output: WriteEdge | None,
509 metadata_output: WriteEdge,
510 dimensions: DimensionGroup | frozenset[str],
511 ):
512 self.key = key
513 self.init = init
514 self.prerequisite_inputs = prerequisite_inputs
515 self.inputs = inputs
516 self.outputs = outputs
517 self.log_output = log_output
518 self.metadata_output = metadata_output
519 self._dimensions = dimensions
521 @staticmethod
522 def _from_imported_data(
523 key: NodeKey,
524 init_key: NodeKey,
525 data: _TaskNodeImportedData,
526 universe: DimensionUniverse | None,
527 ) -> TaskNode:
528 """Construct from a `PipelineTask` type and its configuration.
530 Parameters
531 ----------
532 key : `NodeKey`
533 Identifier for this node in networkx graphs.
534 init_key : `TaskInitNode`
535 Node representing the initialization of this task.
536 data : `_TaskNodeImportedData`
537 Internal struct that holds information that requires the task class
538 to have been be imported.
539 universe : `lsst.daf.butler.DimensionUniverse` or `None`
540 Definitions of all dimensions.
542 Returns
543 -------
544 node : `TaskNode`
545 New task node.
547 Raises
548 ------
549 ValueError
550 Raised if configuration validation failed when constructing
551 ``connections``.
552 """
553 init_inputs = {
554 name: ReadEdge._from_connection_map(init_key, name, data.connection_map)
555 for name in data.connections.initInputs
556 }
557 prerequisite_inputs = {
558 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True)
559 for name in data.connections.prerequisiteInputs
560 }
561 inputs = {
562 name: ReadEdge._from_connection_map(key, name, data.connection_map)
563 for name in data.connections.inputs
564 if not getattr(data.connections, name).deferBinding
565 }
566 init_outputs = {
567 name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
568 for name in data.connections.initOutputs
569 }
570 outputs = {
571 name: WriteEdge._from_connection_map(key, name, data.connection_map)
572 for name in data.connections.outputs
573 }
574 init = TaskInitNode(
575 key=init_key,
576 inputs=init_inputs,
577 outputs=init_outputs,
578 config_output=WriteEdge._from_connection_map(
579 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map
580 ),
581 imported_data=data,
582 )
583 instance = TaskNode(
584 key=key,
585 init=init,
586 prerequisite_inputs=prerequisite_inputs,
587 inputs=inputs,
588 outputs=outputs,
589 log_output=(
590 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map)
591 if data.config.saveLogOutput
592 else None
593 ),
594 metadata_output=WriteEdge._from_connection_map(
595 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
596 ),
597 dimensions=(
598 frozenset(data.connections.dimensions)
599 if universe is None
600 else universe.conform(data.connections.dimensions)
601 ),
602 )
603 return instance
605 key: NodeKey
606 """Key that identifies this node in internal and exported networkx graphs.
607 """
609 prerequisite_inputs: Mapping[str, ReadEdge]
610 """Graph edges that represent prerequisite inputs to this task.
612 Prerequisite inputs must already exist in the data repository when a
613 `QuantumGraph` is built, but have more flexibility in how they are looked
614 up than regular inputs.
615 """
617 inputs: Mapping[str, ReadEdge]
618 """Graph edges that represent regular runtime inputs to this task.
619 """
621 outputs: Mapping[str, WriteEdge]
622 """Graph edges that represent regular runtime outputs of this task.
624 This does not include the special `log_output` and `metadata_output` edges;
625 use `iter_all_outputs` to include that, too.
626 """
628 log_output: WriteEdge | None
629 """The special runtime output that persists the task's logs.
630 """
632 metadata_output: WriteEdge
633 """The special runtime output that persists the task's metadata.
634 """
636 @property
637 def label(self) -> str:
638 """Label of this configuration of a task in the pipeline."""
639 return self.key.name
641 @property
642 def is_imported(self) -> bool:
643 """Whether this the task type for this node has been imported and
644 its configuration overrides applied.
646 If this is `False`, the `task_class` and `config` attributes may not
647 be accessed.
648 """
649 return self.init.is_imported
651 @property
652 def task_class(self) -> type[PipelineTask]:
653 """Type object for the task.
655 Accessing this attribute when `is_imported` is `False` will raise
656 `TaskNotImportedError`, but accessing `task_class_name` will not.
657 """
658 return self.init.task_class
660 @property
661 def task_class_name(self) -> str:
662 """The fully-qualified string name of the task class."""
663 return self.init.task_class_name
665 @property
666 def config(self) -> PipelineTaskConfig:
667 """Configuration for the task.
669 This is always frozen.
671 Accessing this attribute when `is_imported` is `False` will raise
672 `TaskNotImportedError`, but calling `get_config_str` will not.
673 """
674 return self.init.config
676 @property
677 def has_resolved_dimensions(self) -> bool:
678 """Whether the `dimensions` attribute may be accessed.
680 If `False`, the `raw_dimensions` attribute may be used to obtain a
681 set of dimension names that has not been resolved by a
682 `~lsst.daf.butler.DimensionsUniverse`.
683 """
684 return type(self._dimensions) is DimensionGroup
686 @property
687 def dimensions(self) -> DimensionGroup:
688 """Standardized dimensions of the task."""
689 if not self.has_resolved_dimensions:
690 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.")
691 return cast(DimensionGroup, self._dimensions)
693 @property
694 def raw_dimensions(self) -> frozenset[str]:
695 """Raw dimensions of the task, with standardization by a
696 `~lsst.daf.butler.DimensionUniverse` not guaranteed.
697 """
698 if self.has_resolved_dimensions:
699 return frozenset(cast(DimensionGroup, self._dimensions).names)
700 else:
701 return cast(frozenset[str], self._dimensions)
703 def __repr__(self) -> str:
704 if self.has_resolved_dimensions:
705 return f"{self.label} ({self.task_class_name}, {self.dimensions})"
706 else:
707 return f"{self.label} ({self.task_class_name})"
709 def get_config_str(self) -> str:
710 """Return the configuration for this task as a string of override
711 statements.
713 Returns
714 -------
715 config_str : `str`
716 String containing configuration-overload statements.
717 """
718 return self.init.get_config_str()
720 def iter_all_inputs(self) -> Iterator[ReadEdge]:
721 """Iterate over all runtime inputs, including both regular inputs and
722 prerequisites.
724 Yields
725 ------
726 `ReadEdge`
727 All the runtime inputs.
728 """
729 yield from self.prerequisite_inputs.values()
730 yield from self.inputs.values()
732 def iter_all_outputs(self) -> Iterator[WriteEdge]:
733 """Iterate over all runtime outputs, including special ones.
735 Yields
736 ------
737 `ReadEdge`
738 All the runtime outputs.
739 """
740 yield from self.outputs.values()
741 yield self.metadata_output
742 if self.log_output is not None:
743 yield self.log_output
745 def diff_edges(self, other: TaskNode) -> list[str]:
746 """Compare the edges of this task node to those from the same task
747 label in a different pipeline.
749 This also calls `TaskInitNode.diff_edges`.
751 Parameters
752 ----------
753 other : `TaskInitNode`
754 Other node to compare to. Must have the same task label, but need
755 not have the same configuration or even the same task class.
757 Returns
758 -------
759 differences : `list` [ `str` ]
760 List of string messages describing differences between ``self`` and
761 ``other``. Will be empty if the two nodes have the same edges.
762 Messages will use 'A' to refer to ``self`` and 'B' to refer to
763 ``other``.
764 """
765 result = self.init.diff_edges(other.init)
766 result += _diff_edge_mapping(
767 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input"
768 )
769 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input")
770 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output")
771 if self.log_output is not None:
772 if other.log_output is not None:
773 result += self.log_output.diff(other.log_output, "log output")
774 else:
775 result.append("Log output is present in A, but not in B.")
776 elif other.log_output is not None:
777 result.append("Log output is present in B, but not in A.")
778 result += self.metadata_output.diff(other.metadata_output, "metadata output")
779 return result
781 def get_lookup_function(
782 self, connection_name: str
783 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None:
784 """Return the custom dataset query function for an edge, if one exists.
786 Parameters
787 ----------
788 connection_name : `str`
789 Name of the connection.
791 Returns
792 -------
793 lookup_function : `~collections.abc.Callable` or `None`
794 Callable that takes a dataset type, a butler registry, a data
795 coordinate (the quantum data ID), and an ordered list of
796 collections to search, and returns an iterable of
797 `~lsst.daf.butler.DatasetRef`.
798 """
799 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None)
801 def get_connections(self) -> PipelineTaskConnections:
802 """Return the connections class instance for this task.
804 Returns
805 -------
806 connections : `.PipelineTaskConnections`
807 Task-provided object that defines inputs and outputs from
808 configuration.
809 """
810 return self._get_imported_data().connections
812 def get_spatial_bounds_connections(self) -> frozenset[str]:
813 """Return the names of connections whose data IDs should be included
814 in the calculation of the spatial bounds for this task's quanta.
816 Returns
817 -------
818 connection_names : `frozenset` [ `str` ]
819 Names of connections with spatial dimensions.
820 """
821 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections())
823 def get_temporal_bounds_connections(self) -> frozenset[str]:
824 """Return the names of connections whose data IDs should be included
825 in the calculation of the temporal bounds for this task's quanta.
827 Returns
828 -------
829 connection_names : `frozenset` [ `str` ]
830 Names of connections with temporal dimensions.
831 """
832 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections())
834 def _imported_and_configured(self, rebuild: bool) -> TaskNode:
835 """Import the task class and use it to construct a new instance.
837 Parameters
838 ----------
839 rebuild : `bool`
840 If `True`, import the task class and configure its connections to
841 generate new edges that may differ from the current ones. If
842 `False`, import the task class but just update the `task_class` and
843 `config` attributes, and assume the edges have not changed.
845 Returns
846 -------
847 node : `TaskNode`
848 Task node instance for which `is_imported` is `True`. Will be
849 ``self`` if this is the case already.
850 """
851 from ..pipelineTask import PipelineTask
853 if self.is_imported:
854 return self
855 task_class = doImportType(self.task_class_name)
856 if not issubclass(task_class, PipelineTask):
857 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
858 config = task_class.ConfigClass()
859 config.loadFromString(self.get_config_str())
860 return self._reconfigured(config, rebuild=rebuild, task_class=task_class)
862 def _reconfigured(
863 self,
864 config: PipelineTaskConfig,
865 rebuild: bool,
866 task_class: type[PipelineTask] | None = None,
867 ) -> TaskNode:
868 """Return a version of this node with new configuration.
870 Parameters
871 ----------
872 config : `.PipelineTaskConfig`
873 New configuration for the task.
874 rebuild : `bool`
875 If `True`, use the configured connections to generate new edges
876 that may differ from the current ones. If `False`, just update the
877 `task_class` and `config` attributes, and assume the edges have not
878 changed.
879 task_class : `type` [ `PipelineTask` ], optional
880 Subclass of `PipelineTask`. This defaults to ``self.task_class`,
881 but may be passed as an argument if that is not available because
882 the task class was not imported when ``self`` was constructed.
884 Returns
885 -------
886 node : `TaskNode`
887 Task node instance with the new config.
888 """
889 if task_class is None:
890 task_class = self.task_class
891 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config)
892 if rebuild:
893 return self._from_imported_data(
894 self.key,
895 self.init.key,
896 imported_data,
897 universe=self._dimensions.universe if type(self._dimensions) is DimensionGroup else None,
898 )
899 else:
900 return TaskNode(
901 self.key,
902 TaskInitNode(
903 self.init.key,
904 inputs=self.init.inputs,
905 outputs=self.init.outputs,
906 config_output=self.init.config_output,
907 imported_data=imported_data,
908 ),
909 prerequisite_inputs=self.prerequisite_inputs,
910 inputs=self.inputs,
911 outputs=self.outputs,
912 log_output=self.log_output,
913 metadata_output=self.metadata_output,
914 dimensions=self._dimensions,
915 )
917 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
918 """Return an otherwise-equivalent task node with resolved dimensions.
920 Parameters
921 ----------
922 universe : `lsst.daf.butler.DimensionUniverse` or `None`
923 Definitions for all dimensions.
925 Returns
926 -------
927 node : `TaskNode`
928 Task node instance with `dimensions` resolved by the given
929 universe. Will be ``self`` if this is the case already.
930 """
931 if self.has_resolved_dimensions:
932 if cast(DimensionGroup, self._dimensions).universe is universe:
933 return self
934 elif universe is None:
935 return self
936 return TaskNode(
937 key=self.key,
938 init=self.init,
939 prerequisite_inputs=self.prerequisite_inputs,
940 inputs=self.inputs,
941 outputs=self.outputs,
942 log_output=self.log_output,
943 metadata_output=self.metadata_output,
944 dimensions=(
945 universe.conform(self.raw_dimensions) if universe is not None else self.raw_dimensions
946 ),
947 )
949 def _to_xgraph_state(self) -> dict[str, Any]:
950 """Convert this nodes's attributes into a dictionary suitable for use
951 in exported networkx graphs.
952 """
953 result = self.init._to_xgraph_state()
954 if self.has_resolved_dimensions:
955 result["dimensions"] = self._dimensions
956 result["raw_dimensions"] = self.raw_dimensions
957 return result
959 def _get_imported_data(self) -> _TaskNodeImportedData:
960 """Return the imported data struct.
962 Returns
963 -------
964 imported_data : `_TaskNodeImportedData`
965 Internal structure holding state that requires the task class to
966 have been imported.
968 Raises
969 ------
970 TaskNotImportedError
971 Raised if `is_imported` is `False`.
972 """
973 return self.init._get_imported_data()
975 @staticmethod
976 def _unreduce(kwargs: dict[str, Any]) -> TaskNode:
977 """Unpickle a `TaskNode` instance."""
978 return TaskNode(**kwargs)
980 def __reduce__(self) -> tuple[Callable[[dict[str, Any]], TaskNode], tuple[dict[str, Any]]]:
981 return (
982 self._unreduce,
983 (
984 dict(
985 key=self.key,
986 init=self.init,
987 prerequisite_inputs=self.prerequisite_inputs,
988 inputs=self.inputs,
989 outputs=self.outputs,
990 log_output=self.log_output,
991 metadata_output=self.metadata_output,
992 dimensions=self._dimensions,
993 ),
994 ),
995 )
998def _diff_edge_mapping(
999 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str
1000) -> list[str]:
1001 """Compare a pair of mappings of edges.
1003 Parameters
1004 ----------
1005 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
1006 First mapping to compare. Expected to have connection names as keys.
1007 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
1008 First mapping to compare. If keys differ from those of ``a_mapping``,
1009 this will be reported as a difference (in addition to element-wise
1010 comparisons).
1011 task_label : `str`
1012 Task label associated with both mappings.
1013 connection_type : `str`
1014 Type of connection (e.g. "input" or "init output") associated with both
1015 connections. This is a human-readable string to include in difference
1016 messages.
1018 Returns
1019 -------
1020 differences : `list` [ `str` ]
1021 List of string messages describing differences between the two
1022 mappings. Will be empty if the two mappings have the same edges.
1023 Messages will include "A" and "B", and are expected to be a preceded
1024 by a message describing what "A" and "B" are in the context in which
1025 this method is called.
1027 Notes
1028 -----
1029 This is expected to be used to compare one edge-holding mapping attribute
1030 of a task or task init node to the same attribute on another task or task
1031 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`,
1032 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`,
1033 `TaskInitNode.outputs`).
1034 """
1035 results = []
1036 b_to_do = set(b_mapping.keys())
1037 for connection_name, a_edge in a_mapping.items():
1038 if (b_edge := b_mapping.get(connection_name)) is None:
1039 results.append(
1040 f"{connection_type.capitalize()} {connection_name!r} of task "
1041 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
1042 )
1043 else:
1044 results.extend(a_edge.diff(b_edge, connection_type))
1045 b_to_do.discard(connection_name)
1046 for connection_name in b_to_do:
1047 results.append(
1048 f"{connection_type.capitalize()} {connection_name!r} of task "
1049 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
1050 )
1051 return results