Coverage for tests/test_executors.py: 16%

418 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-25 10:56 +0000

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for cmdLineFwk module. 

29""" 

30 

31import faulthandler 

32import logging 

33import multiprocessing 

34import os 

35import signal 

36import sys 

37import time 

38import unittest 

39import warnings 

40from multiprocessing import Manager 

41 

42import networkx as nx 

43import psutil 

44from lsst.ctrl.mpexec import ( 

45 ExecutionStatus, 

46 MPGraphExecutor, 

47 MPGraphExecutorError, 

48 MPTimeoutError, 

49 QuantumExecutor, 

50 QuantumReport, 

51 SingleQuantumExecutor, 

52) 

53from lsst.ctrl.mpexec.execFixupDataId import ExecFixupDataId 

54from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

55from lsst.pipe.base import NodeId 

56from lsst.pipe.base.tests.simpleQGraph import AddTaskFactoryMock, makeSimpleQGraph 

57 

58logging.basicConfig(level=logging.DEBUG) 

59 

60_LOG = logging.getLogger(__name__) 

61 

62TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

63 

64 

65class QuantumExecutorMock(QuantumExecutor): 

66 """Mock class for QuantumExecutor. 

67 

68 Parameters 

69 ---------- 

70 mp : `bool` 

71 Whether the mock should use multiprocessing or not. 

72 """ 

73 

74 def __init__(self, mp=False): 

75 self.quanta = [] 

76 if mp: 

77 # in multiprocess mode use shared list 

78 manager = Manager() 

79 self.quanta = manager.list() 

80 self.report = None 

81 self._execute_called = False 

82 

83 def execute(self, taskDef, quantum): 

84 _LOG.debug("QuantumExecutorMock.execute: taskDef=%s dataId=%s", taskDef, quantum.dataId) 

85 self._execute_called = True 

86 if taskDef.taskClass: 

87 try: 

88 # only works for one of the TaskMock classes below 

89 taskDef.taskClass().runQuantum() 

90 self.report = QuantumReport(dataId=quantum.dataId, taskLabel=taskDef.label) 

91 except Exception as exc: 

92 self.report = QuantumReport.from_exception( 

93 exception=exc, 

94 dataId=quantum.dataId, 

95 taskLabel=taskDef.label, 

96 ) 

97 raise 

98 self.quanta.append(quantum) 

99 return quantum 

100 

101 def getReport(self): 

102 if not self._execute_called: 

103 raise RuntimeError("getReport called before execute") 

104 return self.report 

105 

106 def getDataIds(self, field): 

107 """Return values for dataId field for each visited quanta. 

108 

109 Parameters 

110 ---------- 

111 field : `str` 

112 Field to select. 

113 """ 

114 return [quantum.dataId[field] for quantum in self.quanta] 

115 

116 

117class QuantumMock: 

118 """Mock equivalent of a `~lsst.daf.butler.Quantum`. 

119 

120 Parameters 

121 ---------- 

122 dataId : `dict` 

123 The Data ID of this quantum. 

124 """ 

125 

126 def __init__(self, dataId): 

127 self.dataId = dataId 

128 

129 def __eq__(self, other): 

130 return self.dataId == other.dataId 

131 

132 def __hash__(self): 

133 # dict.__eq__ is order-insensitive 

134 return hash(tuple(sorted(kv for kv in self.dataId.items()))) 

135 

136 

137class QuantumIterDataMock: 

138 """Simple class to mock QuantumIterData. 

139 

140 Parameters 

141 ---------- 

142 index : `int` 

143 The index of this mock. 

144 taskDef : `TaskDefMock` 

145 Mocked task definition. 

146 **dataId : `~typing.Any` 

147 The data ID of the mocked quantum. 

148 """ 

149 

150 def __init__(self, index, taskDef, **dataId): 

151 self.index = index 

152 self.taskDef = taskDef 

153 self.quantum = QuantumMock(dataId) 

154 self.dependencies = set() 

155 self.nodeId = NodeId(index, "DummyBuildString") 

156 

157 

158class QuantumGraphMock: 

159 """Mock for quantum graph. 

160 

161 Parameters 

162 ---------- 

163 qdata : `~collections.abc.Iterable` of `QuantumIterDataMock` 

164 The nodes of the graph. 

