Coverage for python / lsst / pipe / base / pipeline_graph / _mapping_views.py: 45%
67 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from collections.abc import Iterable, Iterator, Mapping, Sequence
30from typing import Any, ClassVar, cast, overload
32import networkx
34from ._dataset_types import DatasetTypeNode
35from ._exceptions import UnresolvedGraphError
36from ._nodes import NodeKey, NodeType
37from ._tasks import TaskInitNode, TaskNode
40class MappingView[N](Mapping[str, N]):
41 """Base class for mapping views into nodes of certain types in a
42 `PipelineGraph`.
44 Parameters
45 ----------
46 parent_xgraph : `networkx.MultiDiGraph`
47 Backing networkx graph for the `PipelineGraph` instance.
49 Notes
50 -----
51 Instances should only be constructed by `PipelineGraph` and its helper
52 classes.
54 Iteration order is topologically sorted if and only if the backing
55 `PipelineGraph` has been sorted since its last modification.
56 """
58 def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None:
59 self._parent_xgraph = parent_xgraph
60 self._keys: list[str] | None = None
62 _NODE_TYPE: ClassVar[NodeType] # defined by derived classes
64 def __contains__(self, key: object) -> bool:
65 # The given key may not be a str, but if it isn't it'll just fail the
66 # check, which is what we want anyway.
67 return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph
69 def __iter__(self) -> Iterator[str]:
70 if self._keys is None:
71 self._keys = self._make_keys(self._parent_xgraph)
72 return iter(self._keys)
74 def __getitem__(self, key: str) -> N:
75 return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"]
77 def __len__(self) -> int:
78 if self._keys is None:
79 self._keys = self._make_keys(self._parent_xgraph)
80 return len(self._keys)
82 def __repr__(self) -> str:
83 return f"{type(self).__name__}({self!s})"
85 def __str__(self) -> str:
86 return f"{{{', '.join(iter(self))}}}"
88 def _reorder(self, parent_keys: Sequence[NodeKey]) -> None:
89 """Set this view's iteration order according to the given iterable of
90 parent keys.
92 Parameters
93 ----------
94 parent_keys : `~collections.abc.Sequence` [ `NodeKey` ]
95 Superset of the keys in this view, in the new order.
96 """
97 self._keys = self._make_keys(parent_keys)
99 def _reset(self) -> None:
100 """Reset all cached content.
102 This should be called by the parent graph after any changes that could
103 invalidate the view, causing it to be reconstructed when next
104 requested.
105 """
106 self._keys = None
108 def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]:
109 """Make a sequence of keys for this view from an iterable of parent
110 keys.
112 Parameters
113 ----------
114 parent_keys : `~collections.abc.Iterable` [ `NodeKey` ]
115 Superset of the keys in this view.
116 """
117 return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE]
120class TaskMappingView(MappingView[TaskNode]):
121 """A mapping view of the tasks in a `PipelineGraph`.
123 Notes
124 -----
125 Mapping keys are task labels and values are `TaskNode` instances.
126 Iteration order is topological if and only if the `PipelineGraph` has been
127 sorted since its last modification.
128 """
130 _NODE_TYPE = NodeType.TASK
132 def between(self, first: str | None = None, last: str | None = None) -> Mapping[str, TaskNode]:
133 """Return a mapping whose tasks are between a range of tasks.
135 Parameters
136 ----------
137 first : `str`, optional
138 Label of the first task to include, inclusive.
139 last : `str`, optional
140 Label of the last task to include, inclusive.
142 Returns
143 -------
144 between : `~collections.abc.Mapping` [ `str`, `TaskNode` ]
145 Tasks that are downstream of ``first`` and upstream of ``last``.
146 If either ``first`` or ``last`` is `None` (default), that side
147 of the result is unbounded. Tasks that have no dependency
148 relationship to either task are not included, and if both bounds
149 are provided, included tasks must have the right relationship with
150 both bounding tasks.
151 """
152 # This is definitely not the fastest way to compute this subset, but
153 # it's a very simple one (given what networkx provides), and pipeline
154 # graphs are never *that* big.
155 if first is not None:
156 first_key = NodeKey(NodeType.TASK, first)
157 a: set[NodeKey] = set(networkx.dag.descendants(self._parent_xgraph, first_key))
158 a.add(first_key)
159 else:
160 a = set(self._parent_xgraph.nodes.keys())
161 if last is not None:
162 last_key = NodeKey(NodeType.TASK, last)
163 b: set[NodeKey] = set(networkx.dag.ancestors(self._parent_xgraph, last_key))
164 b.add(last_key)
165 else:
166 b = set(self._parent_xgraph.nodes.keys())
167 return {
168 key.name: self._parent_xgraph.nodes[key]["instance"]
169 for key in a & b
170 if key.node_type is NodeType.TASK
171 }
174class TaskInitMappingView(MappingView[TaskInitNode]):
175 """A mapping view of the nodes representing task initialization in a
176 `PipelineGraph`.
178 Notes
179 -----
180 Mapping keys are task labels and values are `TaskInitNode` instances.
181 Iteration order is topological if and only if the `PipelineGraph` has been
182 sorted since its last modification.
183 """
185 _NODE_TYPE = NodeType.TASK_INIT
188class DatasetTypeMappingView(MappingView[DatasetTypeNode]):
189 """A mapping view of the nodes representing task initialization in a
190 `PipelineGraph`.
192 Notes
193 -----
194 Mapping keys are parent dataset type names and values are `DatasetTypeNode`
195 instances, but values are only available for nodes that have been resolved
196 (see `PipelineGraph.resolve`). Attempting to access an unresolved value
197 will result in `UnresolvedGraphError` being raised. Keys for unresolved
198 nodes are always present and iterable.
200 Iteration order is topological if and only if the `PipelineGraph` has been
201 sorted since its last modification.
202 """
204 _NODE_TYPE = NodeType.DATASET_TYPE
206 def __getitem__(self, key: str) -> DatasetTypeNode:
207 if (result := super().__getitem__(key)) is None:
208 raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.")
209 return result
211 def is_resolved(self, key: str) -> bool:
212 """Test whether a node has been resolved.
214 Parameters
215 ----------
216 key : `str`
217 Node to check.
219 Returns
220 -------
221 `bool`
222 Whether the node has been resolved or not.
223 """
224 return super().__getitem__(key) is not None
226 @overload
227 def get_if_resolved(self, key: str) -> DatasetTypeNode | None: ... # pragma: nocover 227 ↛ exitline 227 didn't return from function 'get_if_resolved' because
229 @overload
230 def get_if_resolved[T](self, key: str, default: T) -> DatasetTypeNode | T: ... # pragma: nocover
232 def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any:
233 """Get a node or return a default if it has not been resolved.
235 Parameters
236 ----------
237 key : `str`
238 Parent dataset type name.
239 default : `~typing.Any`
240 Value to return if this dataset type has not been resolved.
242 Returns
243 -------
244 result : `DatasetTypeNode`
245 The resolved node or the default value.
247 Raises
248 ------
249 KeyError
250 Raised if the node is not present in the graph at all.
251 """
252 if (result := super().__getitem__(key)) is None:
253 return default # type: ignore
254 return result