Coverage for tests/cqg_test_utils.py: 11%

80 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-09 02:20 -0700

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21"""ClusteredQuantumGraph-related utilities to support ctrl_bps testing. 

22""" 

23 

24import uuid 

25from copy import deepcopy 

26 

27from lsst.ctrl.bps import ClusteredQuantumGraph, QuantaCluster 

28from networkx import is_directed_acyclic_graph 

29from qg_test_utils import make_test_quantum_graph 

30 

31 

32def check_cqg(cqg, truth=None): 

33 """Check ClusteredQuantumGraph for correctness used by unit 

34 tests. 

35 

36 Parameters 

37 ---------- 

38 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph` 

39 ClusteredQuantumGraph to be checked for correctness. 

40 truth : `dict` [`str`, `Any`], optional 

41 Information describing what this cluster should look like. 

42 """ 

43 # Checks independent of data 

44 

45 # Check no cycles, only one edge between same two nodes, 

46 assert is_directed_acyclic_graph(cqg._cluster_graph) 

47 

48 # Check has all QGraph nodes (include message about duplicate node). 

49 node_ids = set() 

50 cl_by_label = {} 

51 for cluster in cqg.clusters(): 

52 cl_by_label.setdefault(cluster.label, []).append(cluster) 

53 for id_ in cluster.qgraph_node_ids: 

54 qnode = cqg.get_quantum_node(id_) 

55 assert id_ not in node_ids, ( 

56 f"Checking cluster {cluster.name}, id {id_} ({qnode.quantum.dataId.byName()}) appears more " 

57 "than once in CQG." 

58 ) 

59 node_ids.add(id_) 

60 assert len(node_ids) == len(cqg._quantum_graph) 

61 

62 # If given what should be there, check values. 

63 if truth: 

64 cqg_info = dump_cqg(cqg) 

65 compare_cqg_dicts(truth, cqg_info) 

66 

67 

68def replace_node_name(name, label, dims): 

69 """Replace node id in cluster name because they 

70 change every run and thus make testing difficult. 

71 

72 Parameters 

73 ---------- 

74 name : `str` 

75 Cluster name 

76 label : `str` 

77 Cluster label 

78 dims : `dict` [`str`, `Any`] 

79 Dimension names and values in order to make new name unique. 

80 

81 Returns 

82 ------- 

83 name : `str` 

84 New name of cluster. 

85 """ 

86 try: 

87 name_parts = name.split("_") 

88 _ = uuid.UUID(name_parts[0]) 

89 if len(name_parts) == 1: 

90 name = f"NODEONLY_{label}_{str(dims)}" 

91 else: 

92 name = f"NODENAME_{'_'.join(name_parts[1:])}" 

93 except ValueError: 

94 pass 

95 return name 

96 

97 

98def dump_cqg(cqg): 

99 """Represent ClusteredQuantumGraph as dictionary for testing. 

100 

101 Parameters 

102 ---------- 

103 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph` 

104 ClusteredQuantumGraph to be represented as a dictionary. 

105 

106 Returns 

107 ------- 

108 info : `dict` [`str`, `Any`] 

109 Dictionary represention of ClusteredQuantumGraph. 

110 """ 

111 info = {"name": cqg.name, "nodes": {}} 

112 

113 orig_to_new = {} 

114 for cluster in cqg.clusters(): 

115 dims = {} 

116 for key, value in cluster.tags.items(): 

117 if key not in ["label", "node_number"]: 

118 dims[key] = value 

119 name = replace_node_name(cluster.name, cluster.label, dims) 

120 orig_to_new[cluster.name] = name 

121 info["nodes"][name] = {"label": cluster.label, "dims": dims, "counts": dict(cluster.quanta_counts)} 

122 

123 info["edges"] = [] 

124 for edge in cqg._cluster_graph.edges: 

125 info["edges"].append((orig_to_new[edge[0]], orig_to_new[edge[1]])) 

126 

127 return info 

128 

129 

130def compare_cqg_dicts(truth, cqg): 

131 """Compare dicts representing two ClusteredQuantumGraphs. 

132 

133 Parameters 

134 ---------- 

135 truth : `dict` [`str`, `Any`] 

136 Representation of the expected ClusteredQuantumGraph. 

137 cqg : `dict` [`str`, `Any`] 

138 Representation of the calculated ClusteredQuantumGraph. 

139 

140 Raises 

141 ------ 

142 AssertionError 

143 Whenever discover discrepancy between dicts. 

144 """ 