165 """ 

166 

167 def __init__(self, qdata): 

168 self._graph = nx.DiGraph() 

169 previous = qdata[0] 

170 for node in qdata[1:]: 

171 self._graph.add_edge(previous, node) 

172 previous = node 

173 

174 def __iter__(self): 

175 yield from nx.topological_sort(self._graph) 

176 

177 def __len__(self): 

178 return len(self._graph) 

179 

180 def findTaskDefByLabel(self, label): 

181 for q in self: 

182 if q.taskDef.label == label: 

183 return q.taskDef 

184 

185 def getQuantaForTask(self, taskDef): 

186 nodes = self.getNodesForTask(taskDef) 

187 return {q.quantum for q in nodes} 

188 

189 def getNodesForTask(self, taskDef): 

190 quanta = set() 

191 for q in self: 

192 if q.taskDef == taskDef: 

193 quanta.add(q) 

194 return quanta 

195 

196 @property 

197 def graph(self): 

198 return self._graph 

199 

200 def findCycle(self): 

201 return [] 

202 

203 def determineInputsToQuantumNode(self, node): 

204 result = set() 

205 for n in node.dependencies: 

206 for otherNode in self: 

207 if otherNode.index == n: 

208 result.add(otherNode) 

209 return result 

210 

211 

212class TaskMockMP: 

213 """Simple mock class for task supporting multiprocessing.""" 

214 

215 canMultiprocess = True 

216 

217 def runQuantum(self): 

218 _LOG.debug("TaskMockMP.runQuantum") 

219 pass 

220 

221 

222class TaskMockFail: 

223 """Simple mock class for task which fails.""" 

224 

225 canMultiprocess = True 

226 

227 def runQuantum(self): 

228 _LOG.debug("TaskMockFail.runQuantum") 

229 raise ValueError("expected failure") 

230 

231 

232class TaskMockCrash: 

233 """Simple mock class for task which fails.""" 

234 

235 canMultiprocess = True 

236 

237 def runQuantum(self): 

238 _LOG.debug("TaskMockCrash.runQuantum") 

239 # Disable fault handler to suppress long scary traceback. 

240 faulthandler.disable() 

241 signal.raise_signal(signal.SIGILL) 

242 

243 

244class TaskMockLongSleep: 

245 """Simple mock class for task which "runs" for very long time.""" 

246 

247 canMultiprocess = True 

248 

249 def runQuantum(self): 

250 _LOG.debug("TaskMockLongSleep.runQuantum") 

251 time.sleep(100.0) 

252 

253 

254class TaskMockNoMP: 

255 """Simple mock class for task not supporting multiprocessing.""" 

256 

257 canMultiprocess = False 

258 

259 

260class TaskDefMock: 

261 """Simple mock class for task definition in a pipeline. 

262 

263 Parameters 

264 ---------- 

265 taskName : `str`, optional 

266 The name of the task. 

267 config : `PipelineTaskConfig`, optional 

268 Config to use for this task. 

269 taskClass : `type`, optional 

270 The class of the task. 

271 label : `str`, optional 

272 Task label. 

