Coverage for python/lsst/pipe/base/pipeline_graph/visualization/_layout.py: 23%

154 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-18 10:50 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("Layout", "ColumnSelector", "LayoutRow") 

30 

31import dataclasses 

32from collections.abc import Iterable, Iterator, Mapping, Set 

33from typing import Generic, TextIO, TypeVar 

34 

35import networkx 

36import networkx.algorithms.components 

37import networkx.algorithms.dag 

38import networkx.algorithms.shortest_paths 

39import networkx.algorithms.traversal 

40 

41_K = TypeVar("_K") 

42 

43 

44class Layout(Generic[_K]): 

45 """A class that positions nodes and edges in text-art graph visualizations. 

46 

47 Parameters 

48 ---------- 

49 xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` 

50 NetworkX export of a `.PipelineGraph` being visualized. 

51 column_selector : `ColumnSelector`, optional 

52 Parameterized helper for selecting which column each node should be 

53 added to. 

54 """ 

55 

56 def __init__( 

57 self, 

58 xgraph: networkx.DiGraph | networkx.MultiDiGraph, 

59 column_selector: ColumnSelector | None = None, 

60 ): 

61 if column_selector is None: 

62 column_selector = ColumnSelector() 

63 self._xgraph = xgraph 

64 self._column_selector = column_selector 

65 # Mapping from the column (i.e. 'x') of an already-positioned node to 

66 # the node keys its outgoing edges terminate at. These and all other 

67 # column/x variables are multiples of 2 when they refer to 

68 # already-existing columns, allowing potential insertion of new columns 

69 # between them to be represented by odd integers. Positions are also 

70 # inverted from the order they are usually displayed; it's best to 

71 # think of them as the distance (to the left) from the column where 

72 # text appears on the right. This is natural because prefer for nodes 

73 # to be close to that text when possible (or maybe it's historical, and 

74 # it's just a lot of work to re-invert the algorithm now that it's 

75 # written). 

76 self._active_columns: dict[int, set[_K]] = {} 

77 # Mapping from node key to its column. 

78 self._locations: dict[_K, int] = {} 

79 # Minimum and maximum column (may go negative; will be shifted as 

80 # needed before actual display). 

81 self._x_min = 0 

82 self._x_max = 0 

83 # Run the algorithm! 

84 self._add_graph(self._xgraph) 

85 del self._active_columns 

86 

87 def _add_graph(self, xgraph: networkx.DiGraph | networkx.MultiDiGraph) -> None: 

88 """Highest-level routine for the layout algorithm. 

89 

90 Parameters 

91 ---------- 

92 xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` 

93 Graph or subgraph to add to the layout. 

94 """ 

95 # Start by identifying unconnected subgraphs ("components"); we'll 

96 # display these in series, as the first of our many attempts to 

97 # minimize tangles of edges. 

98 component_xgraphs_and_orders = [] 

99 single_nodes = [] 

100 for component_nodes in networkx.components.weakly_connected_components(xgraph): 

101 if len(component_nodes) == 1: 

102 single_nodes.append(component_nodes.pop()) 

103 else: 

104 component_xgraph = xgraph.subgraph(component_nodes) 

105 component_order = list( 

106 networkx.algorithms.dag.lexicographical_topological_sort(component_xgraph, key=str) 

107 ) 

108 component_xgraphs_and_orders.append((component_xgraph, component_order)) 

109 # Add all single-node components in lexicographical order. 

110 single_nodes.sort(key=str) 

111 for node in single_nodes: 

112 self._add_single_node(node) 

113 # Sort component graphs by their size and then str of their first node. 

114 component_xgraphs_and_orders.sort(key=lambda t: (len(t[1]), str(t[1][0]))) 

115 # Add subgraphs in that order. 

116 for component_xgraph, component_order in component_xgraphs_and_orders: 

117 self._add_connected_graph(component_xgraph, component_order) 

118 

119 def _add_single_node(self, node: _K) -> None: 

120 """Add a single node to the layout.""" 

121 assert node not in self._locations 

122 if not self._locations: 

123 # Special-case the first node in a component (disconnectd 

124 # subgraph). 

125 self._locations[node] = 0 

126 self._active_columns[0] = set(self._xgraph.successors(node)) 

127 return 

128 # The candidate_x list holds columns where we could insert this node. 

129 # We start with new columns on the outside and a new column between 

130 # each pair of existing columns. These inner nodes are usually not 

131 # very good candidates, but it's simplest to always include them and 

132 # let the penalty system in ColumnSelector take care of it. 

133 candidate_x = [self._x_max + 2, self._x_min - 2] 

134 # The columns holding edges that will connect to this node. 

135 connecting_x = [] 

136 # Iterate over active columns to populate the above. 

137 for active_column_x, active_column_endpoints in self._active_columns.items(): 

138 if node in active_column_endpoints: 

139 connecting_x.append(active_column_x) 

140 # Delete this node from the active columns it is in, and delete any 

141 # entries that now have empty sets. 

142 for x in connecting_x: 

143 destinations = self._active_columns[x] 

