Coverage for tests / test_graph_walker.py: 17%

48 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 08:44 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30import unittest 

31import uuid 

32 

33import networkx 

34 

35from lsst.daf.butler import DataCoordinate 

36from lsst.pipe.base.graph_walker import GraphWalker 

37from lsst.pipe.base.tests.mocks import DynamicConnectionConfig, InMemoryRepo 

38 

39 

40class GraphWalkerTestCase(unittest.TestCase): 

41 """Tests for the GraphWalker utility class.""" 

42 

43 def test_iteration(self) -> None: 

44 helper = InMemoryRepo("base.yaml", "spatial.yaml") 

45 self.enterContext(helper) 

46 helper.add_task("a", dimensions=["visit", "detector"]) 

47 helper.add_task("b", dimensions=["visit", "detector"]) 

48 helper.add_task( 

49 "c", # Gathers outputs from 'b' by visit. 

50 dimensions=["visit"], 

51 inputs={ 

52 "input_connection": DynamicConnectionConfig( 

53 dataset_type_name=f"dataset_auto{helper.last_auto_dataset_type_index}", 

54 dimensions=["visit", "detector"], 

55 multiple=True, 

56 ) 

57 }, 

58 ) 

59 helper.add_task( 

60 "d", # Scatters visit-level outputs from 'c' per-detector. 

61 dimensions=["visit", "detector"], 

62 inputs={ 

63 "input_connection": DynamicConnectionConfig( 

64 dataset_type_name=f"dataset_auto{helper.last_auto_dataset_type_index}", 

65 dimensions=["visit"], 

66 ) 

67 }, 

68 ) 

69 qg = helper.make_quantum_graph_builder().finish(attach_datastore_records=False).assemble() 

70 walker = GraphWalker[uuid.UUID](qg.quantum_only_xgraph.copy()) 

71 # First set of unblocked nodes consists of all 'a' nodes. 

72 a_nodes: dict[DataCoordinate, uuid.UUID] = { 

73 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker) 

74 } 

75 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in a_nodes.values()], ["a"] * 8) 

76 # Mark the detector=1 'a' nodes as finished, and iterate the walker; 

77 # this should unblock all of the detector=1 'b' nodes. 

78 walker.finish( 

79 a_nodes.pop( 

80 DataCoordinate.standardize( 

81 instrument="Cam1", visit=1, detector=1, universe=helper.butler.dimensions 

82 ) 

83 ) 

84 ) 

85 walker.finish( 

86 a_nodes.pop( 

87 DataCoordinate.standardize( 

88 instrument="Cam1", visit=2, detector=1, universe=helper.butler.dimensions 

89 ) 

90 ) 

91 ) 

92 b_nodes: dict[DataCoordinate, uuid.UUID] = { 

93 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker) 

94 } 

95 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in b_nodes.values()], ["b"] * 2) 

96 self.assertEqual([data_id["detector"] for data_id in b_nodes.keys()], [1] * 2) 

97 # Mark one quantum of 'a' as a failure, and check that this returns all 

98 # downstream nodes. 

99 bad_node = a_nodes.pop( 

100 DataCoordinate.standardize( 

101 instrument="Cam1", visit=1, detector=2, universe=helper.butler.dimensions 

102 ) 

103 ) 

104 downstream_of_bad = set(walker.fail(bad_node)) 

105 self.assertEqual(downstream_of_bad, set(networkx.dag.descendants(qg.quantum_only_xgraph, bad_node))) 

106 # Mark the remaining 'a' nodes as finished. This should unblock the 

107 # remaining 5 'b' nodes, for a total of 7 unblocked. 

108 for n in a_nodes.values(): 

109 walker.finish(n) 

110 a_nodes.clear() 

111 b_nodes.update({qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker)}) 

112 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in b_nodes.values()], ["b"] * 7) 

113 # Check that iterating the walker again yields nothing new. Note that 

114 # this is not the same as the iterator being exhausted, which only 

115 # happens after we've traversed the whole graph. 

116 self.assertFalse(next(walker)) 

117 # Mark those 'b' nodes as finished, which should unblock one 'c' node 

118 # (the other was downstream of the failure). 

119 for n in b_nodes.values(): 

120 walker.finish(n) 

121 c_nodes: dict[DataCoordinate, uuid.UUID] = { 

122 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker) 

123 } 

124 self.assertEqual(len(c_nodes), 1) 

125 c_data_id, c_node = c_nodes.popitem() 

126 self.assertEqual(qg.quantum_only_xgraph.nodes[c_node]["task_label"], "c") 

127 self.assertEqual( 

128 c_data_id, 

129 DataCoordinate.standardize(instrument="Cam1", visit=2, universe=helper.butler.dimensions), 

130 ) 

131 # Mark the 'c' node as finished, which should unblock the 4 'd' nodes 

132 # for visit=2. 

133 walker.finish(c_node) 

134 d_nodes: dict[DataCoordinate, uuid.UUID] = { 

135 qg.quantum_only_xgraph.nodes[n]["data_id"]: n for n in next(walker) 

136 } 

137 self.assertEqual([qg.quantum_only_xgraph.nodes[n]["task_label"] for n in d_nodes.values()], ["d"] * 4) 

138 self.assertEqual([data_id["visit"] for data_id in d_nodes.keys()], [2] * 4) 

139 # Finish the 'd' nodes and check that the walker iteration is done. 

140 for n in d_nodes.values(): 

141 walker.finish(n) 

142 with self.assertRaises(StopIteration): 

143 next(walker)