Coverage for tests/cqg_test_utils.py: 11%

80 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-22 11:03 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27"""ClusteredQuantumGraph-related utilities to support ctrl_bps testing. 

28""" 

29 

30import uuid 

31from copy import deepcopy 

32 

33from lsst.ctrl.bps import ClusteredQuantumGraph, QuantaCluster 

34from networkx import is_directed_acyclic_graph 

35from qg_test_utils import make_test_quantum_graph 

36 

37 

38def check_cqg(cqg, truth=None): 

39 """Check ClusteredQuantumGraph for correctness used by unit 

40 tests. 

41 

42 Parameters 

43 ---------- 

44 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph` 

45 ClusteredQuantumGraph to be checked for correctness. 

46 truth : `dict` [`str`, `Any`], optional 

47 Information describing what this cluster should look like. 

48 """ 

49 # Checks independent of data 

50 

51 # Check no cycles, only one edge between same two nodes, 

52 assert is_directed_acyclic_graph(cqg._cluster_graph) 

53 

54 # Check has all QGraph nodes (include message about duplicate node). 

55 node_ids = set() 

56 cl_by_label = {} 

57 for cluster in cqg.clusters(): 

58 cl_by_label.setdefault(cluster.label, []).append(cluster) 

59 for id_ in cluster.qgraph_node_ids: 

60 qnode = cqg.get_quantum_node(id_) 

61 assert id_ not in node_ids, ( 

62 f"Checking cluster {cluster.name}, id {id_} ({qnode.quantum.dataId}) appears more " 

63 "than once in CQG." 

64 ) 

65 node_ids.add(id_) 

66 assert len(node_ids) == len(cqg._quantum_graph) 

67 

68 # If given what should be there, check values. 

69 if truth: 

70 cqg_info = dump_cqg(cqg) 

71 compare_cqg_dicts(truth, cqg_info) 

72 

73 

74def replace_node_name(name, label, dims): 

75 """Replace node id in cluster name because they 

76 change every run and thus make testing difficult. 

77 

78 Parameters 

79 ---------- 

80 name : `str` 

81 Cluster name 

82 label : `str` 

83 Cluster label 

84 dims : `dict` [`str`, `Any`] 

85 Dimension names and values in order to make new name unique. 

86 

87 Returns 

88 ------- 

89 name : `str` 

90 New name of cluster. 

91 """ 

92 try: 

93 name_parts = name.split("_") 

94 _ = uuid.UUID(name_parts[0]) 

95 if len(name_parts) == 1: 

96 name = f"NODEONLY_{label}_{str(dims)}" 

97 else: 

98 name = f"NODENAME_{'_'.join(name_parts[1:])}" 

99 except ValueError: 

100 pass 

101 return name 

102 

103 

104def dump_cqg(cqg): 

105 """Represent ClusteredQuantumGraph as dictionary for testing. 

106 

107 Parameters 

108 ---------- 

109 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph` 

110 ClusteredQuantumGraph to be represented as a dictionary. 

111 

112 Returns 

113 ------- 

114 info : `dict` [`str`, `Any`] 

115 Dictionary represention of ClusteredQuantumGraph. 

116 """ 

117 info = {"name": cqg.name, "nodes": {}} 

118 

119 orig_to_new = {} 

120 for cluster in cqg.clusters(): 

121 dims = {} 

122 for key, value in cluster.tags.items(): 

123 if key not in ["label", "node_number"]: 

124 dims[key] = value 

125 name = replace_node_name(cluster.name, cluster.label, dims) 

126 orig_to_new[cluster.name] = name 

127 info["nodes"][name] = {"label": cluster.label, "dims": dims, "counts": dict(cluster.quanta_counts)} 

128 

129 info["edges"] = [] 

130 for edge in cqg._cluster_graph.edges: 

131 info["edges"].append((orig_to_new[edge[0]], orig_to_new[edge[1]])) 

132 

133 return info 

134 

135 

136def compare_cqg_dicts(truth, cqg): 

137 """Compare dicts representing two ClusteredQuantumGraphs. 

138 

139 Parameters 

140 ---------- 

141 truth : `dict` [`str`, `Any`] 

142 Representation of the expected ClusteredQuantumGraph. 

143 cqg : `dict` [`str`, `Any`] 

144 Representation of the calculated ClusteredQuantumGraph. 

145 

146 Raises 

147 ------ 

148 AssertionError 

149 Whenever discover discrepancy between dicts. 

150 """ 

