Coverage for tests / test_graph_walker.py: 17%
48 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 08:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 08:44 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28from __future__ import annotations
30import unittest
31import uuid
33import networkx
35from lsst.daf.butler import DataCoordinate
36from lsst.pipe.base.graph_walker import GraphWalker
37from lsst.pipe.base.tests.mocks import DynamicConnectionConfig, InMemoryRepo
40class GraphWalkerTestCase(unittest.TestCase):
41 """Tests for the GraphWalker utility class."""
43 def test_iteration(self) -> None:
44 helper = InMemoryRepo("base.yaml", "spatial.yaml")
45 self.enterContext(helper)
46 helper.add_task("a", dimensions=["visit", "detector"])
47 helper.add_task("b", dimensions=["visit", "detector"])
48 helper.add_task(
49 "c", # Gathers outputs from 'b' by visit.
50 dimensions=["visit"],
51 inputs={
52 "input_connection": DynamicConnectionConfig(
53 dataset_type_name=f"dataset_auto{helper.last_auto_dataset_type_index}",
54 dimensions=["visit", "detector"],
55 multiple=True,
56 )
57 },
58 )
59 helper.add_task(
60 "d", # Scatters visit-level outputs from 'c' per-detector.
61 dimensions=["visit", "detector"],
62 inputs={
63 "input_connection": DynamicConnectionConfig(
64 dataset_type_name=f"dataset_auto{helper.last_auto_dataset_type_index}",
65 dimensions=["visit"],
66 )
67 },
68 )
69 qg = helper.make_quantum_graph_builder().finish(attach_datastore_records=False).assemble()
70 walker = GraphWalker[uuid.UUID](qg.quantum_only_xgraph.copy())
71 # First set of unblocked nodes consists of all 'a' nodes.
72 a_nodes: dict[DataCoordinate, uuid.UUID] = {
73 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker)
74 }
75 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in a_nodes.values()], ["a"] * 8)
76 # Mark the detector=1 'a' nodes as finished, and iterate the walker;
77 # this should unblock all of the detector=1 'b' nodes.
78 walker.finish(
79 a_nodes.pop(
80 DataCoordinate.standardize(
81 instrument="Cam1", visit=1, detector=1, universe=helper.butler.dimensions
82 )
83 )
84 )
85 walker.finish(
86 a_nodes.pop(
87 DataCoordinate.standardize(
88 instrument="Cam1", visit=2, detector=1, universe=helper.butler.dimensions
89 )
90 )
91 )
92 b_nodes: dict[DataCoordinate, uuid.UUID] = {
93 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker)
94 }
95 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in b_nodes.values()], ["b"] * 2)
96 self.assertEqual([data_id["detector"] for data_id in b_nodes.keys()], [1] * 2)
97 # Mark one quantum of 'a' as a failure, and check that this returns all
98 # downstream nodes.
99 bad_node = a_nodes.pop(
100 DataCoordinate.standardize(
101 instrument="Cam1", visit=1, detector=2, universe=helper.butler.dimensions
102 )
103 )
104 downstream_of_bad = set(walker.fail(bad_node))
105 self.assertEqual(downstream_of_bad, set(networkx.dag.descendants(qg.quantum_only_xgraph, bad_node)))
106 # Mark the remaining 'a' nodes as finished. This should unblock the
107 # remaining 5 'b' nodes, for a total of 7 unblocked.
108 for n in a_nodes.values():
109 walker.finish(n)
110 a_nodes.clear()
111 b_nodes.update({qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker)})
112 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in b_nodes.values()], ["b"] * 7)
113 # Check that iterating the walker again yields nothing new. Note that
114 # this is not the same as the iterator being exhausted, which only
115 # happens after we've traversed the whole graph.
116 self.assertFalse(next(walker))
117 # Mark those 'b' nodes as finished, which should unblock one 'c' node
118 # (the other was downstream of the failure).
119 for n in b_nodes.values():
120 walker.finish(n)
121 c_nodes: dict[DataCoordinate, uuid.UUID] = {
122 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker)
123 }
124 self.assertEqual(len(c_nodes), 1)
125 c_data_id, c_node = c_nodes.popitem()
126 self.assertEqual(qg.quantum_only_xgraph.nodes[c_node]["task_label"], "c")
127 self.assertEqual(
128 c_data_id,
129 DataCoordinate.standardize(instrument="Cam1", visit=2, universe=helper.butler.dimensions),
130 )
131 # Mark the 'c' node as finished, which should unblock the 4 'd' nodes
132 # for visit=2.
133 walker.finish(c_node)
134 d_nodes: dict[DataCoordinate, uuid.UUID] = {
135 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker)
136 }
137 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in d_nodes.values()], ["d"] * 4)
138 self.assertEqual([data_id["visit"] for data_id in d_nodes.keys()], [2] * 4)
139 # Finish the 'd' nodes and check that the walker iteration is done.
140 for n in d_nodes.values():
141 walker.finish(n)
142 with self.assertRaises(StopIteration):
143 next(walker)