144 destinations.remove(node) 

145 if not destinations: 

146 del self._active_columns[x] 

147 # Add all empty columns between the current min and max as candidates. 

148 for x in range(self._x_min, self._x_max + 2, 2): 

149 if x not in self._active_columns: 

150 candidate_x.append(x) 

151 # Sort the list of connecting columns so we can easily get min and max. 

152 connecting_x.sort() 

153 best_x = min( 

154 candidate_x, 

155 key=lambda x: self._column_selector( 

156 connecting_x, x, self._active_columns, self._x_min, self._x_max 

157 ), 

158 ) 

159 if best_x % 2: 

160 # We're inserting a new column between two existing ones; shift 

161 # all existing column values above this one to make room while 

162 # using only even numbers. 

163 best_x = self._shift(best_x) 

164 self._x_min = min(self._x_min, best_x) 

165 self._x_max = max(self._x_max, best_x) 

166 

167 self._locations[node] = best_x 

168 successors = set(self._xgraph.successors(node)) 

169 if successors: 

170 self._active_columns[best_x] = successors 

171 

172 def _shift(self, x: int) -> int: 

173 """Shift all columns above the given one up by two, allowing a new 

174 column to be inserted while leaving all columns as even integers. 

175 """ 

176 for node, old_x in self._locations.items(): 

177 if old_x > x: 

178 self._locations[node] += 2 

179 self._active_columns = { 

180 old_x + 2 if old_x > x else old_x: destinations 

181 for old_x, destinations in self._active_columns.items() 

182 } 

183 self._x_max += 2 

184 return x + 1 

185 

186 def _add_connected_graph( 

187 self, xgraph: networkx.DiGraph | networkx.MultiDiGraph, order: list[_K] | None = None 

188 ) -> None: 

189 """Add a subgraph whose nodes are connected. 

190 

191 Parameters 

192 ---------- 

193 xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` 

194 Graph or subgraph to add to the layout. 

195 order : `list`, optional 

196 List providing a lexicographical/topological sort of ``xgraph``. 

197 Will be computed if not provided. 

198 """ 

199 if order is None: 

200 order = list(networkx.algorithms.dag.lexicographical_topological_sort(xgraph, key=str)) 

201 # Find the longest path between two nodes, which we'll call the 

202 # "backbone" of our layout; we'll step through this path and add 

203 # recurse via calls to `_add_graph` on the nodes that we think should 

204 # go between the backbone nodes. 

205 backbone: list[_K] = networkx.algorithms.dag.dag_longest_path(xgraph, topo_order=order) 

206 # Add the first backbone node and any ancestors according to the full 

207 # graph (it can't have ancestors in this _subgraph_ because they'd have 

208 # been part of the longest path themselves, but the subgraph doesn't 

209 # have a complete picture). 

210 current = backbone.pop(0) 

211 self._add_blockers_of(current) 

212 self._add_single_node(current) 

213 # Remember all recursive descendants of the current backbone node for 

214 # later use. 

215 descendants = frozenset(networkx.algorithms.dag.descendants(xgraph, current)) 

216 while backbone: 

217 current = backbone.pop(0) 

218 # Find descendants of the previous node that are: 

219 # - in this subgraph 

220 # - not descendants of the current node. 

221 followers_of_previous = set(descendants) 

222 descendants = frozenset(networkx.algorithms.dag.descendants(xgraph, current)) 

223 followers_of_previous.remove(current) 

224 followers_of_previous.difference_update(descendants) 

225 followers_of_previous.difference_update(self._locations.keys()) 

226 # Add those followers of the previous node. We like adding these 

227 # here because this terminates edges as soon as we can, freeing up 

228 # those columns for new new nodes. 

229 self._add_graph(xgraph.subgraph(followers_of_previous)) 

230 # Add the current backbone node and all remaining blockers (which 

231 # may not even be in this subgraph). 

232 self._add_blockers_of(current) 

233 self._add_single_node(current) 

234 # Any remaining subgraph nodes were not directly connected to the 

235 # backbone nodes. 

236 remaining = xgraph.copy() 

237 remaining.remove_nodes_from(self._locations.keys()) 

238 self._add_graph(remaining) 

239 

240 def _add_blockers_of(self, node: _K) -> None: 

241 """Add all nodes that are ancestors of the given node according to the 

242 full graph. 

243 """ 

244 blockers = set(networkx.algorithms.dag.ancestors(self._xgraph, node)) 

245 blockers.difference_update(self._locations.keys()) 

246 self._add_graph(self._xgraph.subgraph(blockers)) 

247 

248 @property 

249 def width(self) -> int: 

250 """The number of actual (not multiple-of-two) columns in the layout.""" 

251 return (self._x_max - self._x_min) // 2 

252 

253 @property 

254 def nodes(self) -> Iterable[_K]: 

255 """The graph nodes in the order they appear in the layout.""" 

256 return self._locations.keys() 

257 

258 def print(self, stream: TextIO) -> None: 