273 """ 

274 

275 def __init__(self, taskName="Task", config=None, taskClass=TaskMockMP, label="task1"): 

276 self.taskName = taskName 

277 self.config = config 

278 self.taskClass = taskClass 

279 self.label = label 

280 

281 def __str__(self): 

282 return f"TaskDefMock(taskName={self.taskName}, taskClass={self.taskClass.__name__})" 

283 

284 

285def _count_status(report, status): 

286 """Count number of quanta witha a given status.""" 

287 return len([qrep for qrep in report.quantaReports if qrep.status is status]) 

288 

289 

290class MPGraphExecutorTestCase(unittest.TestCase): 

291 """A test case for MPGraphExecutor class.""" 

292 

293 def test_mpexec_nomp(self): 

294 """Make simple graph and execute.""" 

295 taskDef = TaskDefMock() 

296 qgraph = QuantumGraphMock( 

297 [QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3)] 

298 ) 

299 

300 # run in single-process mode 

301 qexec = QuantumExecutorMock() 

302 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec) 

303 mpexec.execute(qgraph) 

304 self.assertEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

305 report = mpexec.getReport() 

306 self.assertEqual(report.status, ExecutionStatus.SUCCESS) 

307 self.assertIsNone(report.exitCode) 

308 self.assertIsNone(report.exceptionInfo) 

309 self.assertEqual(len(report.quantaReports), 3) 

310 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports)) 

311 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports)) 

312 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

313 self.assertTrue(all(qrep.taskLabel == "task1" for qrep in report.quantaReports)) 

314 

315 def test_mpexec_mp(self): 

316 """Make simple graph and execute.""" 

317 taskDef = TaskDefMock() 

318 qgraph = QuantumGraphMock( 

319 [QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3)] 

320 ) 

321 

322 methods = ["spawn"] 

323 if sys.platform == "linux": 

324 methods.append("forkserver") 

325 

326 for method in methods: 

327 with self.subTest(startMethod=method): 

328 # Run in multi-process mode, the order of results is not 

329 # defined. 

330 qexec = QuantumExecutorMock(mp=True) 

331 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, startMethod=method) 

332 mpexec.execute(qgraph) 

333 self.assertCountEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

334 report = mpexec.getReport() 

335 self.assertEqual(report.status, ExecutionStatus.SUCCESS) 

336 self.assertIsNone(report.exitCode) 

337 self.assertIsNone(report.exceptionInfo) 

338 self.assertEqual(len(report.quantaReports), 3) 

339 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports)) 

340 self.assertTrue(all(qrep.exitCode == 0 for qrep in report.quantaReports)) 

341 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

342 self.assertTrue(all(qrep.taskLabel == "task1" for qrep in report.quantaReports)) 

343 

344 def test_mpexec_nompsupport(self): 

345 """Try to run MP for task that has no MP support which should fail.""" 

346 taskDef = TaskDefMock(taskClass=TaskMockNoMP) 

347 qgraph = QuantumGraphMock( 

348 [QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3)] 

349 ) 

350 

351 # run in multi-process mode 

352 qexec = QuantumExecutorMock() 

353 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

354 with self.assertRaisesRegex(MPGraphExecutorError, "Task Task does not support multiprocessing"): 

355 mpexec.execute(qgraph) 

356 

357 def test_mpexec_fixup(self): 

358 """Make simple graph and execute, add dependencies by executing fixup 

359 code. 

