Coverage for python/lsst/pipe/base/pipeline_graph/_mapping_views.py: 44%

69 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-07 02:48 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from collections.abc import Iterable, Iterator, Mapping, Sequence 

30from typing import Any, ClassVar, TypeVar, cast, overload 

31 

32import networkx 

33 

34from ._dataset_types import DatasetTypeNode 

35from ._exceptions import UnresolvedGraphError 

36from ._nodes import NodeKey, NodeType 

37from ._tasks import TaskInitNode, TaskNode 

38 

39_N = TypeVar("_N", covariant=True) 

40_T = TypeVar("_T") 

41 

42 

43class MappingView(Mapping[str, _N]): 

44 """Base class for mapping views into nodes of certain types in a 

45 `PipelineGraph`. 

46 

47 Parameters 

48 ---------- 

49 parent_xgraph : `networkx.MultiDiGraph` 

50 Backing networkx graph for the `PipelineGraph` instance. 

51 

52 Notes 

53 ----- 

54 Instances should only be constructed by `PipelineGraph` and its helper 

55 classes. 

56 

57 Iteration order is topologically sorted if and only if the backing 

58 `PipelineGraph` has been sorted since its last modification. 

59 """ 

60 

61 def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None: 

62 self._parent_xgraph = parent_xgraph 

63 self._keys: list[str] | None = None 

64 

65 _NODE_TYPE: ClassVar[NodeType] # defined by derived classes 

66 

67 def __contains__(self, key: object) -> bool: 

68 # The given key may not be a str, but if it isn't it'll just fail the 

69 # check, which is what we want anyway. 

70 return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph 

71 

72 def __iter__(self) -> Iterator[str]: 

73 if self._keys is None: 

74 self._keys = self._make_keys(self._parent_xgraph) 

75 return iter(self._keys) 

76 

77 def __getitem__(self, key: str) -> _N: 

78 return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"] 

79 

80 def __len__(self) -> int: 

81 if self._keys is None: 

82 self._keys = self._make_keys(self._parent_xgraph) 

83 return len(self._keys) 

84 

85 def __repr__(self) -> str: 

86 return f"{type(self).__name__}({self!s})" 

87 

88 def __str__(self) -> str: 

89 return f"{{{', '.join(iter(self))}}}" 

90 

91 def _reorder(self, parent_keys: Sequence[NodeKey]) -> None: 

92 """Set this view's iteration order according to the given iterable of 

93 parent keys. 

94 

95 Parameters 

96 ---------- 

97 parent_keys : `~collections.abc.Sequence` [ `NodeKey` ] 

98 Superset of the keys in this view, in the new order. 

99 """ 

100 self._keys = self._make_keys(parent_keys) 

101 

102 def _reset(self) -> None: 

103 """Reset all cached content. 

104 

105 This should be called by the parent graph after any changes that could 

106 invalidate the view, causing it to be reconstructed when next 

107 requested. 

108 """ 

109 self._keys = None 

110 

111 def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]: 

112 """Make a sequence of keys for this view from an iterable of parent 

113 keys. 

114 

115 Parameters 

116 ---------- 

117 parent_keys : `~collections.abc.Iterable` [ `NodeKey` ] 

118 Superset of the keys in this view. 

119 """ 

120 return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE] 

121 

122 

123class TaskMappingView(MappingView[TaskNode]): 

124 """A mapping view of the tasks in a `PipelineGraph`. 

125 

126 Notes 

127 ----- 

128 Mapping keys are task labels and values are `TaskNode` instances. 

129 Iteration order is topological if and only if the `PipelineGraph` has been 

130 sorted since its last modification. 

131 """ 

132 

133 _NODE_TYPE = NodeType.TASK 

134 

135 def between(self, first: str | None = None, last: str | None = None) -> Mapping[str, TaskNode]: 