151 assert truth["name"] == cqg["name"], f"Mismatch name: truth={truth['name']}, cqg={cqg['name']}" 

152 assert len(truth["nodes"]) == len( 

153 cqg["nodes"] 

154 ), f"Mismatch number of nodes: truth={len(truth['nodes'])}, cqg={len(cqg['nodes'])}" 

155 for tkey in truth["nodes"]: 

156 assert tkey in cqg["nodes"], f"Could not find {tkey} in cqg" 

157 tnode = truth["nodes"][tkey] 

158 cnode = cqg["nodes"][tkey] 

159 assert ( 

160 tnode["label"] == cnode["label"] 

161 ), f"Mismatch cluster label: truth={tnode['label']}, cqg={cnode['label']}" 

162 assert ( 

163 tnode["dims"] == cnode["dims"] 

164 ), f"Mismatch cluster dims: truth={tnode['dims']}, cqg={cnode['dims']}" 

165 assert ( 

166 tnode["counts"] == cnode["counts"] 

167 ), f"Mismatch cluster quanta counts: truth={tnode['counts']}, cqg={cnode['counts']}" 

168 assert set(truth["edges"]) == set( 

169 cqg["edges"] 

170 ), f"Mismatch edges: truth={truth['edges']}, cqg={cqg['edges']}" 

171 

172 

173# T1(1,2) T1(3,4) T4(1,2) T4(3,4) 

174# | | 

175# T2(1,2) T2(3,4) 

176# | | 

177# T3(1,2) T3(3,4) 

178def make_test_clustered_quantum_graph(outdir): 

179 """Make a ClusteredQuantumGraph for testing. 

180 

181 Parameters 

182 ---------- 

183 outdir : `str` 

184 Root used for the QuantumGraph filename stored 

185 in the ClusteredQuantumGraph. 

186 

187 Returns 

188 ------- 

189 qgraph : `lsst.pipe.base.QuantumGraph` 

190 The fake QuantumGraph created for the test 

191 ClusteredQuantumGraph returned separately. 

192 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph 

193 """ 

194 qgraph = make_test_quantum_graph() 

195 qgraph2 = deepcopy(qgraph) # keep separate copy 

196 

197 cqg = ClusteredQuantumGraph("cqg1", qgraph, f"{outdir}/test_file.qgraph") 

198 

199 # since random hash ids, create mapping for tests 

200 test_lookup = {} 

201 for qnode in qgraph: 

202 data_id = dict(qnode.quantum.dataId.required) 

203 key = f"{qnode.taskDef.label}_{data_id['D1']}_{data_id['D2']}" 

204 test_lookup[key] = qnode 

205 

206 # Add orphans 

207 cluster = QuantaCluster.from_quantum_node(test_lookup["T4_1_2"], "T4_1_2") 

208 cqg.add_cluster(cluster) 

209 cluster = QuantaCluster.from_quantum_node(test_lookup["T4_3_4"], "T4_3_4") 

210 cqg.add_cluster(cluster) 

211 

212 # T1,T2,T3 Dim1 = 1, Dim2 = 2 

213 qc1 = QuantaCluster.from_quantum_node(test_lookup["T1_1_2"], "T1_1_2") 

214 qc2 = QuantaCluster.from_quantum_node(test_lookup["T2_1_2"], "T23_1_2") 

215 qc2.add_quantum_node(test_lookup["T3_1_2"]) 

216 cqg.add_cluster([qc2, qc1]) # reversed to check order is corrected in tests 

217 cqg.add_dependency(qc1, qc2) 

218 

219 # T1,T2,T3 Dim1 = 3, Dim2 = 4 

220 qc1 = QuantaCluster.from_quantum_node(test_lookup["T1_3_4"], "T1_3_4") 

221 qc2 = QuantaCluster.from_quantum_node(test_lookup["T2_3_4"], "T23_3_4") 

222 qc2.add_quantum_node(test_lookup["T3_3_4"]) 

223 cqg.add_cluster([qc2, qc1]) # reversed to check order is corrected in tests 

224 cqg.add_dependency(qc1, qc2) 

225 

226 return qgraph2, cqg