Coverage for python/lsst/pipe/base/pipeline_graph/visualization/_layout.py: 23%
154 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-27 10:12 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-27 10:12 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("Layout", "ColumnSelector", "LayoutRow")
31import dataclasses
32from collections.abc import Iterable, Iterator, Mapping, Set
33from typing import Generic, TextIO, TypeVar
35import networkx
36import networkx.algorithms.components
37import networkx.algorithms.dag
38import networkx.algorithms.shortest_paths
39import networkx.algorithms.traversal
41_K = TypeVar("_K")
44class Layout(Generic[_K]):
45 """A class that positions nodes and edges in text-art graph visualizations.
47 Parameters
48 ----------
49 xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph`
50 NetworkX export of a `.PipelineGraph` being visualized.
51 column_selector : `ColumnSelector`, optional
52 Parameterized helper for selecting which column each node should be
53 added to.
54 """
56 def __init__(
57 self,
58 xgraph: networkx.DiGraph | networkx.MultiDiGraph,
59 column_selector: ColumnSelector | None = None,
60 ):
61 if column_selector is None:
62 column_selector = ColumnSelector()
63 self._xgraph = xgraph
64 self._column_selector = column_selector
65 # Mapping from the column (i.e. 'x') of an already-positioned node to
66 # the node keys its outgoing edges terminate at. These and all other
67 # column/x variables are multiples of 2 when they refer to
68 # already-existing columns, allowing potential insertion of new columns
69 # between them to be represented by odd integers. Positions are also
70 # inverted from the order they are usually displayed; it's best to
71 # think of them as the distance (to the left) from the column where
72 # text appears on the right. This is natural because prefer for nodes
73 # to be close to that text when possible (or maybe it's historical, and
74 # it's just a lot of work to re-invert the algorithm now that it's
75 # written).
76 self._active_columns: dict[int, set[_K]] = {}
77 # Mapping from node key to its column.
78 self._locations: dict[_K, int] = {}
79 # Minimum and maximum column (may go negative; will be shifted as
80 # needed before actual display).
81 self._x_min = 0
82 self._x_max = 0
83 # Run the algorithm!
84 self._add_graph(self._xgraph)
85 del self._active_columns
87 def _add_graph(self, xgraph: networkx.DiGraph | networkx.MultiDiGraph) -> None:
88 """Highest-level routine for the layout algorithm.
90 Parameters
91 ----------
92 xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph`
93 Graph or subgraph to add to the layout.
94 """
95 # Start by identifying unconnected subgraphs ("components"); we'll
96 # display these in series, as the first of our many attempts to
97 # minimize tangles of edges.
98 component_xgraphs_and_orders = []
99 single_nodes = []
100 for component_nodes in networkx.components.weakly_connected_components(xgraph):
101 if len(component_nodes) == 1:
102 single_nodes.append(component_nodes.pop())
103 else:
104 component_xgraph = xgraph.subgraph(component_nodes)
105 component_order = list(
106 networkx.algorithms.dag.lexicographical_topological_sort(component_xgraph, key=str)
107 )
108 component_xgraphs_and_orders.append((component_xgraph, component_order))
109 # Add all single-node components in lexicographical order.
110 single_nodes.sort(key=str)
111 for node in single_nodes:
112 self._add_single_node(node)
113 # Sort component graphs by their size and then str of their first node.
114 component_xgraphs_and_orders.sort(key=lambda t: (len(t[1]), str(t[1][0])))
115 # Add subgraphs in that order.
116 for component_xgraph, component_order in component_xgraphs_and_orders:
117 self._add_connected_graph(component_xgraph, component_order)
119 def _add_single_node(self, node: _K) -> None:
120 """Add a single node to the layout."""
121 assert node not in self._locations
122 if not self._locations:
123 # Special-case the first node in a component (disconnectd
124 # subgraph).
125 self._locations[node] = 0
126 self._active_columns[0] = set(self._xgraph.successors(node))
127 return
128 # The candidate_x list holds columns where we could insert this node.
129 # We start with new columns on the outside and a new column between
130 # each pair of existing columns. These inner nodes are usually not
131 # very good candidates, but it's simplest to always include them and
132 # let the penalty system in ColumnSelector take care of it.
133 candidate_x = [self._x_max + 2, self._x_min - 2]
134 # The columns holding edges that will connect to this node.
135 connecting_x = []
136 # Iterate over active columns to populate the above.
137 for active_column_x, active_column_endpoints in self._active_columns.items():
138 if node in active_column_endpoints:
139 connecting_x.append(active_column_x)
140 # Delete this node from the active columns it is in, and delete any
141 # entries that now have empty sets.
142 for x in connecting_x:
143 destinations = self._active_columns[x]
144 destinations.remove(node)
145 if not destinations:
146 del self._active_columns[x]
147 # Add all empty columns between the current min and max as candidates.
148 for x in range(self._x_min, self._x_max + 2, 2):
149 if x not in self._active_columns:
150 candidate_x.append(x)
151 # Sort the list of connecting columns so we can easily get min and max.
152 connecting_x.sort()
153 best_x = min(
154 candidate_x,
155 key=lambda x: self._column_selector(
156 connecting_x, x, self._active_columns, self._x_min, self._x_max
157 ),
158 )
159 if best_x % 2:
160 # We're inserting a new column between two existing ones; shift
161 # all existing column values above this one to make room while
162 # using only even numbers.
163 best_x = self._shift(best_x)
164 self._x_min = min(self._x_min, best_x)
165 self._x_max = max(self._x_max, best_x)
167 self._locations[node] = best_x
168 successors = set(self._xgraph.successors(node))
169 if successors:
170 self._active_columns[best_x] = successors
172 def _shift(self, x: int) -> int:
173 """Shift all columns above the given one up by two, allowing a new
174 column to be inserted while leaving all columns as even integers.
175 """
176 for node, old_x in self._locations.items():
177 if old_x > x:
178 self._locations[node] += 2
179 self._active_columns = {
180 old_x + 2 if old_x > x else old_x: destinations
181 for old_x, destinations in self._active_columns.items()
182 }
183 self._x_max += 2
184 return x + 1
186 def _add_connected_graph(
187 self, xgraph: networkx.DiGraph | networkx.MultiDiGraph, order: list[_K] | None = None
188 ) -> None:
189 """Add a subgraph whose nodes are connected.
191 Parameters
192 ----------
193 xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph`
194 Graph or subgraph to add to the layout.
195 order : `list`, optional
196 List providing a lexicographical/topological sort of ``xgraph``.
197 Will be computed if not provided.
198 """
199 if order is None:
200 order = list(networkx.algorithms.dag.lexicographical_topological_sort(xgraph, key=str))
201 # Find the longest path between two nodes, which we'll call the
202 # "backbone" of our layout; we'll step through this path and add
203 # recurse via calls to `_add_graph` on the nodes that we think should
204 # go between the backbone nodes.
205 backbone: list[_K] = networkx.algorithms.dag.dag_longest_path(xgraph, topo_order=order)
206 # Add the first backbone node and any ancestors according to the full
207 # graph (it can't have ancestors in this _subgraph_ because they'd have
208 # been part of the longest path themselves, but the subgraph doesn't
209 # have a complete picture).
210 current = backbone.pop(0)
211 self._add_blockers_of(current)
212 self._add_single_node(current)
213 # Remember all recursive descendants of the current backbone node for
214 # later use.
215 descendants = frozenset(networkx.algorithms.dag.descendants(xgraph, current))
216 while backbone:
217 current = backbone.pop(0)
218 # Find descendants of the previous node that are:
219 # - in this subgraph
220 # - not descendants of the current node.
221 followers_of_previous = set(descendants)
222 descendants = frozenset(networkx.algorithms.dag.descendants(xgraph, current))
223 followers_of_previous.remove(current)
224 followers_of_previous.difference_update(descendants)
225 followers_of_previous.difference_update(self._locations.keys())
226 # Add those followers of the previous node. We like adding these
227 # here because this terminates edges as soon as we can, freeing up
228 # those columns for new new nodes.
229 self._add_graph(xgraph.subgraph(followers_of_previous))
230 # Add the current backbone node and all remaining blockers (which
231 # may not even be in this subgraph).
232 self._add_blockers_of(current)
233 self._add_single_node(current)
234 # Any remaining subgraph nodes were not directly connected to the
235 # backbone nodes.
236 remaining = xgraph.copy()
237 remaining.remove_nodes_from(self._locations.keys())
238 self._add_graph(remaining)
240 def _add_blockers_of(self, node: _K) -> None:
241 """Add all nodes that are ancestors of the given node according to the
242 full graph.
243 """
244 blockers = set(networkx.algorithms.dag.ancestors(self._xgraph, node))
245 blockers.difference_update(self._locations.keys())
246 self._add_graph(self._xgraph.subgraph(blockers))
248 @property
249 def width(self) -> int:
250 """The number of actual (not multiple-of-two) columns in the layout."""
251 return (self._x_max - self._x_min) // 2
253 @property
254 def nodes(self) -> Iterable[_K]:
255 """The graph nodes in the order they appear in the layout."""
256 return self._locations.keys()
258 def print(self, stream: TextIO) -> None:
259 """Print the nodes (but not their edges) as symbols in the right
260 locations.
262 This is intended for use as a debugging diagnostic, not part of a real
263 visualization system.
265 Parameters
266 ----------
267 stream : `io.TextIO`
268 Output stream to use for printing.
269 """
270 for row in self:
271 print(f"{' ' * row.x}●{' ' * (self.width - row.x)} {row.node}", file=stream)
273 def _external_location(self, x: int) -> int:
274 """Return the actual (not multiple-of-two, stating from zero) location,
275 given the internal multiple-of-two.
276 """
277 return (self._x_max - x) // 2
279 def __iter__(self) -> Iterator[LayoutRow]:
280 active_edges: dict[_K, set[_K]] = {}
281 for node, node_x in self._locations.items():
282 row = LayoutRow(node, self._external_location(node_x))
283 for origin, destinations in active_edges.items():
284 if node in destinations:
285 row.connecting.append((self._external_location(self._locations[origin]), origin))
286 destinations.remove(node)
287 if destinations:
288 row.continuing.append(
289 (self._external_location(self._locations[origin]), origin, frozenset(destinations))
290 )
291 row.connecting.sort(key=str)
292 row.continuing.sort(key=str)
293 yield row
294 active_edges[node] = set(self._xgraph.successors(node))
297@dataclasses.dataclass
298class LayoutRow(Generic[_K]):
299 """Information about a single text-art row in a graph."""
301 node: _K
302 """Key for the node in the exported NetworkX graph."""
304 x: int
305 """Column of the node's symbol and its outgoing edges."""
307 connecting: list[tuple[int, _K]] = dataclasses.field(default_factory=list)
308 """The columns and node keys of edges that terminate at this row.
309 """
311 continuing: list[tuple[int, _K, frozenset[_K]]] = dataclasses.field(default_factory=list)
312 """The columns and node keys of edges that continue through this row.
313 """
316@dataclasses.dataclass
317class ColumnSelector:
318 """Helper class that weighs the columns a new node could be added to in a
319 text DAG visualization.
320 """
322 crossing_penalty: int = 1
323 """Penalty for each ongoing edge the new node's outgoing edge would have to
324 "hop" if it were put at the candidate column.
325 """
327 interior_penalty: int = 1
328 """Penalty for adding a new column to the layout between existing columns
329 (in addition to `insertion_penaly`.
330 """
332 insertion_penalty: int = 2
333 """Penalty for adding a new column to the layout instead of reusing an
334 empty one.
336 This penalty is applied even when there is no empty column; it just cancels
337 out in that case because it's applied to all candidate columns.
338 """
340 def __call__(
341 self,
342 connecting_x: list[int],
343 node_x: int,
344 active_columns: Mapping[int, Set[_K]],
345 x_min: int,
346 x_max: int,
347 ) -> int:
348 """Compute the penalty score for adding a node in the given column.
350 Parameters
351 ----------
352 connecting_x : `list` [ `int` ]
353 The columns of incoming edges for this node. All values are even.
354 node_x : `int`
355 The column being considered for the new node. Will be odd if it
356 proposes an insertion between existing columns, or outside the
357 bounds of ``x_min`` and ``x_max`` if it proposes an insertion
358 on a side.
359 active_columns : `~collections.abc.Mapping` [ `int`, \
360 `~collections.abc.Set` ]
361 The columns of nodes already in the visualization (in previous
362 lines) and the nodes at which their edges terminate. All keys are
363 even.
364 x_min : `int`
365 Current minimum column position (inclusive). Always even.
366 x_max : `int`
367 Current maximum column position (exclusive). Always even.
369 Returns
370 -------
371 penalty : `int`
372 Penalty score for this location. Nodes should be placed at the
373 column with the lowest penalty.
374 """
375 # Start with a penalty for inserting a new column between two existing
376 # columns or on either side, if that's what this is (i.e. x is odd).
377 penalty = (node_x % 2) * (self.interior_penalty + self.insertion_penalty)
378 if node_x < x_min:
379 penalty += self.insertion_penalty
380 elif node_x > x_max:
381 penalty += self.insertion_penalty
382 # If there are no active edges connecting to this node, we're done.
383 if not connecting_x:
384 return penalty
385 # Find the bounds of the horizontal lines that connect
386 horizontal_min_x = min(connecting_x[0], node_x)
387 horizontal_max_x = max(connecting_x[-1], node_x)
388 # Add the (scaled) number of unrelated continuing (vertical) edges that
389 # the (horizontal) input edges for this node would have to "hop".
390 penalty += sum(
391 self.crossing_penalty
392 for x in range(horizontal_min_x, horizontal_max_x + 2)
393 if x in active_columns and x not in connecting_x
394 )
395 return penalty