Coverage for python/lsst/pipe/base/pipeline_graph/_tasks.py: 42%
260 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-11 17:45 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-11 17:45 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode")
31import dataclasses
32import enum
33from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
34from typing import TYPE_CHECKING, Any, cast
36from lsst.daf.butler import (
37 DataCoordinate,
38 DatasetRef,
39 DatasetType,
40 DimensionGroup,
41 DimensionUniverse,
42 Registry,
43)
44from lsst.utils.classes import immutable
45from lsst.utils.doImport import doImportType
46from lsst.utils.introspection import get_full_type_name
48from .. import automatic_connection_constants as acc
49from ..connections import PipelineTaskConnections
50from ..connectionTypes import BaseConnection, InitOutput, Output
51from ._edges import Edge, ReadEdge, WriteEdge
52from ._exceptions import TaskNotImportedError, UnresolvedGraphError
53from ._nodes import NodeKey, NodeType
55if TYPE_CHECKING:
56 from ..config import PipelineTaskConfig
57 from ..pipelineTask import PipelineTask
60class TaskImportMode(enum.Enum):
61 """Enumeration of the ways to handle importing tasks when reading a
62 serialized PipelineGraph.
63 """
65 DO_NOT_IMPORT = enum.auto()
66 """Do not import tasks or instantiate their configs and connections."""
68 REQUIRE_CONSISTENT_EDGES = enum.auto()
69 """Import tasks and instantiate their config and connection objects, and
70 check that the connections still define the same edges.
71 """
73 ASSUME_CONSISTENT_EDGES = enum.auto()
74 """Import tasks and instantiate their config and connection objects, but do
75 not check that the connections still define the same edges.
77 This is safe only when the caller knows the task definition has not changed
78 since the pipeline graph was persisted, such as when it was saved and
79 loaded with the same pipeline version.
80 """
82 OVERRIDE_EDGES = enum.auto()
83 """Import tasks and instantiate their config and connection objects, and
84 allow the edges defined in those connections to override those in the
85 persisted graph.
87 This may cause dataset type nodes to be unresolved, since resolutions
88 consistent with the original edges may be invalidated.
89 """
92@dataclasses.dataclass(frozen=True)
93class _TaskNodeImportedData:
94 """An internal struct that holds `TaskNode` and `TaskInitNode` state that
95 requires task classes to be imported.
96 """
98 task_class: type[PipelineTask]
99 """Type object for the task."""
101 config: PipelineTaskConfig
102 """Configuration object for the task."""
104 connection_map: dict[str, BaseConnection]
105 """Mapping from connection name to connection.
107 In addition to ``connections.allConnections``, this also holds the
108 "automatic" config, log, and metadata connections using the names defined
109 in the `.automatic_connection_constants` module.
110 """
112 connections: PipelineTaskConnections
113 """Configured connections object for the task."""
115 @classmethod
116 def configure(
117 cls,
118 label: str,
119 task_class: type[PipelineTask],
120 config: PipelineTaskConfig,
121 connections: PipelineTaskConnections | None = None,
122 ) -> _TaskNodeImportedData:
123 """Construct while creating a `PipelineTaskConnections` instance if
124 necessary.
126 Parameters
127 ----------
128 label : `str`
129 Label for the task in the pipeline. Only used in error messages.
130 task_class : `type` [ `.PipelineTask` ]
131 Pipeline task `type` object.
132 config : `.PipelineTaskConfig`
133 Configuration for the task.
134 connections : `.PipelineTaskConnections`, optional
135 Object that describes the dataset types used by the task. If not
136 provided, one will be constructed from the given configuration. If
137 provided, it is assumed that ``config`` has already been validated
138 and frozen.
140 Returns
141 -------
142 data : `_TaskNodeImportedData`
143 Instance of this struct.
144 """
145 if connections is None:
146 # If we don't have connections yet, assume the config hasn't been
147 # validated yet.
148 try:
149 config.validate()
150 except Exception as err:
151 raise ValueError(
152 f"Configuration validation failed for task {label!r} (see chained exception)."
153 ) from err
154 config.freeze()
155 connections = task_class.ConfigClass.ConnectionsClass(config=config)
156 connection_map = dict(connections.allConnections)
157 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput(
158 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
159 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
160 )
161 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output(
162 acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
163 acc.METADATA_OUTPUT_STORAGE_CLASS,
164 dimensions=set(connections.dimensions),
165 )
166 if config.saveLogOutput:
167 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output(
168 acc.LOG_OUTPUT_TEMPLATE.format(label=label),
169 acc.LOG_OUTPUT_STORAGE_CLASS,
170 dimensions=set(connections.dimensions),
171 )
172 return cls(task_class, config, connection_map, connections)
175@immutable
176class TaskInitNode:
177 """A node in a pipeline graph that represents the construction of a
178 `PipelineTask`.
180 Parameters
181 ----------
182 key : `NodeKey`
183 Key that identifies this node in internal and exported networkx graphs.
184 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
185 Graph edges that represent inputs required just to construct an
186 instance of this task, keyed by connection name.
187 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
188 Graph edges that represent outputs of this task that are available
189 after just constructing it, keyed by connection name.
191 This does not include the special `config_init_output` edge; use
192 `iter_all_outputs` to include that, too.
193 config_output : `WriteEdge`
194 The special init output edge that persists the task's configuration.
195 imported_data : `_TaskNodeImportedData`, optional
196 Internal struct that holds information that requires the task class to
197 have been be imported.
198 task_class_name : `str`, optional
199 Fully-qualified name of the task class. Must be provided if
200 ``imported_data`` is not.
201 config_str : `str`, optional
202 Configuration for the task as a string of override statements. Must be
203 provided if ``imported_data`` is not.
205 Notes
206 -----
207 When included in an exported `networkx` graph (e.g.
208 `PipelineGraph.make_xgraph`), task initialization nodes set the following
209 node attributes:
211 - ``task_class_name``
212 - ``bipartite`` (see `NodeType.bipartite`)
213 - ``task_class`` (only if `is_imported` is `True`)
214 - ``config`` (only if `is_importd` is `True`)
215 """
217 def __init__(
218 self,
219 key: NodeKey,
220 *,
221 inputs: Mapping[str, ReadEdge],
222 outputs: Mapping[str, WriteEdge],
223 config_output: WriteEdge,
224 imported_data: _TaskNodeImportedData | None = None,
225 task_class_name: str | None = None,
226 config_str: str | None = None,
227 ):
228 self.key = key
229 self.inputs = inputs
230 self.outputs = outputs
231 self.config_output = config_output
232 # Instead of setting attributes to None, we do not set them at all;
233 # this works better with the @immutable decorator, which supports
234 # deferred initialization but not reassignment.
235 if task_class_name is not None:
236 self._task_class_name = task_class_name
237 if config_str is not None:
238 self._config_str = config_str
239 if imported_data is not None:
240 self._imported_data = imported_data
241 else:
242 assert (
243 self._task_class_name is not None and self._config_str is not None
244 ), "If imported_data is not present, task_class_name and config_str must be."
246 key: NodeKey
247 """Key that identifies this node in internal and exported networkx graphs.
248 """
250 inputs: Mapping[str, ReadEdge]
251 """Graph edges that represent inputs required just to construct an instance
252 of this task, keyed by connection name.
253 """
255 outputs: Mapping[str, WriteEdge]
256 """Graph edges that represent outputs of this task that are available after
257 just constructing it, keyed by connection name.
259 This does not include the special `config_output` edge; use
260 `iter_all_outputs` to include that, too.
261 """
263 config_output: WriteEdge
264 """The special output edge that persists the task's configuration.
265 """
267 @property
268 def label(self) -> str:
269 """Label of this configuration of a task in the pipeline."""
270 return str(self.key)
272 @property
273 def is_imported(self) -> bool:
274 """Whether this the task type for this node has been imported and
275 its configuration overrides applied.
277 If this is `False`, the `task_class` and `config` attributes may not
278 be accessed.
279 """
280 return hasattr(self, "_imported_data")
282 @property
283 def task_class(self) -> type[PipelineTask]:
284 """Type object for the task.
286 Accessing this attribute when `is_imported` is `False` will raise
287 `TaskNotImportedError`, but accessing `task_class_name` will not.
288 """
289 return self._get_imported_data().task_class
291 @property
292 def task_class_name(self) -> str:
293 """The fully-qualified string name of the task class."""
294 try:
295 return self._task_class_name
296 except AttributeError:
297 pass
298 self._task_class_name = get_full_type_name(self.task_class)
299 return self._task_class_name
301 @property
302 def config(self) -> PipelineTaskConfig:
303 """Configuration for the task.
305 This is always frozen.
307 Accessing this attribute when `is_imported` is `False` will raise
308 `TaskNotImportedError`, but calling `get_config_str` will not.
309 """
310 return self._get_imported_data().config
312 def __repr__(self) -> str:
313 return f"{self.label} [init] ({self.task_class_name})"
315 def get_config_str(self) -> str:
316 """Return the configuration for this task as a string of override
317 statements.
319 Returns
320 -------
321 config_str : `str`
322 String containing configuration-overload statements.
323 """
324 try:
325 return self._config_str
326 except AttributeError:
327 pass
328 self._config_str = self.config.saveToString()
329 return self._config_str
331 def iter_all_inputs(self) -> Iterator[ReadEdge]:
332 """Iterate over all inputs required for construction.
334 This is the same as iteration over ``inputs.values()``, but it will be
335 updated to include any automatic init-input connections added in the
336 future, while `inputs` will continue to hold only task-defined init
337 inputs.
339 Yields
340 ------
341 `ReadEdge`
342 All the inputs required for construction.
343 """
344 return iter(self.inputs.values())
346 def iter_all_outputs(self) -> Iterator[WriteEdge]:
347 """Iterate over all outputs available after construction, including
348 special ones.
350 Yields
351 ------
352 `ReadEdge`
353 All the outputs available after construction.
354 """
355 yield from self.outputs.values()
356 yield self.config_output
358 def diff_edges(self, other: TaskInitNode) -> list[str]:
359 """Compare the edges of this task initialization node to those from the
360 same task label in a different pipeline.
362 Parameters
363 ----------
364 other : `TaskInitNode`
365 Other node to compare to. Must have the same task label, but need
366 not have the same configuration or even the same task class.
368 Returns
369 -------
370 differences : `list` [ `str` ]
371 List of string messages describing differences between ``self`` and
372 ``other``. Will be empty if the two nodes have the same edges.
373 Messages will use 'A' to refer to ``self`` and 'B' to refer to
374 ``other``.
375 """
376 result = []
377 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input")
378 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output")
379 result += self.config_output.diff(other.config_output, "config init output")
380 return result
382 def _to_xgraph_state(self) -> dict[str, Any]:
383 """Convert this nodes's attributes into a dictionary suitable for use
384 in exported networkx graphs.
385 """
386 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
387 if hasattr(self, "_imported_data"):
388 result["task_class"] = self.task_class
389 result["config"] = self.config
390 return result
392 def _get_imported_data(self) -> _TaskNodeImportedData:
393 """Return the imported data struct.
395 Returns
396 -------
397 imported_data : `_TaskNodeImportedData`
398 Internal structure holding state that requires the task class to
399 have been imported.
401 Raises
402 ------
403 TaskNotImportedError
404 Raised if `is_imported` is `False`.
405 """
406 try:
407 return self._imported_data
408 except AttributeError:
409 raise TaskNotImportedError(
410 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported "
411 "(see PipelineGraph.import_and_configure)."
412 ) from None
415@immutable
416class TaskNode:
417 """A node in a pipeline graph that represents a labeled configuration of a
418 `PipelineTask`.
420 Parameters
421 ----------
422 key : `NodeKey`
423 Identifier for this node in networkx graphs.
424 init : `TaskInitNode`
425 Node representing the initialization of this task.
426 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
427 Graph edges that represent prerequisite inputs to this task, keyed by
428 connection name.
430 Prerequisite inputs must already exist in the data repository when a
431 `QuantumGraph` is built, but have more flexibility in how they are
432 looked up than regular inputs.
433 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
434 Graph edges that represent regular runtime inputs to this task, keyed
435 by connection name.
436 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
437 Graph edges that represent regular runtime outputs of this task, keyed
438 by connection name.
440 This does not include the special `log_output` and `metadata_output`
441 edges; use `iter_all_outputs` to include that, too.
442 log_output : `WriteEdge` or `None`
443 The special runtime output that persists the task's logs.
444 metadata_output : `WriteEdge`
445 The special runtime output that persists the task's metadata.
446 dimensions : `lsst.daf.butler.DimensionGroup` or `frozenset`
447 Dimensions of the task. If a `frozenset`, the dimensions have not been
448 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely
449 compared to other sets of dimensions.
451 Notes
452 -----
453 Task nodes are intentionally not equality comparable, since there are many
454 different (and useful) ways to compare these objects with no clear winner
455 as the most obvious behavior.
457 When included in an exported `networkx` graph (e.g.
458 `PipelineGraph.make_xgraph`), task nodes set the following node attributes:
460 - ``task_class_name``
461 - ``bipartite`` (see `NodeType.bipartite`)
462 - ``task_class`` (only if `is_imported` is `True`)
463 - ``config`` (only if `is_importd` is `True`)
464 """
466 def __init__(
467 self,
468 key: NodeKey,
469 init: TaskInitNode,
470 *,
471 prerequisite_inputs: Mapping[str, ReadEdge],
472 inputs: Mapping[str, ReadEdge],
473 outputs: Mapping[str, WriteEdge],
474 log_output: WriteEdge | None,
475 metadata_output: WriteEdge,
476 dimensions: DimensionGroup | frozenset,
477 ):
478 self.key = key
479 self.init = init
480 self.prerequisite_inputs = prerequisite_inputs
481 self.inputs = inputs
482 self.outputs = outputs
483 self.log_output = log_output
484 self.metadata_output = metadata_output
485 self._dimensions = dimensions
487 @staticmethod
488 def _from_imported_data(
489 key: NodeKey,
490 init_key: NodeKey,
491 data: _TaskNodeImportedData,
492 universe: DimensionUniverse | None,
493 ) -> TaskNode:
494 """Construct from a `PipelineTask` type and its configuration.
496 Parameters
497 ----------
498 key : `NodeKey`
499 Identifier for this node in networkx graphs.
500 init_key : `TaskInitNode`
501 Node representing the initialization of this task.
502 data : `_TaskNodeImportedData`
503 Internal struct that holds information that requires the task class
504 to have been be imported.
505 universe : `lsst.daf.butler.DimensionUniverse` or `None`
506 Definitions of all dimensions.
508 Returns
509 -------
510 node : `TaskNode`
511 New task node.
513 Raises
514 ------
515 ValueError
516 Raised if configuration validation failed when constructing
517 ``connections``.
518 """
519 init_inputs = {
520 name: ReadEdge._from_connection_map(init_key, name, data.connection_map)
521 for name in data.connections.initInputs
522 }
523 prerequisite_inputs = {
524 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True)
525 for name in data.connections.prerequisiteInputs
526 }
527 inputs = {
528 name: ReadEdge._from_connection_map(key, name, data.connection_map)
529 for name in data.connections.inputs
530 }
531 init_outputs = {
532 name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
533 for name in data.connections.initOutputs
534 }
535 outputs = {
536 name: WriteEdge._from_connection_map(key, name, data.connection_map)
537 for name in data.connections.outputs
538 }
539 init = TaskInitNode(
540 key=init_key,
541 inputs=init_inputs,
542 outputs=init_outputs,
543 config_output=WriteEdge._from_connection_map(
544 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map
545 ),
546 imported_data=data,
547 )
548 instance = TaskNode(
549 key=key,
550 init=init,
551 prerequisite_inputs=prerequisite_inputs,
552 inputs=inputs,
553 outputs=outputs,
554 log_output=(
555 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map)
556 if data.config.saveLogOutput
557 else None
558 ),
559 metadata_output=WriteEdge._from_connection_map(
560 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
561 ),
562 dimensions=(
563 frozenset(data.connections.dimensions)
564 if universe is None
565 else universe.conform(data.connections.dimensions)
566 ),
567 )
568 return instance
570 key: NodeKey
571 """Key that identifies this node in internal and exported networkx graphs.
572 """
574 prerequisite_inputs: Mapping[str, ReadEdge]
575 """Graph edges that represent prerequisite inputs to this task.
577 Prerequisite inputs must already exist in the data repository when a
578 `QuantumGraph` is built, but have more flexibility in how they are looked
579 up than regular inputs.
580 """
582 inputs: Mapping[str, ReadEdge]
583 """Graph edges that represent regular runtime inputs to this task.
584 """
586 outputs: Mapping[str, WriteEdge]
587 """Graph edges that represent regular runtime outputs of this task.
589 This does not include the special `log_output` and `metadata_output` edges;
590 use `iter_all_outputs` to include that, too.
591 """
593 log_output: WriteEdge | None
594 """The special runtime output that persists the task's logs.
595 """
597 metadata_output: WriteEdge
598 """The special runtime output that persists the task's metadata.
599 """
601 @property
602 def label(self) -> str:
603 """Label of this configuration of a task in the pipeline."""
604 return self.key.name
606 @property
607 def is_imported(self) -> bool:
608 """Whether this the task type for this node has been imported and
609 its configuration overrides applied.
611 If this is `False`, the `task_class` and `config` attributes may not
612 be accessed.
613 """
614 return self.init.is_imported
616 @property
617 def task_class(self) -> type[PipelineTask]:
618 """Type object for the task.
620 Accessing this attribute when `is_imported` is `False` will raise
621 `TaskNotImportedError`, but accessing `task_class_name` will not.
622 """
623 return self.init.task_class
625 @property
626 def task_class_name(self) -> str:
627 """The fully-qualified string name of the task class."""
628 return self.init.task_class_name
630 @property
631 def config(self) -> PipelineTaskConfig:
632 """Configuration for the task.
634 This is always frozen.
636 Accessing this attribute when `is_imported` is `False` will raise
637 `TaskNotImportedError`, but calling `get_config_str` will not.
638 """
639 return self.init.config
641 @property
642 def has_resolved_dimensions(self) -> bool:
643 """Whether the `dimensions` attribute may be accessed.
645 If `False`, the `raw_dimensions` attribute may be used to obtain a
646 set of dimension names that has not been resolved by a
647 `~lsst.daf.butler.DimensionsUniverse`.
648 """
649 return type(self._dimensions) is DimensionGroup
651 @property
652 def dimensions(self) -> DimensionGroup:
653 """Standardized dimensions of the task."""
654 if not self.has_resolved_dimensions:
655 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.")
656 return cast(DimensionGroup, self._dimensions)
658 @property
659 def raw_dimensions(self) -> frozenset[str]:
660 """Raw dimensions of the task, with standardization by a
661 `~lsst.daf.butler.DimensionUniverse` not guaranteed.
662 """
663 if self.has_resolved_dimensions:
664 return frozenset(cast(DimensionGroup, self._dimensions).names)
665 else:
666 return cast(frozenset[str], self._dimensions)
668 def __repr__(self) -> str:
669 if self.has_resolved_dimensions:
670 return f"{self.label} ({self.task_class_name}, {self.dimensions})"
671 else:
672 return f"{self.label} ({self.task_class_name})"
674 def get_config_str(self) -> str:
675 """Return the configuration for this task as a string of override
676 statements.
678 Returns
679 -------
680 config_str : `str`
681 String containing configuration-overload statements.
682 """
683 return self.init.get_config_str()
685 def iter_all_inputs(self) -> Iterator[ReadEdge]:
686 """Iterate over all runtime inputs, including both regular inputs and
687 prerequisites.
689 Yields
690 ------
691 `ReadEdge`
692 All the runtime inputs.
693 """
694 yield from self.prerequisite_inputs.values()
695 yield from self.inputs.values()
697 def iter_all_outputs(self) -> Iterator[WriteEdge]:
698 """Iterate over all runtime outputs, including special ones.
700 Yields
701 ------
702 `ReadEdge`
703 All the runtime outputs.
704 """
705 yield from self.outputs.values()
706 yield self.metadata_output
707 if self.log_output is not None:
708 yield self.log_output
710 def diff_edges(self, other: TaskNode) -> list[str]:
711 """Compare the edges of this task node to those from the same task
712 label in a different pipeline.
714 This also calls `TaskInitNode.diff_edges`.
716 Parameters
717 ----------
718 other : `TaskInitNode`
719 Other node to compare to. Must have the same task label, but need
720 not have the same configuration or even the same task class.
722 Returns
723 -------
724 differences : `list` [ `str` ]
725 List of string messages describing differences between ``self`` and
726 ``other``. Will be empty if the two nodes have the same edges.
727 Messages will use 'A' to refer to ``self`` and 'B' to refer to
728 ``other``.
729 """
730 result = self.init.diff_edges(other.init)
731 result += _diff_edge_mapping(
732 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input"
733 )
734 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input")
735 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output")
736 if self.log_output is not None:
737 if other.log_output is not None:
738 result += self.log_output.diff(other.log_output, "log output")
739 else:
740 result.append("Log output is present in A, but not in B.")
741 elif other.log_output is not None:
742 result.append("Log output is present in B, but not in A.")
743 result += self.metadata_output.diff(other.metadata_output, "metadata output")
744 return result
746 def get_lookup_function(
747 self, connection_name: str
748 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None:
749 """Return the custom dataset query function for an edge, if one exists.
751 Parameters
752 ----------
753 connection_name : `str`
754 Name of the connection.
756 Returns
757 -------
758 lookup_function : `~collections.abc.Callable` or `None`
759 Callable that takes a dataset type, a butler registry, a data
760 coordinate (the quantum data ID), and an ordered list of
761 collections to search, and returns an iterable of
762 `~lsst.daf.butler.DatasetRef`.
763 """
764 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None)
766 def get_spatial_bounds_connections(self) -> frozenset[str]:
767 """Return the names of connections whose data IDs should be included
768 in the calculation of the spatial bounds for this task's quanta.
770 Returns
771 -------
772 connection_names : `frozenset` [ `str` ]
773 Names of connections with spatial dimensions.
774 """
775 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections())
777 def get_temporal_bounds_connections(self) -> frozenset[str]:
778 """Return the names of connections whose data IDs should be included
779 in the calculation of the temporal bounds for this task's quanta.
781 Returns
782 -------
783 connection_names : `frozenset` [ `str` ]
784 Names of connections with temporal dimensions.
785 """
786 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections())
788 def _imported_and_configured(self, rebuild: bool) -> TaskNode:
789 """Import the task class and use it to construct a new instance.
791 Parameters
792 ----------
793 rebuild : `bool`
794 If `True`, import the task class and configure its connections to
795 generate new edges that may differ from the current ones. If
796 `False`, import the task class but just update the `task_class` and
797 `config` attributes, and assume the edges have not changed.
799 Returns
800 -------
801 node : `TaskNode`
802 Task node instance for which `is_imported` is `True`. Will be
803 ``self`` if this is the case already.
804 """
805 from ..pipelineTask import PipelineTask
807 if self.is_imported:
808 return self
809 task_class = doImportType(self.task_class_name)
810 if not issubclass(task_class, PipelineTask):
811 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
812 config = task_class.ConfigClass()
813 config.loadFromString(self.get_config_str())
814 return self._reconfigured(config, rebuild=rebuild, task_class=task_class)
816 def _reconfigured(
817 self,
818 config: PipelineTaskConfig,
819 rebuild: bool,
820 task_class: type[PipelineTask] | None = None,
821 ) -> TaskNode:
822 """Return a version of this node with new configuration.
824 Parameters
825 ----------
826 config : `.PipelineTaskConfig`
827 New configuration for the task.
828 rebuild : `bool`
829 If `True`, use the configured connections to generate new edges
830 that may differ from the current ones. If `False`, just update the
831 `task_class` and `config` attributes, and assume the edges have not
832 changed.
833 task_class : `type` [ `PipelineTask` ], optional
834 Subclass of `PipelineTask`. This defaults to ``self.task_class`,
835 but may be passed as an argument if that is not available because
836 the task class was not imported when ``self`` was constructed.
838 Returns
839 -------
840 node : `TaskNode`
841 Task node instance with the new config.
842 """
843 if task_class is None:
844 task_class = self.task_class
845 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config)
846 if rebuild:
847 return self._from_imported_data(
848 self.key,
849 self.init.key,
850 imported_data,
851 universe=self._dimensions.universe if type(self._dimensions) is DimensionGroup else None,
852 )
853 else:
854 return TaskNode(
855 self.key,
856 TaskInitNode(
857 self.init.key,
858 inputs=self.init.inputs,
859 outputs=self.init.outputs,
860 config_output=self.init.config_output,
861 imported_data=imported_data,
862 ),
863 prerequisite_inputs=self.prerequisite_inputs,
864 inputs=self.inputs,
865 outputs=self.outputs,
866 log_output=self.log_output,
867 metadata_output=self.metadata_output,
868 dimensions=self._dimensions,
869 )
871 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
872 """Return an otherwise-equivalent task node with resolved dimensions.
874 Parameters
875 ----------
876 universe : `lsst.daf.butler.DimensionUniverse` or `None`
877 Definitions for all dimensions.
879 Returns
880 -------
881 node : `TaskNode`
882 Task node instance with `dimensions` resolved by the given
883 universe. Will be ``self`` if this is the case already.
884 """
885 if self.has_resolved_dimensions:
886 if cast(DimensionGroup, self._dimensions).universe is universe:
887 return self
888 elif universe is None:
889 return self
890 return TaskNode(
891 key=self.key,
892 init=self.init,
893 prerequisite_inputs=self.prerequisite_inputs,
894 inputs=self.inputs,
895 outputs=self.outputs,
896 log_output=self.log_output,
897 metadata_output=self.metadata_output,
898 dimensions=(
899 universe.conform(self.raw_dimensions) if universe is not None else self.raw_dimensions
900 ),
901 )
903 def _to_xgraph_state(self) -> dict[str, Any]:
904 """Convert this nodes's attributes into a dictionary suitable for use
905 in exported networkx graphs.
906 """
907 result = self.init._to_xgraph_state()
908 if self.has_resolved_dimensions:
909 result["dimensions"] = self._dimensions
910 result["raw_dimensions"] = self.raw_dimensions
911 return result
913 def _get_imported_data(self) -> _TaskNodeImportedData:
914 """Return the imported data struct.
916 Returns
917 -------
918 imported_data : `_TaskNodeImportedData`
919 Internal structure holding state that requires the task class to
920 have been imported.
922 Raises
923 ------
924 TaskNotImportedError
925 Raised if `is_imported` is `False`.
926 """
927 return self.init._get_imported_data()
930def _diff_edge_mapping(
931 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str
932) -> list[str]:
933 """Compare a pair of mappings of edges.
935 Parameters
936 ----------
937 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
938 First mapping to compare. Expected to have connection names as keys.
939 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
940 First mapping to compare. If keys differ from those of ``a_mapping``,
941 this will be reported as a difference (in addition to element-wise
942 comparisons).
943 task_label : `str`
944 Task label associated with both mappings.
945 connection_type : `str`
946 Type of connection (e.g. "input" or "init output") associated with both
947 connections. This is a human-readable string to include in difference
948 messages.
950 Returns
951 -------
952 differences : `list` [ `str` ]
953 List of string messages describing differences between the two
954 mappings. Will be empty if the two mappings have the same edges.
955 Messages will include "A" and "B", and are expected to be a preceded
956 by a message describing what "A" and "B" are in the context in which
957 this method is called.
959 Notes
960 -----
961 This is expected to be used to compare one edge-holding mapping attribute
962 of a task or task init node to the same attribute on another task or task
963 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`,
964 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`,
965 `TaskInitNode.outputs`).
966 """
967 results = []
968 b_to_do = set(b_mapping.keys())
969 for connection_name, a_edge in a_mapping.items():
970 if (b_edge := b_mapping.get(connection_name)) is None:
971 results.append(
972 f"{connection_type.capitalize()} {connection_name!r} of task "
973 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
974 )
975 else:
976 results.extend(a_edge.diff(b_edge, connection_type))
977 b_to_do.discard(connection_name)
978 for connection_name in b_to_do:
979 results.append(
980 f"{connection_type.capitalize()} {connection_name!r} of task "
981 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
982 )
983 return results