Coverage for tests / test_pre_transform.py: 30%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 08:47 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27import errno 

28import logging 

29import os 

30import shutil 

31import sys 

32import tempfile 

33import unittest 

34from pathlib import Path 

35 

36from lsst.ctrl.bps import BpsConfig, BpsSubprocessError, ClusteredQuantumGraph 

37from lsst.ctrl.bps.pre_transform import cluster_quanta, create_quantum_graph, execute, update_quantum_graph 

38from lsst.pipe.base.tests.mocks import InMemoryRepo 

39 

40TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class TestExecute(unittest.TestCase): 

45 """Test execution.""" 

46 

47 def setUp(self): 

48 self.file = tempfile.NamedTemporaryFile("w+") 

49 self.logger = logging.getLogger("lsst.ctrl.bps") 

50 

51 def tearDown(self): 

52 self.file.close() 

53 

54 def testSuccessfulExecution(self): 

55 """Test exit status if command succeeded.""" 

56 content = "Successful execution" 

57 command = f"{sys.executable} -c 'print(\"{content}\")'" 

58 with self.assertLogs(logger=self.logger, level="INFO") as cm: 

59 status = execute(command, self.file.name) 

60 self.assertIn(content, cm.output[0]) 

61 self.file.seek(0) 

62 file_contents = self.file.read() 

63 self.assertIn(command, file_contents) 

64 self.assertIn(content, file_contents) 

65 self.assertEqual(status, 0) 

66 

67 def testFailingExecution(self): 

68 """Test exit status if command failed.""" 

69 status = execute("false", self.file.name) 

70 self.assertIn("false", self.file.read()) 

71 self.assertNotEqual(status, 0) 

72 

73 

74class TestCreatingQuantumGraph(unittest.TestCase): 

75 """Test quantum graph creation.""" 

76 

77 def setUp(self): 

78 self.tmpdir = tempfile.mkdtemp(dir=TESTDIR) 

79 self.settings = { 

80 "createQuantumGraph": "touch {qgraphFile}", 

81 "submitPath": self.tmpdir, 

82 "whenSaveJobQgraph": "NEVER", 

83 "uniqProcName": "my_test", 

84 "qgraphFileTemplate": "{uniqProcName}.qg", 

85 } 

86 self.logger = logging.getLogger("lsst.ctrl.bps") 

87 

88 def tearDown(self): 

89 shutil.rmtree(self.tmpdir, ignore_errors=True) 

90 

91 def testSuccess(self): 

92 """Test if a new quantum graph was created successfully.""" 

93 config = BpsConfig(self.settings, search_order=[]) 

94 with self.assertLogs(logger=self.logger, level="INFO") as cm: 

95 qgraph_filename = create_quantum_graph(config, self.tmpdir) 

96 _, command = config.search("createQuantumGraph", opt={"curvals": {"qgraphFile": qgraph_filename}}) 

97 self.assertIn(command, cm.output[0]) 

98 self.assertTrue(os.path.exists(qgraph_filename)) 

99 

100 def testCommandMissing(self): 

101 """Test if error is caught when the command is missing.""" 

102 del self.settings["createQuantumGraph"] 

103 config = BpsConfig(self.settings, search_order=[]) 

104 with self.assertRaisesRegex(KeyError, "command.*not found"): 

105 create_quantum_graph(config, self.tmpdir) 

106 

107 def testFailure(self): 

108 """Test if error is caught when the quantum graph creation fails.""" 

109 self.settings["createQuantumGraph"] = "bash -c 'exit 2'" 

110 config = BpsConfig(self.settings, search_order=[]) 

111 with self.assertRaises(BpsSubprocessError) as cm: 

112 create_quantum_graph(config, self.tmpdir) 

113 self.assertEqual(cm.exception.errno, errno.ENOENT) 

114 self.assertIn("non-zero exit code", str(cm.exception)) 

115 

116 

117class TestUpdatingQuantumGraph(unittest.TestCase): 

118 """Test quantum graph update.""" 

119 

120 def setUp(self): 

121 self.tmpdir = tempfile.mkdtemp(dir=TESTDIR) 

122 self.settings = { 

123 "updateQuantumGraph": "bash -c 'echo foo > {qgraphFile}'", 

124 "submitPath": self.tmpdir, 

125 "whenSaveJobQgraph": "NEVER", 

126 "uniqProcName": "my_test", 

127 "qgraphFileTemplate": "{uniqProcName}.qg", 

128 "inputQgraphFile": f"{self.tmpdir}/src.qg", 

129 } 

130 self.logger = logging.getLogger("lsst.ctrl.bps") 

