Coverage for tests / test_pre_transform.py: 30%
119 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:53 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:53 +0000
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
27import errno
28import logging
29import os
30import shutil
31import sys
32import tempfile
33import unittest
34from pathlib import Path
36from lsst.ctrl.bps import BpsConfig, BpsSubprocessError, ClusteredQuantumGraph
37from lsst.ctrl.bps.pre_transform import cluster_quanta, create_quantum_graph, execute, update_quantum_graph
38from lsst.pipe.base.tests.mocks import InMemoryRepo
40TESTDIR = os.path.abspath(os.path.dirname(__file__))
41_LOG = logging.getLogger(__name__)
44class TestExecute(unittest.TestCase):
45 """Test execution."""
47 def setUp(self):
48 self.file = tempfile.NamedTemporaryFile("w+")
49 self.logger = logging.getLogger("lsst.ctrl.bps")
51 def tearDown(self):
52 self.file.close()
54 def testSuccessfulExecution(self):
55 """Test exit status if command succeeded."""
56 content = "Successful execution"
57 command = f"{sys.executable} -c 'print(\"{content}\")'"
58 with self.assertLogs(logger=self.logger, level="INFO") as cm:
59 status = execute(command, self.file.name)
60 self.assertIn(content, cm.output[0])
61 self.file.seek(0)
62 file_contents = self.file.read()
63 self.assertIn(command, file_contents)
64 self.assertIn(content, file_contents)
65 self.assertEqual(status, 0)
67 def testFailingExecution(self):
68 """Test exit status if command failed."""
69 status = execute("false", self.file.name)
70 self.assertIn("false", self.file.read())
71 self.assertNotEqual(status, 0)
74class TestCreatingQuantumGraph(unittest.TestCase):
75 """Test quantum graph creation."""
77 def setUp(self):
78 self.tmpdir = tempfile.mkdtemp(dir=TESTDIR)
79 self.settings = {
80 "createQuantumGraph": "touch {qgraphFile}",
81 "submitPath": self.tmpdir,
82 "whenSaveJobQgraph": "NEVER",
83 "uniqProcName": "my_test",
84 "qgraphFileTemplate": "{uniqProcName}.qg",
85 }
86 self.logger = logging.getLogger("lsst.ctrl.bps")
88 def tearDown(self):
89 shutil.rmtree(self.tmpdir, ignore_errors=True)
91 def testSuccess(self):
92 """Test if a new quantum graph was created successfully."""
93 config = BpsConfig(self.settings, search_order=[])
94 with self.assertLogs(logger=self.logger, level="INFO") as cm:
95 qgraph_filename = create_quantum_graph(config, self.tmpdir)
96 _, command = config.search("createQuantumGraph", opt={"curvals": {"qgraphFile": qgraph_filename}})
97 self.assertIn(command, cm.output[0])
98 self.assertTrue(os.path.exists(qgraph_filename))
100 def testCommandMissing(self):
101 """Test if error is caught when the command is missing."""
102 del self.settings["createQuantumGraph"]
103 config = BpsConfig(self.settings, search_order=[])
104 with self.assertRaisesRegex(KeyError, "command.*not found"):
105 create_quantum_graph(config, self.tmpdir)
107 def testFailure(self):
108 """Test if error is caught when the quantum graph creation fails."""
109 self.settings["createQuantumGraph"] = "bash -c 'exit 2'"
110 config = BpsConfig(self.settings, search_order=[])
111 with self.assertRaises(BpsSubprocessError) as cm:
112 create_quantum_graph(config, self.tmpdir)
113 self.assertEqual(cm.exception.errno, errno.ENOENT)
114 self.assertIn("non-zero exit code", str(cm.exception))
117class TestUpdatingQuantumGraph(unittest.TestCase):
118 """Test quantum graph update."""
120 def setUp(self):
121 self.tmpdir = tempfile.mkdtemp(dir=TESTDIR)
122 self.settings = {
123 "updateQuantumGraph": "bash -c 'echo foo > {qgraphFile}'",
124 "submitPath": self.tmpdir,
125 "whenSaveJobQgraph": "NEVER",
126 "uniqProcName": "my_test",
127 "qgraphFileTemplate": "{uniqProcName}.qg",
128 "inputQgraphFile": f"{self.tmpdir}/src.qg",
129 }
130 self.logger = logging.getLogger("lsst.ctrl.bps")
132 # Create a file in the temporary directory that will serve as
133 # the file with a quantum graph that needs updating.
134 self.src = Path(self.settings["inputQgraphFile"])
135 self.src.write_text("foo\n")
137 self.backup = Path(f"{self.src.parent}/{self.src.stem}_orig{self.src.suffix}")
139 def tearDown(self):
140 shutil.rmtree(self.tmpdir, ignore_errors=True)
142 def testSuccess(self):
143 """Test if the quantum graph was updated."""
144 config = BpsConfig(self.settings, search_order=[])
145 with self.assertLogs(logger=self.logger, level="INFO") as cm:
146 update_quantum_graph(config, str(self.src), self.tmpdir)
147 _, command = config.search("updateQuantumGraph", opt={"curvals": {"qgraphFile": str(self.src)}})
148 self.assertIn("backing up", cm.output[0].lower())
149 self.assertIn("completed", cm.output[1].lower())
150 self.assertIn(command, cm.output[2])
151 self.assertTrue(self.src.read_text(), "bar\n")
152 self.assertTrue(self.backup.is_file())
153 self.assertTrue(self.backup.read_text(), "foo\n")
155 def testSuccessInPlace(self):
156 """Test if a quantum graph was updated inplace."""
157 config = BpsConfig(self.settings, search_order=[])
158 with self.assertLogs(logger=self.logger, level="INFO") as cm:
159 update_quantum_graph(config, str(self.src), self.tmpdir, inplace=True)
160 _, command = config.search("updateQuantumGraph", opt={"curvals": {"qgraphFile": str(self.src)}})
161 self.assertIn(command, cm.output[0])
162 self.assertTrue(self.src.read_text(), "bar\n")
163 self.assertFalse(self.backup.is_file())
165 def testCommandMissing(self):
166 """Test if error is caught when the command is missing."""
167 del self.settings["updateQuantumGraph"]
168 config = BpsConfig(self.settings, search_order=[])
169 with self.assertRaisesRegex(KeyError, "command.*not found"):
170 update_quantum_graph(config, str(self.src), self.tmpdir)
172 def testFailure(self):
173 """Test if error is caught when the command fails."""
174 self.settings["updateQuantumGraph"] = "bash -c 'exit 2'"
175 config = BpsConfig(self.settings, search_order=[])
176 with self.assertRaises(BpsSubprocessError) as cm:
177 update_quantum_graph(config, str(self.src), self.tmpdir)
178 self.assertEqual(cm.exception.errno, errno.ENOENT)
179 self.assertRegex(str(cm.exception), "non-zero exit code")
182class TestClusterQuanta(unittest.TestCase):
183 """Test cluster_quanta method. Other tests cover functions
184 cluster_quanta calls so mocking them here.
185 """
187 @unittest.mock.patch.object(ClusteredQuantumGraph, "validate")
188 def testValidate(self, mock_validate):
189 """Test that actually calls validate per config."""
190 mock_validate.side_effect = RuntimeError("Fake error")
191 settings = {
192 "clusterAlgorithm": "lsst.ctrl.bps.quantum_clustering_funcs.single_quantum_clustering",
193 "uniqProcName": "my_test",
194 "validateClusteredQgraph": True,
195 }
196 config = BpsConfig(settings, search_order=[])
197 with InMemoryRepo() as repo:
198 qgraph = repo.make_quantum_graph()
199 with self.assertRaisesRegex(RuntimeError, "Fake error"):
200 _ = cluster_quanta(config, qgraph, "a_name")
202 @unittest.mock.patch.object(ClusteredQuantumGraph, "validate")
203 def testNoValidate(self, mock_validate):
204 """Test that doesn't call validate per config."""
205 mock_validate.side_effect = RuntimeError("Fake error")
206 settings = {
207 "clusterAlgorithm": "lsst.ctrl.bps.quantum_clustering_funcs.single_quantum_clustering",
208 "uniqProcName": "my_test",
209 "validateClusteredQgraph": False,
210 }
211 config = BpsConfig(settings, search_order=[])
212 with InMemoryRepo() as repo:
213 qgraph = repo.make_quantum_graph()
214 _ = cluster_quanta(config, qgraph, "a_name")
217if __name__ == "__main__":
218 unittest.main()