Coverage for python/lsst/pipe/base/pipeline_graph/_mapping_views.py: 53%
57 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-21 10:57 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-21 10:57 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from collections.abc import Iterable, Iterator, Mapping, Sequence
30from typing import Any, ClassVar, TypeVar, cast, overload
32import networkx
34from ._dataset_types import DatasetTypeNode
35from ._exceptions import UnresolvedGraphError
36from ._nodes import NodeKey, NodeType
37from ._tasks import TaskInitNode, TaskNode
39_N = TypeVar("_N", covariant=True)
40_T = TypeVar("_T")
43class MappingView(Mapping[str, _N]):
44 """Base class for mapping views into nodes of certain types in a
45 `PipelineGraph`.
47 Parameters
48 ----------
49 parent_xgraph : `networkx.MultiDiGraph`
50 Backing networkx graph for the `PipelineGraph` instance.
52 Notes
53 -----
54 Instances should only be constructed by `PipelineGraph` and its helper
55 classes.
57 Iteration order is topologically sorted if and only if the backing
58 `PipelineGraph` has been sorted since its last modification.
59 """
61 def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None:
62 self._parent_xgraph = parent_xgraph
63 self._keys: list[str] | None = None
65 _NODE_TYPE: ClassVar[NodeType] # defined by derived classes
67 def __contains__(self, key: object) -> bool:
68 # The given key may not be a str, but if it isn't it'll just fail the
69 # check, which is what we want anyway.
70 return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph
72 def __iter__(self) -> Iterator[str]:
73 if self._keys is None:
74 self._keys = self._make_keys(self._parent_xgraph)
75 return iter(self._keys)
77 def __getitem__(self, key: str) -> _N:
78 return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"]
80 def __len__(self) -> int:
81 if self._keys is None:
82 self._keys = self._make_keys(self._parent_xgraph)
83 return len(self._keys)
85 def __repr__(self) -> str:
86 return f"{type(self).__name__}({self!s})"
88 def __str__(self) -> str:
89 return f"{{{', '.join(iter(self))}}}"
91 def _reorder(self, parent_keys: Sequence[NodeKey]) -> None:
92 """Set this view's iteration order according to the given iterable of
93 parent keys.
95 Parameters
96 ----------
97 parent_keys : `~collections.abc.Sequence` [ `NodeKey` ]
98 Superset of the keys in this view, in the new order.
99 """
100 self._keys = self._make_keys(parent_keys)
102 def _reset(self) -> None:
103 """Reset all cached content.
105 This should be called by the parent graph after any changes that could
106 invalidate the view, causing it to be reconstructed when next
107 requested.
108 """
109 self._keys = None
111 def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]:
112 """Make a sequence of keys for this view from an iterable of parent
113 keys.
115 Parameters
116 ----------
117 parent_keys : `~collections.abc.Iterable` [ `NodeKey` ]
118 Superset of the keys in this view.
119 """
120 return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE]
123class TaskMappingView(MappingView[TaskNode]):
124 """A mapping view of the tasks in a `PipelineGraph`.
126 Notes
127 -----
128 Mapping keys are task labels and values are `TaskNode` instances.
129 Iteration order is topological if and only if the `PipelineGraph` has been
130 sorted since its last modification.
131 """
133 _NODE_TYPE = NodeType.TASK
136class TaskInitMappingView(MappingView[TaskInitNode]):
137 """A mapping view of the nodes representing task initialization in a
138 `PipelineGraph`.
140 Notes
141 -----
142 Mapping keys are task labels and values are `TaskInitNode` instances.
143 Iteration order is topological if and only if the `PipelineGraph` has been
144 sorted since its last modification.
145 """
147 _NODE_TYPE = NodeType.TASK_INIT
150class DatasetTypeMappingView(MappingView[DatasetTypeNode]):
151 """A mapping view of the nodes representing task initialization in a
152 `PipelineGraph`.
154 Notes
155 -----
156 Mapping keys are parent dataset type names and values are `DatasetTypeNode`
157 instances, but values are only available for nodes that have been resolved
158 (see `PipelineGraph.resolve`). Attempting to access an unresolved value
159 will result in `UnresolvedGraphError` being raised. Keys for unresolved
160 nodes are always present and iterable.
162 Iteration order is topological if and only if the `PipelineGraph` has been
163 sorted since its last modification.
164 """
166 _NODE_TYPE = NodeType.DATASET_TYPE
168 def __getitem__(self, key: str) -> DatasetTypeNode:
169 if (result := super().__getitem__(key)) is None:
170 raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.")
171 return result
173 def is_resolved(self, key: str) -> bool:
174 """Test whether a node has been resolved.
176 Parameters
177 ----------
178 key : `str`
179 Node to check.
181 Returns
182 -------
183 `bool`
184 Whether the node has been resolved or not.
185 """
186 return super().__getitem__(key) is not None
188 @overload
189 def get_if_resolved(self, key: str) -> DatasetTypeNode | None: ... # pragma: nocover 189 ↛ exitline 189 didn't return from function 'get_if_resolved'
191 @overload
192 def get_if_resolved(self, key: str, default: _T) -> DatasetTypeNode | _T: ... # pragma: nocover 192 ↛ exitline 192 didn't return from function 'get_if_resolved'
194 def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any:
195 """Get a node or return a default if it has not been resolved.
197 Parameters
198 ----------
199 key : `str`
200 Parent dataset type name.
201 default : `~typing.Any`
202 Value to return if this dataset type has not been resolved.
204 Returns
205 -------
206 result : `DatasetTypeNode`
207 The resolved node or the default value.
209 Raises
210 ------
211 KeyError
212 Raised if the node is not present in the graph at all.
213 """
214 if (result := super().__getitem__(key)) is None:
215 return default # type: ignore
216 return result