259 """Print the nodes (but not their edges) as symbols in the right 

260 locations. 

261 

262 This is intended for use as a debugging diagnostic, not part of a real 

263 visualization system. 

264 """ 

265 for row in self: 

266 print(f"{' ' * row.x}●{' ' * (self.width - row.x)} {row.node}", file=stream) 

267 

268 def _external_location(self, x: int) -> int: 

269 """Return the actual (not multiple-of-two, stating from zero) location, 

270 given the internal multiple-of-two. 

271 """ 

272 return (self._x_max - x) // 2 

273 

274 def __iter__(self) -> Iterator[LayoutRow]: 

275 active_edges: dict[_K, set[_K]] = {} 

276 for node, node_x in self._locations.items(): 

277 row = LayoutRow(node, self._external_location(node_x)) 

278 for origin, destinations in active_edges.items(): 

279 if node in destinations: 

280 row.connecting.append((self._external_location(self._locations[origin]), origin)) 

281 destinations.remove(node) 

282 if destinations: 

283 row.continuing.append( 

284 (self._external_location(self._locations[origin]), origin, frozenset(destinations)) 

285 ) 

286 row.connecting.sort(key=str) 

287 row.continuing.sort(key=str) 

288 yield row 

289 active_edges[node] = set(self._xgraph.successors(node)) 

290 

291 

292@dataclasses.dataclass 

293class LayoutRow(Generic[_K]): 

294 """Information about a single text-art row in a graph.""" 

295 

296 node: _K 

297 """Key for the node in the exported NetworkX graph.""" 

298 

299 x: int 

300 """Column of the node's symbol and its outgoing edges.""" 

301 

302 connecting: list[tuple[int, _K]] = dataclasses.field(default_factory=list) 

303 """The columns and node keys of edges that terminate at this row. 

304 """ 

305 

306 continuing: list[tuple[int, _K, frozenset[_K]]] = dataclasses.field(default_factory=list) 

307 """The columns and node keys of edges that continue through this row. 

308 """ 

309 

310 

311@dataclasses.dataclass 

312class ColumnSelector: 

313 """Helper class that weighs the columns a new node could be added to in a 

314 text DAG visualization. 

315 """ 

316 

317 crossing_penalty: int = 1 

318 """Penalty for each ongoing edge the new node's outgoing edge would have to 

319 "hop" if it were put at the candidate column. 

320 """ 

321 

322 interior_penalty: int = 1 

323 """Penalty for adding a new column to the layout between existing columns 

324 (in addition to `insertion_penaly`. 

325 """ 

326 

327 insertion_penalty: int = 2 

328 """Penalty for adding a new column to the layout instead of reusing an 

329 empty one. 

330 

331 This penalty is applied even when there is no empty column; it just cancels 

332 out in that case because it's applied to all candidate columns. 

333 """ 

334 

335 def __call__( 

336 self, 

337 connecting_x: list[int], 

338 node_x: int, 

339 active_columns: Mapping[int, Set[_K]], 

340 x_min: int, 

341 x_max: int, 

342 ) -> int: 

343 """Compute the penalty score for adding a node in the given column. 

344 

345 Parameters 

346 ---------- 

347 connecting_x : `list` [ `int` ] 

348 The columns of incoming edges for this node. All values are even. 

349 node_x : `int` 

350 The column being considered for the new node. Will be odd if it 

351 proposes an insertion between existing columns, or outside the 

352 bounds of ``x_min`` and ``x_max`` if it proposes an insertion 

353 on a side. 

354 active_columns : `~collections.abc.Mapping` [ `int`, \ 

355 `~collections.abc.Set` ] 

356 The columns of nodes already in the visualization (in previous 

357 lines) and the nodes at which their edges terminate. All keys are 

358 even. 

359 x_min : `int` 

360 Current minimum column position (inclusive). Always even. 

361 x_max : `int` 

362 Current maximum column position (exclusive). Always even. 

363 

364 Returns 

365 ------- 

366 penalty : `int` 

367 Penalty score for this location. Nodes should be placed at the 

368 column with the lowest penalty. 

369 """ 

370 # Start with a penalty for inserting a new column between two existing 

371 # columns or on either side, if that's what this is (i.e. x is odd). 

372 penalty = (node_x % 2) * (self.interior_penalty + self.insertion_penalty) 

373 if node_x < x_min: 

374 penalty += self.insertion_penalty 

375 elif node_x > x_max: 

376 penalty += self.insertion_penalty 

377 # If there are no active edges connecting to this node, we're done. 

378 if not connecting_x: 

379 return penalty 

380 # Find the bounds of the horizontal lines that connect 

381 horizontal_min_x = min(connecting_x[0], node_x) 

382 horizontal_max_x = max(connecting_x[-1], node_x) 

383 # Add the (scaled) number of unrelated continuing (vertical) edges that 

384 # the (horizontal) input edges for this node would have to "hop". 

385 penalty += sum( 

386 self.crossing_penalty 

387 for x in range(horizontal_min_x, horizontal_max_x + 2) 

388 if x in active_columns and x not in connecting_x 

389 ) 

390 return penalty