Coverage for tests / test_drivers.py: 39%

153 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:53 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27"""Unit tests for drivers.py.""" 

28 

29import logging 

30import os 

31import shutil 

32import tempfile 

33import unittest 

34 

35import yaml 

36 

37from lsst.ctrl.bps import WmsRunReport, WmsStates 

38from lsst.ctrl.bps.bps_reports import compile_code_summary, compile_job_summary 

39from lsst.ctrl.bps.drivers import _init_submission_driver, ping_driver, report_driver, status_driver 

40 

41TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

42 

43 

44class TestInitSubmissionDriver(unittest.TestCase): 

45 """Test submission.""" 

46 

47 def setUp(self): 

48 self.cwd = os.getcwd() 

49 self.tmpdir = tempfile.mkdtemp(dir=TESTDIR) 

50 

51 def tearDown(self): 

52 shutil.rmtree(self.tmpdir, ignore_errors=True) 

53 

54 @unittest.mock.patch("lsst.ctrl.bps.initialize.BPS_DEFAULTS", {}) 

55 def testDeprecatedOutCollection(self): 

56 config = { 

57 "submitPath": "bad", 

58 "payload": { 

59 "outCollection": "bad", 

60 "outputRun": "bad", 

61 }, 

62 } 

63 with tempfile.NamedTemporaryFile(mode="w+", suffix=".yaml") as file: 

64 yaml.dump(config, stream=file) 

65 with self.assertRaisesRegex(KeyError, "outCollection"): 

66 _init_submission_driver(file.name) 

67 

68 @unittest.mock.patch("lsst.ctrl.bps.initialize.BPS_DEFAULTS", {}) 

69 def testMissingOutputRun(self): 

70 config = {"submitPath": "bad"} 

71 with tempfile.NamedTemporaryFile(mode="w+", suffix=".yaml") as file: 

72 yaml.dump(config, stream=file) 

73 with self.assertRaisesRegex(KeyError, "outputRun"): 

74 _init_submission_driver(file.name) 

75 

76 @unittest.mock.patch("lsst.ctrl.bps.initialize.BPS_DEFAULTS", {}) 

77 def testMissingSubmitPath(self): 

78 config = {"payload": {"outputRun": "bad"}} 

79 with tempfile.NamedTemporaryFile(mode="w+", suffix=".yaml") as file: 

80 yaml.dump(config, stream=file) 

81 with self.assertRaisesRegex(KeyError, "submitPath"): 

82 _init_submission_driver(file.name) 

83 

84 

85class TestPingDriver(unittest.TestCase): 

86 """Test ping.""" 

87 

88 def testWmsServiceSuccess(self): 

89 retval = ping_driver("wms_test_utils.WmsServiceSuccess") 

90 self.assertEqual(retval, 0) 

91 

92 def testWmsServiceFailure(self): 

93 with self.assertLogs(level=logging.ERROR) as cm: 

94 retval = ping_driver("wms_test_utils.WmsServiceFailure") 

95 self.assertNotEqual(retval, 0) 

96 self.assertEqual(cm.records[0].getMessage(), "Couldn't contact service X") 

97 

98 def testWmsServiceEnvVar(self): 

99 with unittest.mock.patch.dict( 

100 os.environ, {"BPS_WMS_SERVICE_CLASS": "wms_test_utils.WmsServiceSuccess"} 

101 ): 

102 retval = ping_driver() 

103 self.assertEqual(retval, 0) 

104 

105 @unittest.mock.patch( 

106 "lsst.ctrl.bps.drivers.BPS_DEFAULTS", {"wmsServiceClass": "wms_test_utils.WmsServiceDefault"} 

107 ) 

108 def testWmsServiceNone(self): 

109 with unittest.mock.patch.dict(os.environ, {}): 

110 with self.assertLogs(level=logging.INFO) as cm: 

111 retval = ping_driver() 

112 self.assertEqual(retval, 0) 

113 self.assertEqual(cm.records[0].getMessage(), "DEFAULT None") 

114 

115 def testWmsServicePassThru(self): 

116 with self.assertLogs(level=logging.INFO) as cm: 

117 retval = ping_driver("wms_test_utils.WmsServicePassThru", "EXTRA_VALUES") 

