Coverage for python/lsst/pipe/base/pipeline_graph/_mapping_views.py: 53%

59 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 09:32 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from collections.abc import Iterable, Iterator, Mapping, Sequence 

30from typing import Any, ClassVar, TypeVar, cast, overload 

31 

32import networkx 

33 

34from ._dataset_types import DatasetTypeNode 

35from ._exceptions import UnresolvedGraphError 

36from ._nodes import NodeKey, NodeType 

37from ._tasks import TaskInitNode, TaskNode 

38 

39_N = TypeVar("_N", covariant=True) 

40_T = TypeVar("_T") 

41 

42 

43class MappingView(Mapping[str, _N]): 

44 """Base class for mapping views into nodes of certain types in a 

45 `PipelineGraph`. 

46 

47 

48 Parameters 

49 ---------- 

50 parent_xgraph : `networkx.MultiDiGraph` 

51 Backing networkx graph for the `PipelineGraph` instance. 

52 

53 Notes 

54 ----- 

55 Instances should only be constructed by `PipelineGraph` and its helper 

56 classes. 

57 

58 Iteration order is topologically sorted if and only if the backing 

59 `PipelineGraph` has been sorted since its last modification. 

60 """ 

61 

62 def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None: 

63 self._parent_xgraph = parent_xgraph 

64 self._keys: list[str] | None = None 

65 

66 _NODE_TYPE: ClassVar[NodeType] # defined by derived classes 

67 

68 def __contains__(self, key: object) -> bool: 

69 # The given key may not be a str, but if it isn't it'll just fail the 

70 # check, which is what we want anyway. 

71 return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph 

72 

73 def __iter__(self) -> Iterator[str]: 

74 if self._keys is None: 

75 self._keys = self._make_keys(self._parent_xgraph) 

76 return iter(self._keys) 

77 

78 def __getitem__(self, key: str) -> _N: 

79 return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"] 

80 

81 def __len__(self) -> int: 

82 if self._keys is None: 

83 self._keys = self._make_keys(self._parent_xgraph) 

84 return len(self._keys) 

85 

86 def __repr__(self) -> str: 

87 return f"{type(self).__name__}({self!s})" 

88 

89 def __str__(self) -> str: 

90 return f"{{{', '.join(iter(self))}}}" 

91 

92 def _reorder(self, parent_keys: Sequence[NodeKey]) -> None: 

93 """Set this view's iteration order according to the given iterable of 

94 parent keys. 

95 

96 Parameters 

97 ---------- 

98 parent_keys : `~collections.abc.Sequence` [ `NodeKey` ] 

99 Superset of the keys in this view, in the new order. 

100 """ 

101 self._keys = self._make_keys(parent_keys) 

102 

103 def _reset(self) -> None: 

104 """Reset all cached content. 

105 

106 This should be called by the parent graph after any changes that could 

107 invalidate the view, causing it to be reconstructed when next 

108 requested. 

109 """ 

110 self._keys = None 

111 

112 def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]: 

113 """Make a sequence of keys for this view from an iterable of parent 

114 keys. 

115 

116 Parameters 

117 ---------- 

118 parent_keys : `~collections.abc.Iterable` [ `NodeKey` ] 

119 Superset of the keys in this view. 

120 """ 

121 return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE] 

122 

123 

124class TaskMappingView(MappingView[TaskNode]): 

125 """A mapping view of the tasks in a `PipelineGraph`. 

126 

127 Notes 

128 ----- 

129 Mapping keys are task labels and values are `TaskNode` instances. 

130 Iteration order is topological if and only if the `PipelineGraph` has been 

131 sorted since its last modification. 

132 """ 

133 

134 _NODE_TYPE = NodeType.TASK 

135 

136 

137class TaskInitMappingView(MappingView[TaskInitNode]): 

138 """A mapping view of the nodes representing task initialization in a 

139 `PipelineGraph`. 

140 

141 Notes 

142 ----- 

143 Mapping keys are task labels and values are `TaskInitNode` instances. 

144 Iteration order is topological if and only if the `PipelineGraph` has been 

145 sorted since its last modification. 

146 """ 

147 

148 _NODE_TYPE = NodeType.TASK_INIT 

149 

150 

151class DatasetTypeMappingView(MappingView[DatasetTypeNode]): 

152 """A mapping view of the nodes representing task initialization in a 

153 `PipelineGraph`. 

154 

155 Notes 

156 ----- 

157 Mapping keys are parent dataset type names and values are `DatasetTypeNode` 

158 instances, but values are only available for nodes that have been resolved 

159 (see `PipelineGraph.resolve`). Attempting to access an unresolved value 

160 will result in `UnresolvedGraphError` being raised. Keys for unresolved 

161 nodes are always present and iterable. 

162 

163 Iteration order is topological if and only if the `PipelineGraph` has been 

164 sorted since its last modification. 

165 """ 

166 

167 _NODE_TYPE = NodeType.DATASET_TYPE 

168 

169 def __getitem__(self, key: str) -> DatasetTypeNode: 

170 if (result := super().__getitem__(key)) is None: 

171 raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.") 

172 return result 

173 

174 def is_resolved(self, key: str) -> bool: 

175 """Test whether a node has been resolved.""" 

176 return super().__getitem__(key) is not None 

177 

178 @overload 

179 def get_if_resolved(self, key: str) -> DatasetTypeNode | None: 

180 ... # pragma: nocover 

181 

182 @overload 

183 def get_if_resolved(self, key: str, default: _T) -> DatasetTypeNode | _T: 

184 ... # pragma: nocover 

185 

186 def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any: 

187 """Get a node or return a default if it has not been resolved. 

188 

189 Parameters 

190 ---------- 

191 key : `str` 

192 Parent dataset type name. 

193 default 

194 Value to return if this dataset type has not been resolved. 

195 

196 Raises 

197 ------ 

198 KeyError 

199 Raised if the node is not present in the graph at all. 

200 """ 

201 if (result := super().__getitem__(key)) is None: 

202 return default # type: ignore 

203 return result