131 

132 # Create a file in the temporary directory that will serve as 

133 # the file with a quantum graph that needs updating. 

134 self.src = Path(self.settings["inputQgraphFile"]) 

135 self.src.write_text("foo\n") 

136 

137 self.backup = Path(f"{self.src.parent}/{self.src.stem}_orig{self.src.suffix}") 

138 

139 def tearDown(self): 

140 shutil.rmtree(self.tmpdir, ignore_errors=True) 

141 

142 def testSuccess(self): 

143 """Test if the quantum graph was updated.""" 

144 config = BpsConfig(self.settings, search_order=[]) 

145 with self.assertLogs(logger=self.logger, level="INFO") as cm: 

146 update_quantum_graph(config, str(self.src), self.tmpdir) 

147 _, command = config.search("updateQuantumGraph", opt={"curvals": {"qgraphFile": str(self.src)}}) 

148 self.assertIn("backing up", cm.output[0].lower()) 

149 self.assertIn("completed", cm.output[1].lower()) 

150 self.assertIn(command, cm.output[2]) 

151 self.assertTrue(self.src.read_text(), "bar\n") 

152 self.assertTrue(self.backup.is_file()) 

153 self.assertTrue(self.backup.read_text(), "foo\n") 

154 

155 def testSuccessInPlace(self): 

156 """Test if a quantum graph was updated inplace.""" 

157 config = BpsConfig(self.settings, search_order=[]) 

158 with self.assertLogs(logger=self.logger, level="INFO") as cm: 

159 update_quantum_graph(config, str(self.src), self.tmpdir, inplace=True) 

160 _, command = config.search("updateQuantumGraph", opt={"curvals": {"qgraphFile": str(self.src)}}) 

161 self.assertIn(command, cm.output[0]) 

162 self.assertTrue(self.src.read_text(), "bar\n") 

163 self.assertFalse(self.backup.is_file()) 

164 

165 def testCommandMissing(self): 

166 """Test if error is caught when the command is missing.""" 

167 del self.settings["updateQuantumGraph"] 

168 config = BpsConfig(self.settings, search_order=[]) 

169 with self.assertRaisesRegex(KeyError, "command.*not found"): 

170 update_quantum_graph(config, str(self.src), self.tmpdir) 

171 

172 def testFailure(self): 

173 """Test if error is caught when the command fails.""" 

174 self.settings["updateQuantumGraph"] = "bash -c 'exit 2'" 

175 config = BpsConfig(self.settings, search_order=[]) 

176 with self.assertRaises(BpsSubprocessError) as cm: 

177 update_quantum_graph(config, str(self.src), self.tmpdir) 

178 self.assertEqual(cm.exception.errno, errno.ENOENT) 

179 self.assertRegex(str(cm.exception), "non-zero exit code") 

180 

181 

182class TestClusterQuanta(unittest.TestCase): 

183 """Test cluster_quanta method. Other tests cover functions 

184 cluster_quanta calls so mocking them here. 

185 """ 

186 

187 @unittest.mock.patch.object(ClusteredQuantumGraph, "validate") 

188 def testValidate(self, mock_validate): 

189 """Test that actually calls validate per config.""" 

190 mock_validate.side_effect = RuntimeError("Fake error") 

191 settings = { 

192 "clusterAlgorithm": "lsst.ctrl.bps.quantum_clustering_funcs.single_quantum_clustering", 

193 "uniqProcName": "my_test", 

194 "validateClusteredQgraph": True, 

195 } 

196 config = BpsConfig(settings, search_order=[]) 

197 with InMemoryRepo() as repo: 

198 qgraph = repo.make_quantum_graph() 

199 with self.assertRaisesRegex(RuntimeError, "Fake error"): 

200 _ = cluster_quanta(config, qgraph, "a_name") 

201 

202 @unittest.mock.patch.object(ClusteredQuantumGraph, "validate") 

203 def testNoValidate(self, mock_validate): 

204 """Test that doesn't call validate per config.""" 

205 mock_validate.side_effect = RuntimeError("Fake error") 

206 settings = { 

207 "clusterAlgorithm": "lsst.ctrl.bps.quantum_clustering_funcs.single_quantum_clustering", 

208 "uniqProcName": "my_test", 

209 "validateClusteredQgraph": False, 

210 } 

211 config = BpsConfig(settings, search_order=[]) 

212 with InMemoryRepo() as repo: 

213 qgraph = repo.make_quantum_graph() 

214 _ = cluster_quanta(config, qgraph, "a_name") 

215 

216 

217if __name__ == "__main__": 

218 unittest.main()