360 """ 

361 taskDef = TaskDefMock() 

362 

363 for reverse in (False, True): 

364 qgraph = QuantumGraphMock( 

365 [QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(3)] 

366 ) 

367 

368 qexec = QuantumExecutorMock() 

369 fixup = ExecFixupDataId("task1", "detector", reverse=reverse) 

370 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec, executionGraphFixup=fixup) 

371 mpexec.execute(qgraph) 

372 

373 expected = [0, 1, 2] 

374 if reverse: 

375 expected = list(reversed(expected)) 

376 self.assertEqual(qexec.getDataIds("detector"), expected) 

377 

378 def test_mpexec_timeout(self): 

379 """Fail due to timeout.""" 

380 taskDef = TaskDefMock() 

381 taskDefSleep = TaskDefMock(taskClass=TaskMockLongSleep) 

382 qgraph = QuantumGraphMock( 

383 [ 

384 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

385 QuantumIterDataMock(index=1, taskDef=taskDefSleep, detector=1), 

386 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

387 ] 

388 ) 

389 

390 # with failFast we'll get immediate MPTimeoutError 

391 qexec = QuantumExecutorMock(mp=True) 

392 mpexec = MPGraphExecutor(numProc=3, timeout=1, quantumExecutor=qexec, failFast=True) 

393 with self.assertRaises(MPTimeoutError): 

394 mpexec.execute(qgraph) 

395 report = mpexec.getReport() 

396 self.assertEqual(report.status, ExecutionStatus.TIMEOUT) 

397 self.assertEqual(report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPTimeoutError") 

398 self.assertGreater(len(report.quantaReports), 0) 

399 self.assertEqual(_count_status(report, ExecutionStatus.TIMEOUT), 1) 

400 self.assertTrue(any(qrep.exitCode < 0 for qrep in report.quantaReports)) 

401 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

402 

403 # with failFast=False exception happens after last task finishes 

404 qexec = QuantumExecutorMock(mp=True) 

405 mpexec = MPGraphExecutor(numProc=3, timeout=3, quantumExecutor=qexec, failFast=False) 

406 with self.assertRaises(MPTimeoutError): 

407 mpexec.execute(qgraph) 

408 # We expect two tasks (0 and 2) to finish successfully and one task to 

409 # timeout. Unfortunately on busy CPU there is no guarantee that tasks 

410 # finish on time, so expect more timeouts and issue a warning. 

411 detectorIds = set(qexec.getDataIds("detector")) 

412 self.assertLess(len(detectorIds), 3) 

413 if detectorIds != {0, 2}: 

414 warnings.warn(f"Possibly timed out tasks, expected [0, 2], received {detectorIds}") 

415 report = mpexec.getReport() 

416 self.assertEqual(report.status, ExecutionStatus.TIMEOUT) 

417 self.assertEqual(report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPTimeoutError") 

418 self.assertGreater(len(report.quantaReports), 0) 

419 self.assertGreater(_count_status(report, ExecutionStatus.TIMEOUT), 0) 

420 self.assertTrue(any(qrep.exitCode < 0 for qrep in report.quantaReports)) 

421 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

422 

423 def test_mpexec_failure(self): 

424 """Failure in one task should not stop other tasks.""" 

425 taskDef = TaskDefMock() 

426 taskDefFail = TaskDefMock(taskClass=TaskMockFail) 

427 qgraph = QuantumGraphMock( 

428 [ 

429 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

430 QuantumIterDataMock(index=1, taskDef=taskDefFail, detector=1), 

431 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

432 ] 

433 ) 

434 

435 qexec = QuantumExecutorMock(mp=True) 

436 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

437 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

438 mpexec.execute(qgraph) 

439 self.assertCountEqual(qexec.getDataIds("detector"), [0, 2]) 

440 report = mpexec.getReport() 

441 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

442 self.assertEqual( 

443 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

444 ) 

445 self.assertGreater(len(report.quantaReports), 0) 

446 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

447 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

448 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

449 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

450 

451 def test_mpexec_failure_dep(self): 

452 """Failure in one task should skip dependents.""" 

453 taskDef = TaskDefMock() 

454 taskDefFail = TaskDefMock(taskClass=TaskMockFail) 

455 qdata = [ 

456 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

457 QuantumIterDataMock(index=1, taskDef=taskDefFail, detector=1), 

458 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

459 QuantumIterDataMock(index=3, taskDef=taskDef, detector=3), 

460 QuantumIterDataMock(index=4, taskDef=taskDef, detector=4), 

461 ] 

462 qdata[2].dependencies.add(1) 

463 qdata[4].dependencies.add(3) 

464 qdata[4].dependencies.add(2) 

465 

466 qgraph = QuantumGraphMock(qdata) 

467 

468 qexec = QuantumExecutorMock(mp=True) 

469 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

470 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

471 mpexec.execute(qgraph) 

472 self.assertCountEqual(qexec.getDataIds("detector"), [0, 3]) 

473 report = mpexec.getReport() 

474 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

475 self.assertEqual( 

476 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

477 ) 

478 # Dependencies of failed tasks do not appear in quantaReports 

479 self.assertGreater(len(report.quantaReports), 0) 

480 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

481 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

482 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 2) 

483 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

484 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

485 

486 def test_mpexec_failure_dep_nomp(self): 

487 """Failure in one task should skip dependents, in-process version.""" 

488 taskDef = TaskDefMock() 

489 taskDefFail = TaskDefMock(taskClass=TaskMockFail) 

490 qdata = [ 

491 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

492 QuantumIterDataMock(index=1, taskDef=taskDefFail, detector=1), 

493 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

494 QuantumIterDataMock(index=3, taskDef=taskDef, detector=3), 

495 QuantumIterDataMock(index=4, taskDef=taskDef, detector=4), 

496 ] 

497 qdata[2].dependencies.add(1) 

498 qdata[4].dependencies.add(3) 

499 qdata[4].dependencies.add(2) 

500 

501 qgraph = QuantumGraphMock(qdata) 

502 

503 qexec = QuantumExecutorMock() 

504 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec) 

505 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

506 mpexec.execute(qgraph) 

507 self.assertCountEqual(qexec.getDataIds("detector"), [0, 3]) 

508 report = mpexec.getReport() 

509 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

510 self.assertEqual( 

511 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

512 ) 

513 # Dependencies of failed tasks do not appear in quantaReports 

514 self.assertGreater(len(report.quantaReports), 0) 

515 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

516 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

517 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 2) 

518 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports)) 

519 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

520 

521 def test_mpexec_failure_failfast(self): 

522 """Fast fail stops quickly. 

523 

524 Timing delay of task #3 should be sufficient to process 

525 failure and raise exception. 

