Coverage for tests / test_clustered_quantum_graph.py: 34%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 09:03 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27"""Unit tests for the clustering methods.""" 

28 

29# Turn off "doesn't conform to snake_case naming style" because matching 

30# the unittest casing. 

31# pylint: disable=invalid-name 

32 

33import os 

34import shutil 

35import tempfile 

36import unittest 

37from collections import Counter 

38from pathlib import Path 

39 

40from cqg_test_utils import make_test_clustered_quantum_graph 

41from qg_test_utils import make_test_quantum_graph 

42 

43from lsst.ctrl.bps import ClusteredQuantumGraph, QuantaCluster 

44 

45TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

46 

47 

48class TestQuantaCluster(unittest.TestCase): 

49 """Tests for clustering.""" 

50 

51 def setUp(self): 

52 self.qgraph = make_test_quantum_graph() 

53 self.id1, self.id2, *_ = self.qgraph.quanta_by_task["T1"].values() 

54 self.info1 = self.qgraph.quantum_only_xgraph.nodes[self.id1] 

55 self.info2 = self.qgraph.quantum_only_xgraph.nodes[self.id2] 

56 

57 def tearDown(self): 

58 pass 

59 

60 def testQgraphNodeIds(self): 

61 qc = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

62 self.assertEqual(qc.qgraph_node_ids, frozenset([self.id1])) 

63 

64 def testQuantaCountsNone(self): 

65 qc = QuantaCluster("NoQuanta", "the_label") 

66 self.assertEqual(qc.quanta_counts, Counter()) 

67 

68 def testQuantaCounts(self): 

69 qc = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

70 self.assertEqual(qc.quanta_counts, Counter({"T1": 1})) 

71 

72 def testAddQuantum(self): 

73 qc = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

74 qc.add_quantum(self.id2, self.info2["task_label"]) 

75 self.assertEqual(qc.quanta_counts, Counter({"T1": 2})) 

76 

77 def testStr(self): 

78 qc = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

79 self.assertIn(qc.name, str(qc)) 

80 self.assertIn("T1", str(qc)) 

81 self.assertIn("tags", str(qc)) 

82 

83 def testEqual(self): 

84 qc1 = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

85 qc2 = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

86 self.assertEqual(qc1, qc2) 

87 

88 def testNotEqual(self): 

89 qc1 = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

90 qc2 = QuantaCluster.from_quantum_info(self.id2, self.info2, "{node_number}") 

91 self.assertNotEqual(qc1, qc2) 

92 

93 def testHash(self): 

94 qc1 = QuantaCluster.from_quantum_info(self.id1, self.info1, "{node_number}") 

95 qc2 = QuantaCluster.from_quantum_info(self.id2, self.info2, "{node_number}") 

96 self.assertNotEqual(hash(qc1), hash(qc2)) 

97 

98 

99class TestClusteredQuantumGraph(unittest.TestCase): 

100 """Tests for single_quantum_clustering method.""" 

101 

102 def setUp(self): 

103 self.tmpdir = tempfile.mkdtemp() 

104 self.qgraph, self.cqg1 = make_test_clustered_quantum_graph(self.tmpdir) 

105 

106 def tearDown(self): 

107 shutil.rmtree(self.tmpdir, ignore_errors=True) 

108 

109 def testName(self): 

110 self.assertEqual(self.cqg1.name, "cqg1") 

111 

112 def testQgraph(self): 

113 """Test qgraph method.""" 

114 self.assertEqual(self.cqg1.qgraph, self.qgraph) 

115 

116 def testGetClusterExists(self): 

117 """Test get_cluster method where cluster exists.""" 

118 self.assertEqual("T1_1_2", self.cqg1.get_cluster("T1_1_2").name) 

119 

120 def testGetClusterMissing(self): 

121 """Test get_cluster method where cluster doesn't exist.""" 

122 with self.assertRaises(KeyError): 

123 _ = self.cqg1.get_cluster("Not_There") 

124 

125 def testClusters(self): 

126 """Test clusters method returns in correct order.""" 