118 self.assertEqual(retval, 0) 

119 self.assertRegex(cm.output[0], "INFO.+EXTRA_VALUES") 

120 

121 

122class TestStatusDriver(unittest.TestCase): 

123 """Test status_driver function.""" 

124 

125 def testWmsServiceSuccess(self): 

126 with self.assertLogs(level=logging.INFO) as cm: 

127 retval = status_driver("wms_test_utils.WmsServiceSuccess", run_id="/dummy/path", hist_days=3) 

128 self.assertEqual(retval, WmsStates.SUCCEEDED.value) 

129 self.assertEqual(cm.records[0].getMessage(), "status: SUCCEEDED") 

130 

131 def testWmsServiceFailure(self): 

132 with self.assertLogs(level=logging.WARNING) as cm: 

133 retval = status_driver("wms_test_utils.WmsServiceFailure", run_id="/dummy/path", hist_days=3) 

134 self.assertEqual(retval, WmsStates.FAILED.value) 

135 self.assertEqual(cm.records[0].getMessage(), "Dummy error message.") 

136 

137 @unittest.mock.patch( 

138 "lsst.ctrl.bps.drivers.BPS_DEFAULTS", {"wmsServiceClass": "wms_test_utils.WmsServiceDefault"} 

139 ) 

140 def testWmsServiceNone(self): 

141 with unittest.mock.patch.dict(os.environ, {}): 

142 retval = status_driver(None, run_id="/dummy/path", hist_days=3) 

143 self.assertEqual(retval, WmsStates.RUNNING.value) 

144 

145 

146class TestReportDriver(unittest.TestCase): 

147 """Test report_driver function.""" 

148 

149 @unittest.mock.patch( 

150 "lsst.ctrl.bps.drivers.BPS_DEFAULTS", new={"wmsServiceClass": "wms_test_utils.WmsServiceSuccess"} 

151 ) 

152 def testWmsServiceFromDefaults(self): 

153 # Should not raise an exception and use default from BPS_DEFAULTS. 

154 with unittest.mock.patch.dict(os.environ, {}, clear=True): 

155 report_driver( 

156 wms_service=None, 

157 run_id=None, 

158 user=None, 

159 hist_days=0, 

160 pass_thru=None, 

161 ) 

162 

163 def testWmsServiceFromEnvVar(self): 

164 # Should not raise an exception. 

165 with unittest.mock.patch.dict( 

166 os.environ, {"BPS_WMS_SERVICE_CLASS": "wms_test_utils.WmsServiceSuccess"} 

167 ): 

168 report_driver( 

169 wms_service=None, 

170 run_id=None, 

171 user=None, 

172 hist_days=0.0, 

173 pass_thru=None, 

174 ) 

175 

176 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

177 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

178 def testHistDefault(self, mock_display, mock_retrieve): 

179 mock_retrieve.return_value = ([], []) 

180 

181 report_driver( 

182 wms_service="wms_test_utils.WmsServiceSuccess", 

183 run_id="123", 

184 user=None, 

185 hist_days=0.0, 

186 pass_thru=None, 

187 ) 

188 

189 # Verify retrieve_report was called with the default hist setting. 

190 _, kwargs = mock_retrieve.call_args 

191 self.assertAlmostEqual(kwargs["hist"], 2.0) 

192 

193 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

194 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

195 def testHistCustom(self, mock_display, mock_retrieve): 

196 mock_retrieve.return_value = ([], []) 

197 

198 report_driver( 

199 wms_service="wms_test_utils.WmsServiceSuccess", 

200 run_id="123", 

201 user=None, 

202 hist_days=4.0, 

203 pass_thru=None, 

204 ) 

205 

206 # Verify retrieve_report was called with a custom hist setting. 

207 _, kwargs = mock_retrieve.call_args 

208 self.assertAlmostEqual(kwargs["hist"], 4.0) 

209 

210 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

211 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

212 def testPostprocessorsWithoutExitCodes(self, mock_display, mock_retrieve): 

213 mock_retrieve.return_value = ([], []) 

214 

