Coverage for tests/cqg_test_utils.py: 10%
53 statements
« prev ^ index » next coverage.py v6.4, created at 2022-05-24 11:00 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-05-24 11:00 +0000
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21"""ClusteredQuantumGraph-related utilities to support ctrl_bps testing.
22"""
24import uuid
26from networkx import is_directed_acyclic_graph
29def check_cqg(cqg, truth=None):
30 """Check ClusteredQuantumGraph for correctness used by unit
31 tests.
33 Parameters
34 ----------
35 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph`
36 ClusteredQuantumGraph to be checked for correctness.
37 truth : `dict` [`str`, `Any`], optional
38 Information describing what this cluster should look like.
39 """
40 # Checks independent of data
42 # Check no cycles, only one edge between same two nodes,
43 assert is_directed_acyclic_graph(cqg._cluster_graph)
45 # Check has all QGraph nodes (include message about duplicate node).
46 node_ids = set()
47 cl_by_label = {}
48 for cluster in cqg.clusters():
49 cl_by_label.setdefault(cluster.label, []).append(cluster)
50 for id_ in cluster.qgraph_node_ids:
51 qnode = cqg.get_quantum_node(id_)
52 assert id_ not in node_ids, (
53 f"Checking cluster {cluster.name}, id {id_} ({qnode.quantum.dataId.byName()}) appears more "
54 "than once in CQG."
55 )
56 node_ids.add(id_)
57 assert len(node_ids) == len(cqg._quantum_graph)
59 # If given what should be there, check values.
60 if truth:
61 cqg_info = dump_cqg(cqg)
62 compare_cqg_dicts(truth, cqg_info)
65def replace_node_name(name, label, dims):
66 """Replace node id in cluster name because they
67 change every run and thus make testing difficult.
69 Parameters
70 ----------
71 name : `str`
72 Cluster name
73 label : `str`
74 Cluster label
75 dims : `dict` [`str`, `Any`]
76 Dimension names and values in order to make new name unique.
78 Returns
79 -------
80 name : `str`
81 New name of cluster.
82 """
83 try:
84 name_parts = name.split("_")
85 _ = uuid.UUID(name_parts[0])
86 if len(name_parts) == 1:
87 name = f"NODEONLY_{label}_{str(dims)}"
88 else:
89 name = f"NODENAME_{'_'.join(name_parts[1:])}"
90 except ValueError:
91 pass
92 return name
95def dump_cqg(cqg):
96 """Represent ClusteredQuantumGraph as dictionary for testing.
98 Parameters
99 ----------
100 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph`
101 ClusteredQuantumGraph to be represented as a dictionary.
103 Returns
104 -------
105 info : `dict` [`str`, `Any`]
106 Dictionary represention of ClusteredQuantumGraph.
107 """
108 info = {"name": cqg.name, "nodes": {}}
110 orig_to_new = {}
111 for cluster in cqg.clusters():
112 dims = {}
113 for key, value in cluster.tags.items():
114 if key not in ["label", "node_number"]:
115 dims[key] = value
116 name = replace_node_name(cluster.name, cluster.label, dims)
117 orig_to_new[cluster.name] = name
118 info["nodes"][name] = {"label": cluster.label, "dims": dims, "counts": dict(cluster.quanta_counts)}
120 info["edges"] = []
121 for edge in cqg._cluster_graph.edges:
122 info["edges"].append((orig_to_new[edge[0]], orig_to_new[edge[1]]))
124 return info
127def compare_cqg_dicts(truth, cqg):
128 """Compare dicts representing two ClusteredQuantumGraphs.
130 Parameters
131 ----------
132 truth : `dict` [`str`, `Any`]
133 Representation of the expected ClusteredQuantumGraph.
134 cqg : `dict` [`str`, `Any`]
135 Representation of the calculated ClusteredQuantumGraph.
137 Raises
138 ------
139 AssertionError
140 Whenever discover discrepancy between dicts.
141 """
142 assert truth["name"] == cqg["name"], f"Mismatch name: truth={truth['name']}, cqg={cqg['name']}"
143 assert len(truth["nodes"]) == len(
144 cqg["nodes"]
145 ), f"Mismatch number of nodes: truth={len(truth['nodes'])}, cqg={len(cqg['nodes'])}"
146 for tkey in truth["nodes"]:
147 assert tkey in cqg["nodes"], f"Could not find {tkey} in cqg"
148 tnode = truth["nodes"][tkey]
149 cnode = cqg["nodes"][tkey]
150 assert (
151 tnode["label"] == cnode["label"]
152 ), f"Mismatch cluster label: truth={tnode['label']}, cqg={cnode['label']}"
153 assert (
154 tnode["dims"] == cnode["dims"]
155 ), f"Mismatch cluster dims: truth={tnode['dims']}, cqg={cnode['dims']}"
156 assert (
157 tnode["counts"] == cnode["counts"]
158 ), f"Mismatch cluster quanta counts: truth={tnode['counts']}, cqg={cnode['counts']}"
159 assert set(truth["edges"]) == set(
160 cqg["edges"]
161 ), f"Mismatch edges: truth={truth['edges']}, cqg={cqg['edges']}"