127 retval = list(self.cqg1.clusters()) 

128 

129 # Save min and max locations of a label in retval for later comparison. 

130 label_to_index = {} 

131 for index, cluster in enumerate(retval): 

132 minmax = label_to_index.setdefault(cluster.label, (len(retval) + 1, -1)) 

133 label_to_index[cluster.label] = (min(minmax[0], index), max(minmax[1], index)) 

134 

135 # assert see all of T1 before see any of clusterT2T3 

136 self.assertLess(label_to_index["T1"][1], label_to_index["clusterT2T3"][0]) 

137 

138 def testSuccessorsExisting(self): 

139 """Test successors method returns existing successors.""" 

140 self.assertEqual(list(self.cqg1.successors("T1_1_2")), ["T23_1_2"]) 

141 

142 def testSuccessorsNone(self): 

143 """Test successors method handles no successors.""" 

144 # check iterable and empty 

145 self.assertEqual(len(list(self.cqg1.successors("T5_1_2"))), 0) 

146 

147 def testPredecessorsExisting(self): 

148 """Test predecessors method returns existing predecessors.""" 

149 self.assertEqual(list(self.cqg1.predecessors("T23_1_2")), ["T1_1_2"]) 

150 

151 def testPredecessorsNone(self): 

152 """Test predecessors method handles no predecessors.""" 

153 # check iterable and empty 

154 self.assertEqual(len(list(self.cqg1.predecessors("T1_1_2"))), 0) 

155 

156 def testSaveAndLoad(self): 

157 path = Path(f"{self.tmpdir}/save_1.pickle") 

158 self.cqg1.save(path) 

159 self.assertTrue(path.is_file() and path.stat().st_size) 

160 test_cqg = ClusteredQuantumGraph.load(path) 

161 self.assertEqual(self.cqg1, test_cqg) 

162 

163 def testValidateOK(self): 

164 # Test nothing raised on valid clustered quantum graph 

165 self.cqg1.validate() 

166 

167 def testValidateBadTagLabel(self): 

168 # Test nothing raised on valid clustered quantum graph 

169 qc2 = self.cqg1.get_cluster("T23_1_2") 

170 qc2.label = "T23" 

171 with self.assertRaisesRegex(RuntimeError, "Label mismatch in cluster T23_1_2"): 

172 self.cqg1.validate() 

173 

174 def testValidateNotDAG(self): 

175 # Add bad edge to make not a DAG 

176 qc1 = self.cqg1.get_cluster("T1_1_2") 

177 qc2 = self.cqg1.get_cluster("T23_1_2") 

178 self.cqg1.add_dependency(qc2, qc1) 

179 

180 with self.assertRaisesRegex(RuntimeError, "is not a directed acyclic graph"): 

181 self.cqg1.validate() 

182 

183 def testValidateMissingQuanta(self): 

184 # Remove Quanta from cluster 

185 qc2 = self.cqg1.get_cluster("T23_1_2") 

186 qc2._qgraph_node_ids = qc2._qgraph_node_ids[:-1] 

187 

188 with self.assertRaisesRegex(RuntimeError, "does not equal number in quantum graph"): 

189 self.cqg1.validate() 

190 

191 def testValidateDuplicateId(self): 

192 # Add new Quanta with duplicate Quantum 

193 qc1 = self.cqg1.get_cluster("T1_1_2") 

194 quantum_id = next(iter(qc1.qgraph_node_ids)) 

195 quantum_info = self.cqg1.get_quantum_info(quantum_id) 

196 qc = QuantaCluster.from_quantum_info(quantum_id, quantum_info, "DuplicateId") 

197 self.cqg1.add_cluster(qc) 

198 qc2 = self.cqg1.get_cluster("T23_1_2") 

199 self.cqg1.add_dependency(qc2, qc) 

200 

201 with self.assertRaisesRegex(RuntimeError, "occurs in at least 2 clusters"): 

202 self.cqg1.validate() 

203 

204 

205if __name__ == "__main__": 

206 unittest.main()