Coverage for tests/test_task.py: 26%

207 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-03 09:57 +0000

1# 

2# LSST Data Management System 

3# Copyright 2008, 2009, 2010 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22import json 

23import logging 

24import numbers 

25import time 

26import unittest 

27 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30import lsst.utils.tests 

31import yaml 

32 

33# Whilst in transition the test can't tell which type is 

34# going to be used for metadata. 

35from lsst.pipe.base.task import _TASK_METADATA_TYPE 

36from lsst.utils.timer import timeMethod 

37 

38 

39class AddConfig(pexConfig.Config): 

40 """Config for AddTask.""" 

41 

42 addend = pexConfig.Field(doc="amount to add", dtype=float, default=3.1) 

43 

44 

45class AddTask(pipeBase.Task): 

46 """Example task to add two values.""" 

47 

48 ConfigClass = AddConfig 

49 

50 @timeMethod 

51 def run(self, val): 

52 self.metadata.add("add", self.config.addend) 

53 return pipeBase.Struct( 

54 val=val + self.config.addend, 

55 ) 

56 

57 

58class MultConfig(pexConfig.Config): 

59 """Config for MultTask.""" 

60 

61 multiplicand = pexConfig.Field(doc="amount by which to multiply", dtype=float, default=2.5) 

62 

63 

64class MultTask(pipeBase.Task): 

65 """Task to multiply.""" 

66 

67 ConfigClass = MultConfig 

68 

69 @timeMethod 

70 def run(self, val): 

71 self.metadata.add("mult", self.config.multiplicand) 

72 return pipeBase.Struct( 

73 val=val * self.config.multiplicand, 

74 ) 

75 

76 

77# prove that registry fields can also be used to hold subtasks 

78# by using a registry to hold MultTask 

79multRegistry = pexConfig.makeRegistry("Registry for Mult-like tasks") 

80multRegistry.register("stdMult", MultTask) 

81 

82 

83class AddMultConfig(pexConfig.Config): 

84 """Config for AddMult.""" 

85 

86 add = AddTask.makeField("add task") 

87 mult = multRegistry.makeField("mult task", default="stdMult") 

88 

89 

90class AddMultTask(pipeBase.Task): 

91 """Test Task with subtasks.""" 

92 

93 ConfigClass = AddMultConfig 

94 _DefaultName = "addMult" 

95 _add_module_logger_prefix = False 

96 

97 """First add, then multiply.""" 

98 

99 def __init__(self, **keyArgs): 

100 pipeBase.Task.__init__(self, **keyArgs) 

101 self.makeSubtask("add") 

102 self.makeSubtask("mult") 

103 

104 @timeMethod 

105 def run(self, val): 

106 with self.timer("context"): 

107 addRet = self.add.run(val) 

108 multRet = self.mult.run(addRet.val) 

109 self.metadata.add("addmult", multRet.val) 

110 return pipeBase.Struct( 

111 val=multRet.val, 

112 ) 

113 

114 @timeMethod 

115 def failDec(self): 

116 """Fail with a decorator.""" 

117 raise RuntimeError("failDec intentional error") 

118 

119 def failCtx(self): 

120 """Fail inside a context manager.""" 

121 with self.timer("failCtx"): 

122 raise RuntimeError("failCtx intentional error") 

123 

124 

125class AddMultTask2(AddMultTask): 

126 """Subclass that gets an automatic logger prefix.""" 

127 

128 _add_module_logger_prefix = True 

129 

130 

131class AddTwiceTask(AddTask): 

132 """Variant of AddTask that adds twice the addend.""" 

133 

134 def run(self, val): 

135 addend = self.config.addend 

136 return pipeBase.Struct(val=val + (2 * addend)) 

137 

138 

139class TaskTestCase(unittest.TestCase): 

140 """A test case for Task.""" 

141 

142 def testBasics(self): 

143 """Test basic construction and use of a task.""" 

144 for addend in (1.1, -3.5): 

145 for multiplicand in (0.9, -45.0): 

146 config = AddMultTask.ConfigClass() 

147 config.add.addend = addend 

148 config.mult["stdMult"].multiplicand = multiplicand 

149 # make sure both ways of accessing the registry work and give 

150 # the same result 

151 self.assertEqual(config.mult.active.multiplicand, multiplicand) 

152 addMultTask = AddMultTask(config=config) 

153 for val in (-1.0, 0.0, 17.5): 

154 ret = addMultTask.run(val=val) 

155 self.assertAlmostEqual(ret.val, (val + addend) * multiplicand) 

156 

157 def testNames(self): 

158 """Test getName() and getFullName().""" 

159 addMultTask = AddMultTask() 

160 self.assertEqual(addMultTask.getName(), "addMult") 

161 self.assertEqual(addMultTask.add.getName(), "add") 

162 self.assertEqual(addMultTask.mult.getName(), "mult") 

163 

164 self.assertEqual(addMultTask._name, "addMult") 

