Coverage for python/lsst/pipe/base/pipeline_graph/_mapping_views.py: 53%

59 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-23 10:31 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from collections.abc import Iterable, Iterator, Mapping, Sequence 

24from typing import Any, ClassVar, TypeVar, cast, overload 

25 

26import networkx 

27 

28from ._dataset_types import DatasetTypeNode 

29from ._exceptions import UnresolvedGraphError 

30from ._nodes import NodeKey, NodeType 

31from ._tasks import TaskInitNode, TaskNode 

32 

33_N = TypeVar("_N", covariant=True) 

34_T = TypeVar("_T") 

35 

36 

37class MappingView(Mapping[str, _N]): 

38 """Base class for mapping views into nodes of certain types in a 

39 `PipelineGraph`. 

40 

41 

42 Parameters 

43 ---------- 

44 parent_xgraph : `networkx.MultiDiGraph` 

45 Backing networkx graph for the `PipelineGraph` instance. 

46 

47 Notes 

48 ----- 

49 Instances should only be constructed by `PipelineGraph` and its helper 

50 classes. 

51 

52 Iteration order is topologically sorted if and only if the backing 

53 `PipelineGraph` has been sorted since its last modification. 

54 """ 

55 

56 def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None: 

57 self._parent_xgraph = parent_xgraph 

58 self._keys: list[str] | None = None 

59 

60 _NODE_TYPE: ClassVar[NodeType] # defined by derived classes 

61 

62 def __contains__(self, key: object) -> bool: 

63 # The given key may not be a str, but if it isn't it'll just fail the 

64 # check, which is what we want anyway. 

65 return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph 

66 

67 def __iter__(self) -> Iterator[str]: 

68 if self._keys is None: 

69 self._keys = self._make_keys(self._parent_xgraph) 

70 return iter(self._keys) 

71 

72 def __getitem__(self, key: str) -> _N: 

73 return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"] 

74 

75 def __len__(self) -> int: 

76 if self._keys is None: 

77 self._keys = self._make_keys(self._parent_xgraph) 

78 return len(self._keys) 

79 

80 def __repr__(self) -> str: 

81 return f"{type(self).__name__}({self!s})" 

82 

83 def __str__(self) -> str: 

84 return f"{{{', '.join(iter(self))}}}" 

85 

86 def _reorder(self, parent_keys: Sequence[NodeKey]) -> None: 

87 """Set this view's iteration order according to the given iterable of 

88 parent keys. 

89 

90 Parameters 

91 ---------- 

92 parent_keys : `~collections.abc.Sequence` [ `NodeKey` ] 

93 Superset of the keys in this view, in the new order. 

94 """ 

95 self._keys = self._make_keys(parent_keys) 

96 

97 def _reset(self) -> None: 

98 """Reset all cached content. 

99 

100 This should be called by the parent graph after any changes that could 

101 invalidate the view, causing it to be reconstructed when next 

102 requested. 

103 """ 

104 self._keys = None 

105 

106 def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]: 

107 """Make a sequence of keys for this view from an iterable of parent 

108 keys. 

109 

110 Parameters 

111 ---------- 

112 parent_keys : `~collections.abc.Iterable` [ `NodeKey` ] 

113 Superset of the keys in this view. 

114 """ 

115 return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE] 

116 

117 

118class TaskMappingView(MappingView[TaskNode]): 

119 """A mapping view of the tasks in a `PipelineGraph`. 

120 

121 Notes 

122 ----- 

123 Mapping keys are task labels and values are `TaskNode` instances. 

124 Iteration order is topological if and only if the `PipelineGraph` has been 

125 sorted since its last modification. 

126 """ 

127 

128 _NODE_TYPE = NodeType.TASK 

129 

130 

131class TaskInitMappingView(MappingView[TaskInitNode]): 

132 """A mapping view of the nodes representing task initialization in a 

133 `PipelineGraph`. 

134 

135 Notes 

136 ----- 

137 Mapping keys are task labels and values are `TaskInitNode` instances. 

138 Iteration order is topological if and only if the `PipelineGraph` has been 

139 sorted since its last modification. 

140 """ 

141 

142 _NODE_TYPE = NodeType.TASK_INIT 

143 

144 

145class DatasetTypeMappingView(MappingView[DatasetTypeNode]): 

146 """A mapping view of the nodes representing task initialization in a 

147 `PipelineGraph`. 

148 

149 Notes 

150 ----- 

151 Mapping keys are parent dataset type names and values are `DatasetTypeNode` 

152 instances, but values are only available for nodes that have been resolved 

153 (see `PipelineGraph.resolve`). Attempting to access an unresolved value 

154 will result in `UnresolvedGraphError` being raised. Keys for unresolved 

155 nodes are always present and iterable. 

156 

157 Iteration order is topological if and only if the `PipelineGraph` has been 

158 sorted since its last modification. 

159 """ 

160 

161 _NODE_TYPE = NodeType.DATASET_TYPE 

162 

163 def __getitem__(self, key: str) -> DatasetTypeNode: 

164 if (result := super().__getitem__(key)) is None: 

165 raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.") 

166 return result 

167 

168 def is_resolved(self, key: str) -> bool: 

169 """Test whether a node has been resolved.""" 

170 return super().__getitem__(key) is not None 

171 

172 @overload 

173 def get_if_resolved(self, key: str) -> DatasetTypeNode | None: 

174 ... # pragma: nocover 

175 

176 @overload 

177 def get_if_resolved(self, key: str, default: _T) -> DatasetTypeNode | _T: 

178 ... # pragma: nocover 

179 

180 def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any: 

181 """Get a node or return a default if it has not been resolved. 

182 

183 Parameters 

184 ---------- 

185 key : `str` 

186 Parent dataset type name. 

187 default 

188 Value to return if this dataset type has not been resolved. 

189 

190 Raises 

191 ------ 

192 KeyError 

193 Raised if the node is not present in the graph at all. 

194 """ 

195 if (result := super().__getitem__(key)) is None: 

196 return default # type: ignore 

197 return result