Coverage for python / lsst / pipe / base / graph_walker.py: 23%
37 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("GraphWalker",)
32from typing import Self
34import networkx
37class GraphWalker[T]:
38 """A helper for traversing directed acyclic graphs.
40 Parameters
41 ----------
42 xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph`
43 Networkx graph to process. Will be consumed during iteration, so this
44 should often be a copy.
46 Notes
47 -----
48 Each iteration yields a `frozenset` of nodes, which may be empty if there
49 are no nodes ready for processing. A node is only considered ready if all
50 of its predecessor nodes have been marked as complete with `finish`.
51 Iteration only completes when all nodes have been finished or failed.
53 `GraphWalker` is not thread-safe; calling one `GraphWalker` method while
54 another is in progress is undefined behavior. It is designed to be used
55 in the management thread or process in a parallel traversal.
56 """
58 def __init__(self, xgraph: networkx.DiGraph | networkx.MultiDiGraph):
59 self._xgraph = xgraph
60 self._ready: set[T] = set(next(iter(networkx.dag.topological_generations(self._xgraph)), []))
61 self._active: set[T] = set()
62 self._incomplete: set[T] = set(self._xgraph)
64 def __iter__(self) -> Self:
65 return self
67 def __next__(self) -> frozenset[T]:
68 if not self._incomplete:
69 raise StopIteration()
70 new_active = frozenset(self._ready)
71 self._active.update(new_active)
72 self._ready.clear()
73 return new_active
75 def finish(self, key: T) -> None:
76 """Mark a node as successfully processed, unblocking (at least in part)
77 iteration over successor nodes.
79 Parameters
80 ----------
81 key : unspecified
82 NetworkX key of the node to mark finished. Does not need to have
83 been returned by the iterator yet.
84 """
85 self._incomplete.remove(key)
86 self._active.discard(key)
87 self._ready.discard(key)
88 successors = list(self._xgraph.successors(key))
89 for successor in successors:
90 assert successor not in self._active, (
91 "A node downstream of an active one should not have been yielded yet."
92 )
93 if all(
94 predecessor not in self._incomplete for predecessor in self._xgraph.predecessors(successor)
95 ):
96 self._ready.add(successor)
98 def fail(self, key: T) -> list[T]:
99 """Mark a node as unsuccessfully processed, permanently blocking all
100 recursive descendants.
102 Parameters
103 ----------
104 key : unspecified
105 NetworkX key of the node to mark as a failure. Does not need to
106 have been returned by the iterator yet.
108 Returns
109 -------
110 blocked : `list`
111 NetworkX keys of nodes that were recursive descendants of the
112 failed node, and will hence never be yielded by the iterator.
113 """
114 self._incomplete.remove(key)
115 self._active.discard(key)
116 self._ready.discard(key)
117 descendants = list(networkx.dag.descendants(self._xgraph, key))
118 self._xgraph.remove_node(key)
119 self._xgraph.remove_nodes_from(descendants)
120 self._incomplete.difference_update(descendants)
121 return descendants