Coverage for tests/cqg_test_utils.py: 11%
80 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-07 17:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-07 17:21 +0000
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
27"""ClusteredQuantumGraph-related utilities to support ctrl_bps testing.
28"""
30import uuid
31from copy import deepcopy
33from lsst.ctrl.bps import ClusteredQuantumGraph, QuantaCluster
34from networkx import is_directed_acyclic_graph
35from qg_test_utils import make_test_quantum_graph
38def check_cqg(cqg, truth=None):
39 """Check ClusteredQuantumGraph for correctness used by unit
40 tests.
42 Parameters
43 ----------
44 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph`
45 ClusteredQuantumGraph to be checked for correctness.
46 truth : `dict` [`str`, `Any`], optional
47 Information describing what this cluster should look like.
48 """
49 # Checks independent of data
51 # Check no cycles, only one edge between same two nodes,
52 assert is_directed_acyclic_graph(cqg._cluster_graph)
54 # Check has all QGraph nodes (include message about duplicate node).
55 node_ids = set()
56 cl_by_label = {}
57 for cluster in cqg.clusters():
58 cl_by_label.setdefault(cluster.label, []).append(cluster)
59 for id_ in cluster.qgraph_node_ids:
60 qnode = cqg.get_quantum_node(id_)
61 assert id_ not in node_ids, (
62 f"Checking cluster {cluster.name}, id {id_} ({qnode.quantum.dataId}) appears more "
63 "than once in CQG."
64 )
65 node_ids.add(id_)
66 assert len(node_ids) == len(cqg._quantum_graph)
68 # If given what should be there, check values.
69 if truth:
70 cqg_info = dump_cqg(cqg)
71 compare_cqg_dicts(truth, cqg_info)
74def replace_node_name(name, label, dims):
75 """Replace node id in cluster name because they
76 change every run and thus make testing difficult.
78 Parameters
79 ----------
80 name : `str`
81 Cluster name
82 label : `str`
83 Cluster label
84 dims : `dict` [`str`, `Any`]
85 Dimension names and values in order to make new name unique.
87 Returns
88 -------
89 name : `str`
90 New name of cluster.
91 """
92 try:
93 name_parts = name.split("_")
94 _ = uuid.UUID(name_parts[0])
95 if len(name_parts) == 1:
96 name = f"NODEONLY_{label}_{str(dims)}"
97 else:
98 name = f"NODENAME_{'_'.join(name_parts[1:])}"
99 except ValueError:
100 pass
101 return name
104def dump_cqg(cqg):
105 """Represent ClusteredQuantumGraph as dictionary for testing.
107 Parameters
108 ----------
109 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph`
110 ClusteredQuantumGraph to be represented as a dictionary.
112 Returns
113 -------
114 info : `dict` [`str`, `Any`]
115 Dictionary represention of ClusteredQuantumGraph.
116 """
117 info = {"name": cqg.name, "nodes": {}}
119 orig_to_new = {}
120 for cluster in cqg.clusters():
121 dims = {}
122 for key, value in cluster.tags.items():
123 if key not in ["label", "node_number"]:
124 dims[key] = value
125 name = replace_node_name(cluster.name, cluster.label, dims)
126 orig_to_new[cluster.name] = name
127 info["nodes"][name] = {"label": cluster.label, "dims": dims, "counts": dict(cluster.quanta_counts)}
129 info["edges"] = []
130 for edge in cqg._cluster_graph.edges:
131 info["edges"].append((orig_to_new[edge[0]], orig_to_new[edge[1]]))
133 return info
136def compare_cqg_dicts(truth, cqg):
137 """Compare dicts representing two ClusteredQuantumGraphs.
139 Parameters
140 ----------
141 truth : `dict` [`str`, `Any`]
142 Representation of the expected ClusteredQuantumGraph.
143 cqg : `dict` [`str`, `Any`]
144 Representation of the calculated ClusteredQuantumGraph.
146 Raises
147 ------
148 AssertionError
149 Whenever discover discrepancy between dicts.
150 """
151 assert truth["name"] == cqg["name"], f"Mismatch name: truth={truth['name']}, cqg={cqg['name']}"
152 assert len(truth["nodes"]) == len(
153 cqg["nodes"]
154 ), f"Mismatch number of nodes: truth={len(truth['nodes'])}, cqg={len(cqg['nodes'])}"
155 for tkey in truth["nodes"]:
156 assert tkey in cqg["nodes"], f"Could not find {tkey} in cqg"
157 tnode = truth["nodes"][tkey]
158 cnode = cqg["nodes"][tkey]
159 assert (
160 tnode["label"] == cnode["label"]
161 ), f"Mismatch cluster label: truth={tnode['label']}, cqg={cnode['label']}"
162 assert (
163 tnode["dims"] == cnode["dims"]
164 ), f"Mismatch cluster dims: truth={tnode['dims']}, cqg={cnode['dims']}"
165 assert (
166 tnode["counts"] == cnode["counts"]
167 ), f"Mismatch cluster quanta counts: truth={tnode['counts']}, cqg={cnode['counts']}"
168 assert set(truth["edges"]) == set(
169 cqg["edges"]
170 ), f"Mismatch edges: truth={truth['edges']}, cqg={cqg['edges']}"
173# T1(1,2) T1(3,4) T4(1,2) T4(3,4)
174# | |
175# T2(1,2) T2(3,4)
176# | |
177# T3(1,2) T3(3,4)
178def make_test_clustered_quantum_graph(outdir):
179 """Make a ClusteredQuantumGraph for testing.
181 Parameters
182 ----------
183 outdir : `str`
184 Root used for the QuantumGraph filename stored
185 in the ClusteredQuantumGraph.
187 Returns
188 -------
189 qgraph : `lsst.pipe.base.QuantumGraph`
190 The fake QuantumGraph created for the test
191 ClusteredQuantumGraph returned separately.
192 cqg : `lsst.ctrl.bps.ClusteredQuantumGraph
193 """
194 qgraph = make_test_quantum_graph()
195 qgraph2 = deepcopy(qgraph) # keep separate copy
197 cqg = ClusteredQuantumGraph("cqg1", qgraph, f"{outdir}/test_file.qgraph")
199 # since random hash ids, create mapping for tests
200 test_lookup = {}
201 for qnode in qgraph:
202 data_id = dict(qnode.quantum.dataId.required)
203 key = f"{qnode.taskDef.label}_{data_id['D1']}_{data_id['D2']}"
204 test_lookup[key] = qnode
206 # Add orphans
207 cluster = QuantaCluster.from_quantum_node(test_lookup["T4_1_2"], "T4_1_2")
208 cqg.add_cluster(cluster)
209 cluster = QuantaCluster.from_quantum_node(test_lookup["T4_3_4"], "T4_3_4")
210 cqg.add_cluster(cluster)
212 # T1,T2,T3 Dim1 = 1, Dim2 = 2
213 qc1 = QuantaCluster.from_quantum_node(test_lookup["T1_1_2"], "T1_1_2")
214 qc2 = QuantaCluster.from_quantum_node(test_lookup["T2_1_2"], "T23_1_2")
215 qc2.add_quantum_node(test_lookup["T3_1_2"])
216 cqg.add_cluster([qc2, qc1]) # reversed to check order is corrected in tests
217 cqg.add_dependency(qc1, qc2)
219 # T1,T2,T3 Dim1 = 3, Dim2 = 4
220 qc1 = QuantaCluster.from_quantum_node(test_lookup["T1_3_4"], "T1_3_4")
221 qc2 = QuantaCluster.from_quantum_node(test_lookup["T2_3_4"], "T23_3_4")
222 qc2.add_quantum_node(test_lookup["T3_3_4"])
223 cqg.add_cluster([qc2, qc1]) # reversed to check order is corrected in tests
224 cqg.add_dependency(qc1, qc2)
226 return qgraph2, cqg