136 """Return a mapping whose tasks are between a range of tasks. 

137 

138 Parameters 

139 ---------- 

140 first : `str`, optional 

141 Label of the first task to include, inclusive. 

142 last : `str`, optional 

143 Label of the last task to include, inclusive. 

144 

145 Returns 

146 ------- 

147 between : `~collections.abc.Mapping` [ `str`, `TaskNode` ] 

148 Tasks that are downstream of ``first`` and upstream of ``last``. 

149 If either ``first`` or ``last`` is `None` (default), that side 

150 of the result is unbounded. Tasks that have no dependency 

151 relationship to either task are not included, and if both bounds 

152 are provided, included tasks must have the right relationship with 

153 both bounding tasks. 

154 """ 

155 # This is definitely not the fastest way to compute this subset, but 

156 # it's a very simple one (given what networkx provides), and pipeline 

157 # graphs are never *that* big. 

158 if first is not None: 

159 first_key = NodeKey(NodeType.TASK, first) 

160 a: set[NodeKey] = set(networkx.dag.descendants(self._parent_xgraph, first_key)) 

161 a.add(first_key) 

162 else: 

163 a = set(self._parent_xgraph.nodes.keys()) 

164 if last is not None: 

165 last_key = NodeKey(NodeType.TASK, last) 

166 b: set[NodeKey] = set(networkx.dag.ancestors(self._parent_xgraph, last_key)) 

167 b.add(last_key) 

168 else: 

169 b = set(self._parent_xgraph.nodes.keys()) 

170 return { 

171 key.name: self._parent_xgraph.nodes[key]["instance"] 

172 for key in a & b 

173 if key.node_type is NodeType.TASK 

174 } 

175 

176 

177class TaskInitMappingView(MappingView[TaskInitNode]): 

178 """A mapping view of the nodes representing task initialization in a 

179 `PipelineGraph`. 

180 

181 Notes 

182 ----- 

183 Mapping keys are task labels and values are `TaskInitNode` instances. 

184 Iteration order is topological if and only if the `PipelineGraph` has been 

185 sorted since its last modification. 

186 """ 

187 

188 _NODE_TYPE = NodeType.TASK_INIT 

189 

190 

191class DatasetTypeMappingView(MappingView[DatasetTypeNode]): 

192 """A mapping view of the nodes representing task initialization in a 

193 `PipelineGraph`. 

194 

195 Notes 

196 ----- 

197 Mapping keys are parent dataset type names and values are `DatasetTypeNode` 

198 instances, but values are only available for nodes that have been resolved 

199 (see `PipelineGraph.resolve`). Attempting to access an unresolved value 

200 will result in `UnresolvedGraphError` being raised. Keys for unresolved 

201 nodes are always present and iterable. 

202 

203 Iteration order is topological if and only if the `PipelineGraph` has been 

204 sorted since its last modification. 

205 """ 

206 

207 _NODE_TYPE = NodeType.DATASET_TYPE 

208 

209 def __getitem__(self, key: str) -> DatasetTypeNode: 

210 if (result := super().__getitem__(key)) is None: 

211 raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.") 

212 return result 

213 

214 def is_resolved(self, key: str) -> bool: 

215 """Test whether a node has been resolved. 

216 

217 Parameters 

218 ---------- 

219 key : `str` 

220 Node to check. 

221 

222 Returns 

223 ------- 

224 `bool` 

225 Whether the node has been resolved or not. 

226 """ 

227 return super().__getitem__(key) is not None 

228 

229 @overload 

230 def get_if_resolved(self, key: str) -> DatasetTypeNode | None: ... # pragma: nocover 230 ↛ exitline 230 didn't return from function 'get_if_resolved', because

231 

232 @overload 

233 def get_if_resolved(self, key: str, default: _T) -> DatasetTypeNode | _T: ... # pragma: nocover 233 ↛ exitline 233 didn't return from function 'get_if_resolved', because

234 

235 def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any: 

236 """Get a node or return a default if it has not been resolved. 

237 

238 Parameters 

239 ---------- 

240 key : `str` 

241 Parent dataset type name. 

242 default : `~typing.Any` 

243 Value to return if this dataset type has not been resolved. 

244 

245 Returns 

246 ------- 

247 result : `DatasetTypeNode` 

248 The resolved node or the default value. 

249 

250 Raises 

251 ------ 

252 KeyError 

253 Raised if the node is not present in the graph at all. 

254 """ 

255 if (result := super().__getitem__(key)) is None: 

256 return default # type: ignore 

257 return result