526 """ 

527 taskDef = TaskDefMock() 

528 taskDefFail = TaskDefMock(taskClass=TaskMockFail) 

529 taskDefLongSleep = TaskDefMock(taskClass=TaskMockLongSleep) 

530 qdata = [ 

531 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

532 QuantumIterDataMock(index=1, taskDef=taskDefFail, detector=1), 

533 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

534 QuantumIterDataMock(index=3, taskDef=taskDefLongSleep, detector=3), 

535 QuantumIterDataMock(index=4, taskDef=taskDef, detector=4), 

536 ] 

537 qdata[1].dependencies.add(0) 

538 qdata[2].dependencies.add(1) 

539 qdata[4].dependencies.add(3) 

540 qdata[4].dependencies.add(2) 

541 

542 qgraph = QuantumGraphMock(qdata) 

543 

544 qexec = QuantumExecutorMock(mp=True) 

545 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, failFast=True) 

546 with self.assertRaisesRegex(MPGraphExecutorError, "failed, exit code=1"): 

547 mpexec.execute(qgraph) 

548 self.assertCountEqual(qexec.getDataIds("detector"), [0]) 

549 report = mpexec.getReport() 

550 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

551 self.assertEqual( 

552 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

553 ) 

554 # Dependencies of failed tasks do not appear in quantaReports 

555 self.assertGreater(len(report.quantaReports), 0) 

556 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

557 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

558 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

559 

560 def test_mpexec_crash(self): 

561 """Check task crash due to signal.""" 

562 taskDef = TaskDefMock() 

563 taskDefCrash = TaskDefMock(taskClass=TaskMockCrash) 

564 qgraph = QuantumGraphMock( 

565 [ 

566 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

567 QuantumIterDataMock(index=1, taskDef=taskDefCrash, detector=1), 

568 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

569 ] 

570 ) 

571 

572 qexec = QuantumExecutorMock(mp=True) 

573 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

574 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

575 mpexec.execute(qgraph) 

576 report = mpexec.getReport() 

577 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

578 self.assertEqual( 

579 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

580 ) 

581 # Dependencies of failed tasks do not appear in quantaReports 

582 self.assertGreater(len(report.quantaReports), 0) 

583 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

584 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

585 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports)) 

586 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

587 

588 def test_mpexec_crash_failfast(self): 

589 """Check task crash due to signal with --fail-fast.""" 

590 taskDef = TaskDefMock() 

591 taskDefCrash = TaskDefMock(taskClass=TaskMockCrash) 

592 qgraph = QuantumGraphMock( 

593 [ 

594 QuantumIterDataMock(index=0, taskDef=taskDef, detector=0), 

595 QuantumIterDataMock(index=1, taskDef=taskDefCrash, detector=1), 

596 QuantumIterDataMock(index=2, taskDef=taskDef, detector=2), 

597 ] 

598 ) 

599 

600 qexec = QuantumExecutorMock(mp=True) 

601 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, failFast=True) 

602 with self.assertRaisesRegex(MPGraphExecutorError, "failed, killed by signal 4 .Illegal instruction"): 

603 mpexec.execute(qgraph) 

604 report = mpexec.getReport() 

605 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

606 self.assertEqual( 

607 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

608 ) 

609 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

610 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports)) 

611 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

612 

613 def test_mpexec_num_fd(self): 

614 """Check that number of open files stays reasonable.""" 

615 taskDef = TaskDefMock() 

616 qgraph = QuantumGraphMock( 

617 [QuantumIterDataMock(index=i, taskDef=taskDef, detector=i) for i in range(20)] 

618 ) 

619 

620 this_proc = psutil.Process() 

621 num_fds_0 = this_proc.num_fds() 

622 

623 # run in multi-process mode, the order of results is not defined 

624 qexec = QuantumExecutorMock(mp=True) 

625 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

626 mpexec.execute(qgraph) 

627 

628 num_fds_1 = this_proc.num_fds() 

629 # They should be the same but allow small growth just in case. 

630 # Without DM-26728 fix the difference would be equal to number of 

631 # quanta (20). 

632 self.assertLess(num_fds_1 - num_fds_0, 5) 

633 

634 

635class SingleQuantumExecutorTestCase(unittest.TestCase): 

636 """Tests for SingleQuantumExecutor implementation.""" 

637 

638 instrument = "lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

639 

640 def setUp(self): 

641 self.root = makeTestTempDir(TESTDIR) 

642 

643 def tearDown(self): 

644 removeTestTempDir(self.root) 

645 

646 def test_simple_execute(self) -> None: 

647 """Run execute() method in simplest setup.""" 

648 nQuanta = 1 

649 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

650 

651 nodes = list(qgraph) 

652 self.assertEqual(len(nodes), nQuanta) 

653 node = nodes[0] 

654 

655 taskFactory = AddTaskFactoryMock() 

656 executor = SingleQuantumExecutor(butler, taskFactory) 

657 executor.execute(node.taskDef, node.quantum) 

658 self.assertEqual(taskFactory.countExec, 1) 

659 

660 # There must be one dataset of task's output connection 

661 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

662 self.assertEqual(len(refs), 1) 

663 

664 def test_skip_existing_execute(self) -> None: 

665 """Run execute() method twice, with skip_existing_in.""" 

666 nQuanta = 1 

667 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

668 

669 nodes = list(qgraph) 

670 self.assertEqual(len(nodes), nQuanta) 

671 node = nodes[0] 

672 

673 taskFactory = AddTaskFactoryMock() 

674 executor = SingleQuantumExecutor(butler, taskFactory) 

675 executor.execute(node.taskDef, node.quantum) 

676 self.assertEqual(taskFactory.countExec, 1) 

677 

678 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

679 self.assertEqual(len(refs), 1) 

680 dataset_id_1 = refs[0].id 

681 

682 # Re-run it with skipExistingIn, it should not run. 

683 assert butler.run is not None 

684 executor = SingleQuantumExecutor(butler, taskFactory, skipExistingIn=[butler.run]) 

685 executor.execute(node.taskDef, node.quantum) 

686 self.assertEqual(taskFactory.countExec, 1) 

687 

688 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

689 self.assertEqual(len(refs), 1) 

690 dataset_id_2 = refs[0].id 

691 self.assertEqual(dataset_id_1, dataset_id_2) 

692 

693 def test_clobber_outputs_execute(self) -> None: 

694 """Run execute() method twice, with clobber_outputs.""" 

695 nQuanta = 1 

696 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

697 

698 nodes = list(qgraph) 

699 self.assertEqual(len(nodes), nQuanta) 

700 node = nodes[0] 

701 

702 taskFactory = AddTaskFactoryMock() 

703 executor = SingleQuantumExecutor(butler, taskFactory) 

704 executor.execute(node.taskDef, node.quantum) 

705 self.assertEqual(taskFactory.countExec, 1) 

706 

707 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

708 self.assertEqual(len(refs), 1) 

709 dataset_id_1 = refs[0].id 

710 

711 original_dataset = butler.get(refs[0]) 

712 

713 # Remove the dataset ourself, and replace it with something 

714 # different so we can check later whether it got replaced. 

715 butler.pruneDatasets([refs[0]], disassociate=False, unstore=True, purge=False) 

716 replacement = original_dataset + 10 

717 butler.put(replacement, refs[0]) 

718 

719 # Re-run it with clobberOutputs and skipExistingIn, it should not 

720 # clobber but should skip instead. 

721 assert butler.run is not None 

722 executor = SingleQuantumExecutor( 

723 butler, taskFactory, skipExistingIn=[butler.run], clobberOutputs=True 

724 ) 

725 executor.execute(node.taskDef, node.quantum) 

726 self.assertEqual(taskFactory.countExec, 1) 

727 

728 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

729 self.assertEqual(len(refs), 1) 

730 dataset_id_2 = refs[0].id 

731 self.assertEqual(dataset_id_1, dataset_id_2) 

732 

733 second_dataset = butler.get(refs[0]) 

734 self.assertEqual(list(second_dataset), list(replacement)) 

735 

736 # Re-run it with clobberOutputs but without skipExistingIn, it should 

737 # clobber. 

738 assert butler.run is not None 

739 executor = SingleQuantumExecutor(butler, taskFactory, clobberOutputs=True) 

740 executor.execute(node.taskDef, node.quantum) 

741 self.assertEqual(taskFactory.countExec, 2) 

742 

743 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

744 self.assertEqual(len(refs), 1) 

745 dataset_id_3 = refs[0].id 

746 

747 third_dataset = butler.get(refs[0]) 

748 self.assertEqual(list(third_dataset), list(original_dataset)) 

749 

750 # No change in UUID even after replacement 

751 self.assertEqual(dataset_id_1, dataset_id_3) 

752 

753 

754def setup_module(module): 

755 """Force spawn to be used if no method given explicitly. 

756 

757 This can be removed when Python 3.14 changes the default. 

758 

759 Parameters 

760 ---------- 

761 module : `~types.ModuleType` 

762 Module to set up. 

763 """ 

764 multiprocessing.set_start_method("spawn", force=True) 

765 

766 

767if __name__ == "__main__": 

768 # Do not need to force start mode when running standalone. 

769 multiprocessing.set_start_method("spawn") 

770 unittest.main()