Coverage for tests/cqg_test_utils.py: 13%

59 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 14:46 -0800

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21"""ClusteredQuantumGraph-related utilities to support ctrl_bps testing. 

22""" 

23 

24import uuid 

25 

26from lsst.ctrl.bps.quantum_clustering_funcs import dimension_clustering 

27from networkx import is_directed_acyclic_graph 

28from qg_test_utils import make_test_quantum_graph 

29 

30 

31def check_cqg(cqg, truth=None): 

32 """Check ClusteredQuantumGraph for correctness used by unit 

33 tests. 

34 

35 Parameters 

36 ---------- 

37 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph` 

38 ClusteredQuantumGraph to be checked for correctness. 

39 truth : `dict` [`str`, `Any`], optional 

40 Information describing what this cluster should look like. 

41 """ 

42 # Checks independent of data 

43 

44 # Check no cycles, only one edge between same two nodes, 

45 assert is_directed_acyclic_graph(cqg._cluster_graph) 

46 

47 # Check has all QGraph nodes (include message about duplicate node). 

48 node_ids = set() 

49 cl_by_label = {} 

50 for cluster in cqg.clusters(): 

51 cl_by_label.setdefault(cluster.label, []).append(cluster) 

52 for id_ in cluster.qgraph_node_ids: 

53 qnode = cqg.get_quantum_node(id_) 

54 assert id_ not in node_ids, ( 

55 f"Checking cluster {cluster.name}, id {id_} ({qnode.quantum.dataId.byName()}) appears more " 

56 "than once in CQG." 

57 ) 

58 node_ids.add(id_) 

59 assert len(node_ids) == len(cqg._quantum_graph) 

60 

61 # If given what should be there, check values. 

62 if truth: 

63 cqg_info = dump_cqg(cqg) 

64 compare_cqg_dicts(truth, cqg_info) 

65 

66 

67def replace_node_name(name, label, dims): 

68 """Replace node id in cluster name because they 

69 change every run and thus make testing difficult. 

70 

71 Parameters 

72 ---------- 

73 name : `str` 

74 Cluster name 

75 label : `str` 

76 Cluster label 

77 dims : `dict` [`str`, `Any`] 

78 Dimension names and values in order to make new name unique. 

79 

80 Returns 

81 ------- 

82 name : `str` 

83 New name of cluster. 

84 """ 

85 try: 

86 name_parts = name.split("_") 

87 _ = uuid.UUID(name_parts[0]) 

88 if len(name_parts) == 1: 

89 name = f"NODEONLY_{label}_{str(dims)}" 

90 else: 

91 name = f"NODENAME_{'_'.join(name_parts[1:])}" 

92 except ValueError: 

93 pass 

94 return name 

95 

96 

97def dump_cqg(cqg): 

98 """Represent ClusteredQuantumGraph as dictionary for testing. 

99 

100 Parameters 

101 ---------- 

102 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph` 

103 ClusteredQuantumGraph to be represented as a dictionary. 

104 

105 Returns 

106 ------- 

107 info : `dict` [`str`, `Any`] 

108 Dictionary represention of ClusteredQuantumGraph. 

109 """ 

110 info = {"name": cqg.name, "nodes": {}} 

111 

112 orig_to_new = {} 

113 for cluster in cqg.clusters(): 

114 dims = {} 

115 for key, value in cluster.tags.items(): 

116 if key not in ["label", "node_number"]: 

117 dims[key] = value 

118 name = replace_node_name(cluster.name, cluster.label, dims) 

119 orig_to_new[cluster.name] = name 

120 info["nodes"][name] = {"label": cluster.label, "dims": dims, "counts": dict(cluster.quanta_counts)} 

121 

122 info["edges"] = [] 

123 for edge in cqg._cluster_graph.edges: 

124 info["edges"].append((orig_to_new[edge[0]], orig_to_new[edge[1]])) 

125 

126 return info 

127 

128 

129def compare_cqg_dicts(truth, cqg): 

130 """Compare dicts representing two ClusteredQuantumGraphs. 

131 

132 Parameters 

133 ---------- 

134 truth : `dict` [`str`, `Any`] 

135 Representation of the expected ClusteredQuantumGraph. 

136 cqg : `dict` [`str`, `Any`] 

137 Representation of the calculated ClusteredQuantumGraph. 

138 

139 Raises 

140 ------ 

141 AssertionError 

142 Whenever discover discrepancy between dicts. 

143 """ 

144 assert truth["name"] == cqg["name"], f"Mismatch name: truth={truth['name']}, cqg={cqg['name']}" 

145 assert len(truth["nodes"]) == len( 

146 cqg["nodes"] 

147 ), f"Mismatch number of nodes: truth={len(truth['nodes'])}, cqg={len(cqg['nodes'])}" 

148 for tkey in truth["nodes"]: 

149 assert tkey in cqg["nodes"], f"Could not find {tkey} in cqg" 

150 tnode = truth["nodes"][tkey] 

151 cnode = cqg["nodes"][tkey] 

152 assert ( 

153 tnode["label"] == cnode["label"] 

154 ), f"Mismatch cluster label: truth={tnode['label']}, cqg={cnode['label']}" 

155 assert ( 

156 tnode["dims"] == cnode["dims"] 

157 ), f"Mismatch cluster dims: truth={tnode['dims']}, cqg={cnode['dims']}" 

158 assert ( 

159 tnode["counts"] == cnode["counts"] 

160 ), f"Mismatch cluster quanta counts: truth={tnode['counts']}, cqg={cnode['counts']}" 

161 assert set(truth["edges"]) == set( 

162 cqg["edges"] 

163 ), f"Mismatch edges: truth={truth['edges']}, cqg={cqg['edges']}" 

164 

165 

166def make_test_clustered_quantum_graph(config): 

167 qgraph = make_test_quantum_graph() 

168 cqg = dimension_clustering(config, qgraph, "test_cqg") 

169 return cqg