165 self.assertEqual(addMultTask.add._name, "add") 

166 self.assertEqual(addMultTask.mult._name, "mult") 

167 

168 self.assertEqual(addMultTask.getFullName(), "addMult") 

169 self.assertEqual(addMultTask.add.getFullName(), "addMult.add") 

170 self.assertEqual(addMultTask.mult.getFullName(), "addMult.mult") 

171 

172 self.assertEqual(addMultTask._fullName, "addMult") 

173 self.assertEqual(addMultTask.add._fullName, "addMult.add") 

174 self.assertEqual(addMultTask.mult._fullName, "addMult.mult") 

175 

176 def testLog(self): 

177 """Test the Task's logger.""" 

178 addMultTask = AddMultTask() 

179 self.assertEqual(addMultTask.log.name, "addMult") 

180 self.assertEqual(addMultTask.add.log.name, "addMult.add") 

181 

182 log = logging.getLogger("tester") 

183 addMultTask = AddMultTask(log=log) 

184 self.assertEqual(addMultTask.log.name, "tester.addMult") 

185 self.assertEqual(addMultTask.add.log.name, "tester.addMult.add") 

186 

187 addMultTask2 = AddMultTask2() 

188 self.assertEqual(addMultTask2.log.name, f"{__name__}.addMult") 

189 

190 def testGetFullMetadata(self): 

191 """Test getFullMetadata().""" 

192 addMultTask = AddMultTask() 

193 addMultTask.run(val=1.234) # Add some metadata 

194 fullMetadata = addMultTask.getFullMetadata() 

195 self.assertIsInstance(fullMetadata["addMult"], _TASK_METADATA_TYPE) 

196 self.assertIsInstance(fullMetadata["addMult:add"], _TASK_METADATA_TYPE) 

197 self.assertIsInstance(fullMetadata["addMult:mult"], _TASK_METADATA_TYPE) 

198 self.assertEqual(set(fullMetadata), {"addMult", "addMult:add", "addMult:mult"}) 

199 

200 all_names = fullMetadata.names() 

201 self.assertIn("addMult", all_names) 

202 self.assertIn("addMult.runStartUtc", all_names) 

203 

204 param_names = fullMetadata.paramNames(topLevelOnly=True) 

205 # No top level keys without hierarchy 

206 self.assertEqual(set(param_names), set()) 

207 

208 param_names = fullMetadata.paramNames(topLevelOnly=False) 

209 self.assertNotIn("addMult", param_names) 

210 self.assertIn("addMult.runStartUtc", param_names) 

211 self.assertIn("addMult:add.runStartCpuTime", param_names) 

212 

213 def testEmptyMetadata(self): 

214 task = AddMultTask() 

215 task.run(val=1.2345) 

216 task.emptyMetadata() 

217 fullMetadata = task.getFullMetadata() 

218 self.assertEqual(len(fullMetadata["addMult"]), 0) 

219 self.assertEqual(len(fullMetadata["addMult:add"]), 0) 

220 self.assertEqual(len(fullMetadata["addMult:mult"]), 0) 

221 

222 def testReplace(self): 

223 """Test replacing one subtask with another.""" 

224 for addend in (1.1, -3.5): 

225 for multiplicand in (0.9, -45.0): 

226 config = AddMultTask.ConfigClass() 

227 config.add.retarget(AddTwiceTask) 

228 config.add.addend = addend 

229 config.mult["stdMult"].multiplicand = multiplicand 

230 addMultTask = AddMultTask(config=config) 

231 for val in (-1.0, 0.0, 17.5): 

232 ret = addMultTask.run(val=val) 

233 self.assertAlmostEqual(ret.val, (val + (2 * addend)) * multiplicand) 

234 

235 def testFail(self): 

236 """Test timers when the code they are timing fails.""" 

237 addMultTask = AddMultTask() 

238 try: 

239 addMultTask.failDec() 

240 self.fail("Expected RuntimeError") 

241 except RuntimeError: 

242 self.assertIn("failDecEndCpuTime", addMultTask.metadata) 

243 try: 

244 addMultTask.failCtx() 

245 self.fail("Expected RuntimeError") 

246 except RuntimeError: 

247 self.assertIn("failCtxEndCpuTime", addMultTask.metadata) 

248 

249 def testTimeMethod(self): 

250 """Test that the timer is adding the right metadata.""" 

251 addMultTask = AddMultTask() 

252 

253 # Run twice to ensure we are additive. 

254 addMultTask.run(val=1.1) 

255 addMultTask.run(val=2.0) 

256 # Check existence and type 

257 for key, keyType in ( 

258 ("Utc", str), 

259 ("CpuTime", float), 

260 ("UserTime", float), 

261 ("SystemTime", float), 

262 ("MaxResidentSetSize", numbers.Integral), 

263 ("MinorPageFaults", numbers.Integral), 

264 ("MajorPageFaults", numbers.Integral), 

265 ("BlockInputs", numbers.Integral), 

266 ("BlockOutputs", numbers.Integral), 

267 ("VoluntaryContextSwitches", numbers.Integral), 

268 ("InvoluntaryContextSwitches", numbers.Integral), 

269 ): 