145 assert truth["name"] == cqg["name"], f"Mismatch name: truth={truth['name']}, cqg={cqg['name']}" 

146 assert len(truth["nodes"]) == len( 

147 cqg["nodes"] 

148 ), f"Mismatch number of nodes: truth={len(truth['nodes'])}, cqg={len(cqg['nodes'])}" 

149 for tkey in truth["nodes"]: 

150 assert tkey in cqg["nodes"], f"Could not find {tkey} in cqg" 

151 tnode = truth["nodes"][tkey] 

152 cnode = cqg["nodes"][tkey] 

153 assert ( 

154 tnode["label"] == cnode["label"] 

155 ), f"Mismatch cluster label: truth={tnode['label']}, cqg={cnode['label']}" 

156 assert ( 

157 tnode["dims"] == cnode["dims"] 

158 ), f"Mismatch cluster dims: truth={tnode['dims']}, cqg={cnode['dims']}" 

159 assert ( 

160 tnode["counts"] == cnode["counts"] 

161 ), f"Mismatch cluster quanta counts: truth={tnode['counts']}, cqg={cnode['counts']}" 

162 assert set(truth["edges"]) == set( 

163 cqg["edges"] 

164 ), f"Mismatch edges: truth={truth['edges']}, cqg={cqg['edges']}" 

165 

166 

167# T1(1,2) T1(3,4) T4(1,2) T4(3,4) 

168# | | 

169# T2(1,2) T2(3,4) 

170# | | 

171# T3(1,2) T3(3,4) 

172def make_test_clustered_quantum_graph(outdir): 

173 """Make a ClusteredQuantumGraph for testing. 

174 

175 Parameters 

176 ---------- 

177 outdir : `str` 

178 Root used for the QuantumGraph filename stored 

179 in the ClusteredQuantumGraph. 

180 

181 Returns 

182 ------- 

183 qgraph : `lsst.pipe.base.QuantumGraph` 

184 The fake QuantumGraph created for the test 

185 ClusteredQuantumGraph returned separately. 

186 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph 

187 """ 

188 qgraph = make_test_quantum_graph() 

189 qgraph2 = deepcopy(qgraph) # keep separate copy 

190 

191 cqg = ClusteredQuantumGraph("cqg1", qgraph, f"{outdir}/test_file.qgraph") 

192 

193 # since random hash ids, create mapping for tests 

194 test_lookup = {} 

195 for qnode in qgraph: 

196 data_id = qnode.quantum.dataId.byName() 

197 key = f"{qnode.taskDef.label}_{data_id['D1']}_{data_id['D2']}" 

198 test_lookup[key] = qnode 

199 

200 # Add orphans 

201 cluster = QuantaCluster.from_quantum_node(test_lookup["T4_1_2"], "T4_1_2") 

202 cqg.add_cluster(cluster) 

203 cluster = QuantaCluster.from_quantum_node(test_lookup["T4_3_4"], "T4_3_4") 

204 cqg.add_cluster(cluster) 

205 

206 # T1,T2,T3 Dim1 = 1, Dim2 = 2 

207 qc1 = QuantaCluster.from_quantum_node(test_lookup["T1_1_2"], "T1_1_2") 

208 qc2 = QuantaCluster.from_quantum_node(test_lookup["T2_1_2"], "T23_1_2") 

209 qc2.add_quantum_node(test_lookup["T3_1_2"]) 

210 cqg.add_cluster([qc2, qc1]) # reversed to check order is corrected in tests 

211 cqg.add_dependency(qc1, qc2) 

212 

213 # T1,T2,T3 Dim1 = 3, Dim2 = 4 

214 qc1 = QuantaCluster.from_quantum_node(test_lookup["T1_3_4"], "T1_3_4") 

215 qc2 = QuantaCluster.from_quantum_node(test_lookup["T2_3_4"], "T23_3_4") 

216 qc2.add_quantum_node(test_lookup["T3_3_4"]) 

217 cqg.add_cluster([qc2, qc1]) # reversed to check order is corrected in tests 

218 cqg.add_dependency(qc1, qc2) 

219 

220 return qgraph2, cqg