215 report_driver( 

216 wms_service="wms_test_utils.WmsServiceSuccess", 

217 run_id="123", 

218 user=None, 

219 hist_days=0.0, 

220 pass_thru=None, 

221 return_exit_codes=False, 

222 ) 

223 

224 # Verify the postprocessors list contains only one postprocessor. 

225 args, kwargs = mock_retrieve.call_args 

226 self.assertEqual(len(kwargs["postprocessors"]), 1) 

227 self.assertIn(compile_job_summary, kwargs["postprocessors"]) 

228 

229 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

230 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

231 def testPostprocessorsWithExitCodes(self, mock_display, mock_retrieve): 

232 mock_retrieve.return_value = ([], []) 

233 

234 report_driver( 

235 wms_service="wms_test_utils.WmsServiceSuccess", 

236 run_id="123", 

237 user=None, 

238 hist_days=0.0, 

239 pass_thru=None, 

240 return_exit_codes=True, 

241 ) 

242 

243 # Verify the postprocessors list contains both postprocessors. 

244 _, kwargs = mock_retrieve.call_args 

245 self.assertEqual(len(kwargs["postprocessors"]), 2) 

246 self.assertIn(compile_code_summary, kwargs["postprocessors"]) 

247 self.assertIn(compile_job_summary, kwargs["postprocessors"]) 

248 

249 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

250 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

251 def testPostprocessorsNoRunId(self, mock_display, mock_retrieve): 

252 mock_retrieve.return_value = ([], []) 

253 

254 report_driver( 

255 wms_service="wms_test_utils.WmsServiceSuccess", 

256 run_id=None, 

257 user=None, 

258 hist_days=0.0, 

259 pass_thru=None, 

260 ) 

261 

262 # Verify postprocessors contains compile_job_summary 

263 _, kwargs = mock_retrieve.call_args 

264 self.assertIsNone(kwargs["postprocessors"]) 

265 

266 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

267 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

268 def testDisplayCalledIfRuns(self, mock_display, mock_retrieve): 

269 mock_runs = [WmsRunReport(wms_id="1", state=WmsStates.SUCCEEDED)] 

270 mock_retrieve.return_value = (mock_runs, []) 

271 

272 report_driver( 

273 wms_service="wms_test_utils.WmsServiceSuccess", 

274 run_id=None, 

275 user=None, 

276 hist_days=0, 

277 pass_thru=None, 

278 ) 

279 

280 # Verify display_report was called with the runs 

281 mock_display.assert_called_once() 

282 args, kwargs = mock_display.call_args 

283 self.assertEqual(args[0], mock_runs) 

284 

285 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

286 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

287 def testDisplayCalledIfMessages(self, mock_display, mock_retrieve): 

288 mock_messages = ["Warning message 1", "Warning message 2"] 

289 mock_retrieve.return_value = ([], mock_messages) 

290 

291 report_driver( 

292 wms_service="wms_test_utils.WmsServiceSuccess", 

293 run_id=None, 

294 user=None, 

295 hist_days=0, 

296 pass_thru=None, 

297 ) 

298 

299 # Verify display_report was called with messages 

300 mock_display.assert_called_once() 

301 args, kwargs = mock_display.call_args 

302 self.assertEqual(args[1], mock_messages) 

303 

304 @unittest.mock.patch("lsst.ctrl.bps.drivers.retrieve_report") 

305 @unittest.mock.patch("lsst.ctrl.bps.drivers.display_report") 

306 @unittest.mock.patch("builtins.print") 

307 def testNoRecordsFoundMessage(self, mock_print, mock_display, mock_retrieve): 

308 mock_retrieve.return_value = ([], []) 

309 

310 report_driver( 

311 wms_service="wms_test_utils.WmsServiceSuccess", 

312 run_id="123", 

313 user=None, 

314 hist_days=1.5, 

315 pass_thru=None, 

316 ) 

317 

318 # Verify display_report() was NOT called. 

319 mock_display.assert_not_called() 

320 

321 # Verify that a helpful message was printed. 

322 mock_print.assert_called_once() 

323 call_args = mock_print.call_args[0][0] 

324 self.assertIn("No records found", call_args) 

325 self.assertIn("123", call_args) 

326 

327 

328if __name__ == "__main__": 

329 unittest.main()