270 for when in ("Start", "End"): 

271 for method in ("run", "context"): 

272 name = method + when + key 

273 self.assertIn(name, addMultTask.metadata, name + " is missing from task metadata") 

274 self.assertIsInstance( 

275 addMultTask.metadata.getScalar(name), 

276 keyType, 

277 f"{name} is not of the right type " 

278 f"({keyType} vs {type(addMultTask.metadata.getScalar(name))})", 

279 ) 

280 # Some basic sanity checks 

281 currCpuTime = time.process_time() 

282 self.assertLessEqual( 

283 addMultTask.metadata.getScalar("runStartCpuTime"), 

284 addMultTask.metadata.getScalar("runEndCpuTime"), 

285 ) 

286 self.assertLessEqual(addMultTask.metadata.getScalar("runEndCpuTime"), currCpuTime) 

287 self.assertLessEqual( 

288 addMultTask.metadata.getScalar("contextStartCpuTime"), 

289 addMultTask.metadata.getScalar("contextEndCpuTime"), 

290 ) 

291 self.assertLessEqual(addMultTask.metadata.getScalar("contextEndCpuTime"), currCpuTime) 

292 self.assertLessEqual( 

293 addMultTask.add.metadata.getScalar("runStartCpuTime"), 

294 addMultTask.metadata.getScalar("runEndCpuTime"), 

295 ) 

296 self.assertLessEqual(addMultTask.add.metadata.getScalar("runEndCpuTime"), currCpuTime) 

297 

298 # Add some explicit values for serialization test. 

299 addMultTask.metadata["comment"] = "A comment" 

300 addMultTask.metadata["integer"] = 5 

301 addMultTask.metadata["float"] = 3.14 

302 addMultTask.metadata["bool"] = False 

303 addMultTask.metadata.add("commentList", "comment1") 

304 addMultTask.metadata.add("commentList", "comment1") 

305 addMultTask.metadata.add("intList", 6) 

306 addMultTask.metadata.add("intList", 7) 

307 addMultTask.metadata.add("boolList", False) 

308 addMultTask.metadata.add("boolList", True) 

309 addMultTask.metadata.add("floatList", 6.6) 

310 addMultTask.metadata.add("floatList", 7.8) 

311 

312 # TaskMetadata can serialize to JSON but not YAML 

313 # and PropertySet can serialize to YAML and not JSON. 

314 if hasattr(addMultTask.metadata, "json"): 

315 j = addMultTask.metadata.model_dump_json() 

316 new_meta = pipeBase.TaskMetadata.model_validate(json.loads(j)) 

317 else: 

318 y = yaml.dump(addMultTask.metadata) 

319 new_meta = yaml.safe_load(y) 

320 self.assertEqual(new_meta, addMultTask.metadata) 

321 

322 def test_annotate_exception(self): 

323 """Test annotating failures in the task metadata when a non-Task 

324 exception is raised (when there is no `metadata` on the exception). 

325 """ 

326 task = AddMultTask() 

327 msg = "something failed!" 

328 error = ValueError(msg) 

329 with self.assertLogs("addMult", level="ERROR") as cm: 

330 pipeBase.AnnotatedPartialOutputsError.annotate(error, task, log=task.log) 

331 self.assertIn(msg, "\n".join(cm.output)) 

332 self.assertEqual(task.metadata["failure"]["message"], msg) 

333 self.assertEqual(task.metadata["failure"]["type"], "ValueError") 

334 self.assertNotIn("metadata", task.metadata["failure"]) 

335 

336 def test_annotate_task_exception(self): 

337 """Test annotating failures in the task metadata when a Task-specific 

338 exception is raised. 

339 """ 

340 

341 class TestError(pipeBase.AlgorithmError): 

342 @property 

343 def metadata(self): 

344 return {"something": 12345} 

345 

346 task = AddMultTask() 

347 msg = "something failed!" 

348 error = TestError(msg) 

349 with self.assertLogs("addMult", level="ERROR") as cm: 

350 pipeBase.AnnotatedPartialOutputsError.annotate(error, task, log=task.log) 

351 self.assertIn(msg, "\n".join(cm.output)) 

352 self.assertEqual(task.metadata["failure"]["message"], msg) 

353 result = "test_task.TaskTestCase.test_annotate_task_exception.<locals>.TestError" 

354 self.assertEqual(task.metadata["failure"]["type"], result) 

355 self.assertEqual(task.metadata["failure"]["metadata"]["something"], 12345) 

356 

357 

358class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

359 """Run file leak tests.""" 

360 

361 

362def setup_module(module): 

363 """Configure pytest.""" 

364 lsst.utils.tests.init() 

365 

366 

367if __name__ == "__main__": 

368 lsst.utils.tests.init() 

369 unittest.main()