Coverage for python/lsst/pipe/base/pipeline_graph/_tasks.py: 42%
260 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-30 10:01 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-30 10:01 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode")
31import dataclasses
32import enum
33from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
34from typing import TYPE_CHECKING, Any, cast
36from lsst.daf.butler import (
37 DataCoordinate,
38 DatasetRef,
39 DatasetType,
40 DimensionGroup,
41 DimensionUniverse,
42 Registry,
43)
44from lsst.utils.classes import immutable
45from lsst.utils.doImport import doImportType
46from lsst.utils.introspection import get_full_type_name
48from .. import automatic_connection_constants as acc
49from ..connections import PipelineTaskConnections
50from ..connectionTypes import BaseConnection, InitOutput, Output
51from ._edges import Edge, ReadEdge, WriteEdge
52from ._exceptions import TaskNotImportedError, UnresolvedGraphError
53from ._nodes import NodeKey, NodeType
55if TYPE_CHECKING:
56 from ..config import PipelineTaskConfig
57 from ..pipelineTask import PipelineTask
60class TaskImportMode(enum.Enum):
61 """Enumeration of the ways to handle importing tasks when reading a
62 serialized PipelineGraph.
63 """
65 DO_NOT_IMPORT = enum.auto()
66 """Do not import tasks or instantiate their configs and connections."""
68 REQUIRE_CONSISTENT_EDGES = enum.auto()
69 """Import tasks and instantiate their config and connection objects, and
70 check that the connections still define the same edges.
71 """
73 ASSUME_CONSISTENT_EDGES = enum.auto()
74 """Import tasks and instantiate their config and connection objects, but do
75 not check that the connections still define the same edges.
77 This is safe only when the caller knows the task definition has not changed
78 since the pipeline graph was persisted, such as when it was saved and
79 loaded with the same pipeline version.
80 """
82 OVERRIDE_EDGES = enum.auto()
83 """Import tasks and instantiate their config and connection objects, and
84 allow the edges defined in those connections to override those in the
85 persisted graph.
87 This may cause dataset type nodes to be unresolved, since resolutions
88 consistent with the original edges may be invalidated.
89 """
92@dataclasses.dataclass(frozen=True)
93class _TaskNodeImportedData:
94 """An internal struct that holds `TaskNode` and `TaskInitNode` state that
95 requires task classes to be imported.
96 """
98 task_class: type[PipelineTask]
99 """Type object for the task."""
101 config: PipelineTaskConfig
102 """Configuration object for the task."""
104 connection_map: dict[str, BaseConnection]
105 """Mapping from connection name to connection.
107 In addition to ``connections.allConnections``, this also holds the
108 "automatic" config, log, and metadata connections using the names defined
109 in the `.automatic_connection_constants` module.
110 """
112 connections: PipelineTaskConnections
113 """Configured connections object for the task."""
115 @classmethod
116 def configure(
117 cls,
118 label: str,
119 task_class: type[PipelineTask],
120 config: PipelineTaskConfig,
121 connections: PipelineTaskConnections | None = None,
122 ) -> _TaskNodeImportedData:
123 """Construct while creating a `PipelineTaskConnections` instance if
124 necessary.
126 Parameters
127 ----------
128 label : `str`
129 Label for the task in the pipeline. Only used in error messages.
130 task_class : `type` [ `.PipelineTask` ]
131 Pipeline task `type` object.
132 config : `.PipelineTaskConfig`
133 Configuration for the task.
134 connections : `.PipelineTaskConnections`, optional
135 Object that describes the dataset types used by the task. If not
136 provided, one will be constructed from the given configuration. If
137 provided, it is assumed that ``config`` has already been validated
138 and frozen.
140 Returns
141 -------
142 data : `_TaskNodeImportedData`
143 Instance of this struct.
144 """
145 if connections is None:
146 # If we don't have connections yet, assume the config hasn't been
147 # validated yet.
148 try:
149 config.validate()
150 except Exception as err:
151 raise ValueError(
152 f"Configuration validation failed for task {label!r} (see chained exception)."
153 ) from err
154 config.freeze()
155 connections = task_class.ConfigClass.ConnectionsClass(config=config)
156 connection_map = dict(connections.allConnections)
157 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput(
158 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
159 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
160 )
161 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output(
162 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
163 acc.METADATA_OUTPUT_STORAGE_CLASS,
164 dimensions=set(connections.dimensions),
165 )
166 if config.saveLogOutput:
167 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output(
168 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
169 acc.LOG_OUTPUT_STORAGE_CLASS,
170 dimensions=set(connections.dimensions),
171 )
172 return cls(task_class, config, connection_map, connections)
175@immutable
176class TaskInitNode:
177 """A node in a pipeline graph that represents the construction of a
178 `PipelineTask`.
180 Parameters
181 ----------
182 key : `NodeKey`
183 Key that identifies this node in internal and exported networkx graphs.
184 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
185 Graph edges that represent inputs required just to construct an
186 instance of this task, keyed by connection name.
187 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
188 Graph edges that represent outputs of this task that are available
189 after just constructing it, keyed by connection name.
191 This does not include the special `config_init_output` edge; use
192 `iter_all_outputs` to include that, too.
193 config_output : `WriteEdge`
194 The special init output edge that persists the task's configuration.
195 imported_data : `_TaskNodeImportedData`, optional
196 Internal struct that holds information that requires the task class to
197 have been be imported.
198 task_class_name : `str`, optional
199 Fully-qualified name of the task class. Must be provided if
200 ``imported_data`` is not.
201 config_str : `str`, optional
202 Configuration for the task as a string of override statements. Must be
203 provided if ``imported_data`` is not.
205 Notes
206 -----
207 When included in an exported `networkx` graph (e.g.
208 `PipelineGraph.make_xgraph`), task initialization nodes set the following
209 node attributes:
211 - ``task_class_name``
212 - ``bipartite`` (see `NodeType.bipartite`)
213 - ``task_class`` (only if `is_imported` is `True`)
214 - ``config`` (only if `is_importd` is `True`)
215 """
217 def __init__(
218 self,
219 key: NodeKey,
220 *,
221 inputs: Mapping[str, ReadEdge],
222 outputs: Mapping[str, WriteEdge],
223 config_output: WriteEdge,
224 imported_data: _TaskNodeImportedData | None = None,
225 task_class_name: str | None = None,
226 config_str: str | None = None,
227 ):
228 self.key = key
229 self.inputs = inputs
230 self.outputs = outputs
231 self.config_output = config_output
232 # Instead of setting attributes to None, we do not set them at all;
233 # this works better with the @immutable decorator, which supports
234 # deferred initialization but not reassignment.
235 if task_class_name is not None:
236 self._task_class_name = task_class_name
237 if config_str is not None:
238 self._config_str = config_str
239 if imported_data is not None:
240 self._imported_data = imported_data
241 else:
242 assert (
243 self._task_class_name is not None and self._config_str is not None
244 ), "If imported_data is not present, task_class_name and config_str must be."
246 key: NodeKey
247 """Key that identifies this node in internal and exported networkx graphs.
248 """
250 inputs: Mapping[str, ReadEdge]
251 """Graph edges that represent inputs required just to construct an instance
252 of this task, keyed by connection name.
253 """
255 outputs: Mapping[str, WriteEdge]
256 """Graph edges that represent outputs of this task that are available after
257 just constructing it, keyed by connection name.
259 This does not include the special `config_output` edge; use
260 `iter_all_outputs` to include that, too.
261 """
263 config_output: WriteEdge
264 """The special output edge that persists the task's configuration.
265 """
267 @property
268 def label(self) -> str:
269 """Label of this configuration of a task in the pipeline."""
270 return str(self.key)
272 @property
273 def is_imported(self) -> bool:
274 """Whether this the task type for this node has been imported and
275 its configuration overrides applied.
277 If this is `False`, the `task_class` and `config` attributes may not
278 be accessed.
279 """
280 return hasattr(self, "_imported_data")
282 @property
283 def task_class(self) -> type[PipelineTask]:
284 """Type object for the task.
286 Accessing this attribute when `is_imported` is `False` will raise
287 `TaskNotImportedError`, but accessing `task_class_name` will not.
288 """
289 return self._get_imported_data().task_class
291 @property
292 def task_class_name(self) -> str:
293 """The fully-qualified string name of the task class."""
294 try:
295 return self._task_class_name
296 except AttributeError:
297 pass
298 self._task_class_name = get_full_type_name(self.task_class)
299 return self._task_class_name
301 @property
302 def config(self) -> PipelineTaskConfig:
303 """Configuration for the task.
305 This is always frozen.
307 Accessing this attribute when `is_imported` is `False` will raise
308 `TaskNotImportedError`, but calling `get_config_str` will not.
309 """
310 return self._get_imported_data().config
312 def __repr__(self) -> str:
313 return f"{self.label} [init] ({self.task_class_name})"
315 def get_config_str(self) -> str:
316 """Return the configuration for this task as a string of override
317 statements.
319 Returns
320 -------
321 config_str : `str`
322 String containing configuration-overload statements.
323 """
324 try:
325 return self._config_str
326 except AttributeError:
327 pass
328 self._config_str = self.config.saveToString()
329 return self._config_str
331 def iter_all_inputs(self) -> Iterator[ReadEdge]:
332 """Iterate over all inputs required for construction.
334 This is the same as iteration over ``inputs.values()``, but it will be
335 updated to include any automatic init-input connections added in the
336 future, while `inputs` will continue to hold only task-defined init
337 inputs.
339 Yields
340 ------
341 `ReadEdge`
342 All the inputs required for construction.
343 """
344 return iter(self.inputs.values())
346 def iter_all_outputs(self) -> Iterator[WriteEdge]:
347 """Iterate over all outputs available after construction, including
348 special ones.
350 Yields
351 ------
352 `ReadEdge`
353 All the outputs available after construction.
354 """
355 yield from self.outputs.values()
356 yield self.config_output
358 def diff_edges(self, other: TaskInitNode) -> list[str]:
359 """Compare the edges of this task initialization node to those from the
360 same task label in a different pipeline.
362 Parameters
363 ----------
364 other : `TaskInitNode`
365 Other node to compare to. Must have the same task label, but need
366 not have the same configuration or even the same task class.
368 Returns
369 -------
370 differences : `list` [ `str` ]
371 List of string messages describing differences between ``self`` and
372 ``other``. Will be empty if the two nodes have the same edges.
373 Messages will use 'A' to refer to ``self`` and 'B' to refer to
374 ``other``.
375 """
376 result = []
377 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input")
378 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output")
379 result += self.config_output.diff(other.config_output, "config init output")
380 return result
382 def _to_xgraph_state(self) -> dict[str, Any]:
383 """Convert this nodes's attributes into a dictionary suitable for use
384 in exported networkx graphs.
385 """
386 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
387 if hasattr(self, "_imported_data"):
388 result["task_class"] = self.task_class
389 result["config"] = self.config
390 return result
392 def _get_imported_data(self) -> _TaskNodeImportedData:
393 """Return the imported data struct.
395 Returns
396 -------
397 imported_data : `_TaskNodeImportedData`
398 Internal structure holding state that requires the task class to
399 have been imported.
401 Raises
402 ------
403 TaskNotImportedError
404 Raised if `is_imported` is `False`.
405 """
406 try:
407 return self._imported_data
408 except AttributeError:
409 raise TaskNotImportedError(
410 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported "
411 "(see PipelineGraph.import_and_configure)."
412 ) from None
415@immutable
416class TaskNode:
417 """A node in a pipeline graph that represents a labeled configuration of a
418 `PipelineTask`.
420 Parameters
421 ----------
422 key : `NodeKey`
423 Identifier for this node in networkx graphs.
424 init : `TaskInitNode`
425 Node representing the initialization of this task.
426 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
427 Graph edges that represent prerequisite inputs to this task, keyed by
428 connection name.
430 Prerequisite inputs must already exist in the data repository when a
431 `QuantumGraph` is built, but have more flexibility in how they are
432 looked up than regular inputs.
433 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
434 Graph edges that represent regular runtime inputs to this task, keyed
435 by connection name.
436 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
437 Graph edges that represent regular runtime outputs of this task, keyed
438 by connection name.
440 This does not include the special `log_output` and `metadata_output`
441 edges; use `iter_all_outputs` to include that, too.
442 log_output : `WriteEdge` or `None`
443 The special runtime output that persists the task's logs.
444 metadata_output : `WriteEdge`
445 The special runtime output that persists the task's metadata.
446 dimensions : `lsst.daf.butler.DimensionGroup` or `frozenset`
447 Dimensions of the task. If a `frozenset`, the dimensions have not been
448 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely
449 compared to other sets of dimensions.
451 Notes
452 -----
453 Task nodes are intentionally not equality comparable, since there are many
454 different (and useful) ways to compare these objects with no clear winner
455 as the most obvious behavior.
457 When included in an exported `networkx` graph (e.g.
458 `PipelineGraph.make_xgraph`), task nodes set the following node attributes:
460 - ``task_class_name``
461 - ``bipartite`` (see `NodeType.bipartite`)
462 - ``task_class`` (only if `is_imported` is `True`)
463 - ``config`` (only if `is_importd` is `True`)
464 """
466 def __init__(
467 self,
468 key: NodeKey,
469 init: TaskInitNode,
470 *,
471 prerequisite_inputs: Mapping[str, ReadEdge],
472 inputs: Mapping[str, ReadEdge],
473 outputs: Mapping[str, WriteEdge],
474 log_output: WriteEdge | None,
475 metadata_output: WriteEdge,
476 dimensions: DimensionGroup | frozenset,
477 ):
478 self.key = key
479 self.init = init
480 self.prerequisite_inputs = prerequisite_inputs
481 self.inputs = inputs
482 self.outputs = outputs
483 self.log_output = log_output
484 self.metadata_output = metadata_output
485 self._dimensions = dimensions
487 @staticmethod
488 def _from_imported_data(
489 key: NodeKey,
490 init_key: NodeKey,
491 data: _TaskNodeImportedData,
492 universe: DimensionUniverse | None,
493 ) -> TaskNode:
494 """Construct from a `PipelineTask` type and its configuration.
496 Parameters
497 ----------
498 key : `NodeKey`
499 Identifier for this node in networkx graphs.
500 init_key : `TaskInitNode`
501 Node representing the initialization of this task.
502 data : `_TaskNodeImportedData`
503 Internal struct that holds information that requires the task class
504 to have been be imported.
505 universe : `lsst.daf.butler.DimensionUniverse` or `None`
506 Definitions of all dimensions.
508 Returns
509 -------
510 node : `TaskNode`
511 New task node.
513 Raises
514 ------
515 ValueError
516 Raised if configuration validation failed when constructing
517 ``connections``.
518 """
519 init_inputs = {
520 name: ReadEdge._from_connection_map(init_key, name, data.connection_map)
521 for name in data.connections.initInputs
522 }
523 prerequisite_inputs = {
524 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True)
525 for name in data.connections.prerequisiteInputs
526 }
527 inputs = {
528 name: ReadEdge._from_connection_map(key, name, data.connection_map)
529 for name in data.connections.inputs
530 if not getattr(data.connections, name).deferBinding
531 }
532 init_outputs = {
533 name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
534 for name in data.connections.initOutputs
535 }
536 outputs = {
537 name: WriteEdge._from_connection_map(key, name, data.connection_map)
538 for name in data.connections.outputs
539 }
540 init = TaskInitNode(
541 key=init_key,
542 inputs=init_inputs,
543 outputs=init_outputs,
544 config_output=WriteEdge._from_connection_map(
545 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map
546 ),
547 imported_data=data,
548 )
549 instance = TaskNode(
550 key=key,
551 init=init,
552 prerequisite_inputs=prerequisite_inputs,
553 inputs=inputs,
554 outputs=outputs,
555 log_output=(
556 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map)
557 if data.config.saveLogOutput
558 else None
559 ),
560 metadata_output=WriteEdge._from_connection_map(
561 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
562 ),
563 dimensions=(
564 frozenset(data.connections.dimensions)
565 if universe is None
566 else universe.conform(data.connections.dimensions)
567 ),
568 )
569 return instance
571 key: NodeKey
572 """Key that identifies this node in internal and exported networkx graphs.
573 """
575 prerequisite_inputs: Mapping[str, ReadEdge]
576 """Graph edges that represent prerequisite inputs to this task.
578 Prerequisite inputs must already exist in the data repository when a
579 `QuantumGraph` is built, but have more flexibility in how they are looked
580 up than regular inputs.
581 """
583 inputs: Mapping[str, ReadEdge]
584 """Graph edges that represent regular runtime inputs to this task.
585 """
587 outputs: Mapping[str, WriteEdge]
588 """Graph edges that represent regular runtime outputs of this task.
590 This does not include the special `log_output` and `metadata_output` edges;
591 use `iter_all_outputs` to include that, too.
592 """
594 log_output: WriteEdge | None
595 """The special runtime output that persists the task's logs.
596 """
598 metadata_output: WriteEdge
599 """The special runtime output that persists the task's metadata.
600 """
602 @property
603 def label(self) -> str:
604 """Label of this configuration of a task in the pipeline."""
605 return self.key.name
607 @property
608 def is_imported(self) -> bool:
609 """Whether this the task type for this node has been imported and
610 its configuration overrides applied.
612 If this is `False`, the `task_class` and `config` attributes may not
613 be accessed.
614 """
615 return self.init.is_imported
617 @property
618 def task_class(self) -> type[PipelineTask]:
619 """Type object for the task.
621 Accessing this attribute when `is_imported` is `False` will raise
622 `TaskNotImportedError`, but accessing `task_class_name` will not.
623 """
624 return self.init.task_class
626 @property
627 def task_class_name(self) -> str:
628 """The fully-qualified string name of the task class."""
629 return self.init.task_class_name
631 @property
632 def config(self) -> PipelineTaskConfig:
633 """Configuration for the task.
635 This is always frozen.
637 Accessing this attribute when `is_imported` is `False` will raise
638 `TaskNotImportedError`, but calling `get_config_str` will not.
639 """
640 return self.init.config
642 @property
643 def has_resolved_dimensions(self) -> bool:
644 """Whether the `dimensions` attribute may be accessed.
646 If `False`, the `raw_dimensions` attribute may be used to obtain a
647 set of dimension names that has not been resolved by a
648 `~lsst.daf.butler.DimensionsUniverse`.
649 """
650 return type(self._dimensions) is DimensionGroup
652 @property
653 def dimensions(self) -> DimensionGroup:
654 """Standardized dimensions of the task."""
655 if not self.has_resolved_dimensions:
656 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.")
657 return cast(DimensionGroup, self._dimensions)
659 @property
660 def raw_dimensions(self) -> frozenset[str]:
661 """Raw dimensions of the task, with standardization by a
662 `~lsst.daf.butler.DimensionUniverse` not guaranteed.
663 """
664 if self.has_resolved_dimensions:
665 return frozenset(cast(DimensionGroup, self._dimensions).names)
666 else:
667 return cast(frozenset[str], self._dimensions)
669 def __repr__(self) -> str:
670 if self.has_resolved_dimensions:
671 return f"{self.label} ({self.task_class_name}, {self.dimensions})"
672 else:
673 return f"{self.label} ({self.task_class_name})"
675 def get_config_str(self) -> str:
676 """Return the configuration for this task as a string of override
677 statements.
679 Returns
680 -------
681 config_str : `str`
682 String containing configuration-overload statements.
683 """
684 return self.init.get_config_str()
686 def iter_all_inputs(self) -> Iterator[ReadEdge]:
687 """Iterate over all runtime inputs, including both regular inputs and
688 prerequisites.
690 Yields
691 ------
692 `ReadEdge`
693 All the runtime inputs.
694 """
695 yield from self.prerequisite_inputs.values()
696 yield from self.inputs.values()
698 def iter_all_outputs(self) -> Iterator[WriteEdge]:
699 """Iterate over all runtime outputs, including special ones.
701 Yields
702 ------
703 `ReadEdge`
704 All the runtime outputs.
705 """
706 yield from self.outputs.values()
707 yield self.metadata_output
708 if self.log_output is not None:
709 yield self.log_output
711 def diff_edges(self, other: TaskNode) -> list[str]:
712 """Compare the edges of this task node to those from the same task
713 label in a different pipeline.
715 This also calls `TaskInitNode.diff_edges`.
717 Parameters
718 ----------
719 other : `TaskInitNode`
720 Other node to compare to. Must have the same task label, but need
721 not have the same configuration or even the same task class.
723 Returns
724 -------
725 differences : `list` [ `str` ]
726 List of string messages describing differences between ``self`` and
727 ``other``. Will be empty if the two nodes have the same edges.
728 Messages will use 'A' to refer to ``self`` and 'B' to refer to
729 ``other``.
730 """
731 result = self.init.diff_edges(other.init)
732 result += _diff_edge_mapping(
733 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input"
734 )
735 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input")
736 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output")
737 if self.log_output is not None:
738 if other.log_output is not None:
739 result += self.log_output.diff(other.log_output, "log output")
740 else:
741 result.append("Log output is present in A, but not in B.")
742 elif other.log_output is not None:
743 result.append("Log output is present in B, but not in A.")
744 result += self.metadata_output.diff(other.metadata_output, "metadata output")
745 return result
747 def get_lookup_function(
748 self, connection_name: str
749 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None:
750 """Return the custom dataset query function for an edge, if one exists.
752 Parameters
753 ----------
754 connection_name : `str`
755 Name of the connection.
757 Returns
758 -------
759 lookup_function : `~collections.abc.Callable` or `None`
760 Callable that takes a dataset type, a butler registry, a data
761 coordinate (the quantum data ID), and an ordered list of
762 collections to search, and returns an iterable of
763 `~lsst.daf.butler.DatasetRef`.
764 """
765 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None)
767 def get_spatial_bounds_connections(self) -> frozenset[str]:
768 """Return the names of connections whose data IDs should be included
769 in the calculation of the spatial bounds for this task's quanta.
771 Returns
772 -------
773 connection_names : `frozenset` [ `str` ]
774 Names of connections with spatial dimensions.
775 """
776 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections())
778 def get_temporal_bounds_connections(self) -> frozenset[str]:
779 """Return the names of connections whose data IDs should be included
780 in the calculation of the temporal bounds for this task's quanta.
782 Returns
783 -------
784 connection_names : `frozenset` [ `str` ]
785 Names of connections with temporal dimensions.
786 """
787 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections())
789 def _imported_and_configured(self, rebuild: bool) -> TaskNode:
790 """Import the task class and use it to construct a new instance.
792 Parameters
793 ----------
794 rebuild : `bool`
795 If `True`, import the task class and configure its connections to
796 generate new edges that may differ from the current ones. If
797 `False`, import the task class but just update the `task_class` and
798 `config` attributes, and assume the edges have not changed.
800 Returns
801 -------
802 node : `TaskNode`
803 Task node instance for which `is_imported` is `True`. Will be
804 ``self`` if this is the case already.
805 """
806 from ..pipelineTask import PipelineTask
808 if self.is_imported:
809 return self
810 task_class = doImportType(self.task_class_name)
811 if not issubclass(task_class, PipelineTask):
812 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
813 config = task_class.ConfigClass()
814 config.loadFromString(self.get_config_str())
815 return self._reconfigured(config, rebuild=rebuild, task_class=task_class)
817 def _reconfigured(
818 self,
819 config: PipelineTaskConfig,
820 rebuild: bool,
821 task_class: type[PipelineTask] | None = None,
822 ) -> TaskNode:
823 """Return a version of this node with new configuration.
825 Parameters
826 ----------
827 config : `.PipelineTaskConfig`
828 New configuration for the task.
829 rebuild : `bool`
830 If `True`, use the configured connections to generate new edges
831 that may differ from the current ones. If `False`, just update the
832 `task_class` and `config` attributes, and assume the edges have not
833 changed.
834 task_class : `type` [ `PipelineTask` ], optional
835 Subclass of `PipelineTask`. This defaults to ``self.task_class`,
836 but may be passed as an argument if that is not available because
837 the task class was not imported when ``self`` was constructed.
839 Returns
840 -------
841 node : `TaskNode`
842 Task node instance with the new config.
843 """
844 if task_class is None:
845 task_class = self.task_class
846 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config)
847 if rebuild:
848 return self._from_imported_data(
849 self.key,
850 self.init.key,
851 imported_data,
852 universe=self._dimensions.universe if type(self._dimensions) is DimensionGroup else None,
853 )
854 else:
855 return TaskNode(
856 self.key,
857 TaskInitNode(
858 self.init.key,
859 inputs=self.init.inputs,
860 outputs=self.init.outputs,
861 config_output=self.init.config_output,
862 imported_data=imported_data,
863 ),
864 prerequisite_inputs=self.prerequisite_inputs,
865 inputs=self.inputs,
866 outputs=self.outputs,
867 log_output=self.log_output,
868 metadata_output=self.metadata_output,
869 dimensions=self._dimensions,
870 )
872 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
873 """Return an otherwise-equivalent task node with resolved dimensions.
875 Parameters
876 ----------
877 universe : `lsst.daf.butler.DimensionUniverse` or `None`
878 Definitions for all dimensions.
880 Returns
881 -------
882 node : `TaskNode`
883 Task node instance with `dimensions` resolved by the given
884 universe. Will be ``self`` if this is the case already.
885 """
886 if self.has_resolved_dimensions:
887 if cast(DimensionGroup, self._dimensions).universe is universe:
888 return self
889 elif universe is None:
890 return self
891 return TaskNode(
892 key=self.key,
893 init=self.init,
894 prerequisite_inputs=self.prerequisite_inputs,
895 inputs=self.inputs,
896 outputs=self.outputs,
897 log_output=self.log_output,
898 metadata_output=self.metadata_output,
899 dimensions=(
900 universe.conform(self.raw_dimensions) if universe is not None else self.raw_dimensions
901 ),
902 )
904 def _to_xgraph_state(self) -> dict[str, Any]:
905 """Convert this nodes's attributes into a dictionary suitable for use
906 in exported networkx graphs.
907 """
908 result = self.init._to_xgraph_state()
909 if self.has_resolved_dimensions:
910 result["dimensions"] = self._dimensions
911 result["raw_dimensions"] = self.raw_dimensions
912 return result
914 def _get_imported_data(self) -> _TaskNodeImportedData:
915 """Return the imported data struct.
917 Returns
918 -------
919 imported_data : `_TaskNodeImportedData`
920 Internal structure holding state that requires the task class to
921 have been imported.
923 Raises
924 ------
925 TaskNotImportedError
926 Raised if `is_imported` is `False`.
927 """
928 return self.init._get_imported_data()
931def _diff_edge_mapping(
932 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str
933) -> list[str]:
934 """Compare a pair of mappings of edges.
936 Parameters
937 ----------
938 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
939 First mapping to compare. Expected to have connection names as keys.
940 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
941 First mapping to compare. If keys differ from those of ``a_mapping``,
942 this will be reported as a difference (in addition to element-wise
943 comparisons).
944 task_label : `str`
945 Task label associated with both mappings.
946 connection_type : `str`
947 Type of connection (e.g. "input" or "init output") associated with both
948 connections. This is a human-readable string to include in difference
949 messages.
951 Returns
952 -------
953 differences : `list` [ `str` ]
954 List of string messages describing differences between the two
955 mappings. Will be empty if the two mappings have the same edges.
956 Messages will include "A" and "B", and are expected to be a preceded
957 by a message describing what "A" and "B" are in the context in which
958 this method is called.
960 Notes
961 -----
962 This is expected to be used to compare one edge-holding mapping attribute
963 of a task or task init node to the same attribute on another task or task
964 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`,
965 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`,
966 `TaskInitNode.outputs`).
967 """
968 results = []
969 b_to_do = set(b_mapping.keys())
970 for connection_name, a_edge in a_mapping.items():
971 if (b_edge := b_mapping.get(connection_name)) is None:
972 results.append(
973 f"{connection_type.capitalize()} {connection_name!r} of task "
974 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
975 )
976 else:
977 results.extend(a_edge.diff(b_edge, connection_type))
978 b_to_do.discard(connection_name)
979 for connection_name in b_to_do:
980 results.append(
981 f"{connection_type.capitalize()} {connection_name!r} of task "
982 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
983 )
984 return results