Coverage for tests/test_executors.py: 16%

419 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-10 03:29 -0700

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for cmdLineFwk module. 

29""" 

30 

31import faulthandler 

32import logging 

33import multiprocessing 

34import os 

35import signal 

36import sys 

37import time 

38import unittest 

39import warnings 

40from multiprocessing import Manager 

41 

42import networkx as nx 

43import psutil 

44from lsst.ctrl.mpexec import ( 

45 ExecutionStatus, 

46 MPGraphExecutor, 

47 MPGraphExecutorError, 

48 MPTimeoutError, 

49 QuantumExecutor, 

50 QuantumReport, 

51 SingleQuantumExecutor, 

52) 

53from lsst.ctrl.mpexec.execFixupDataId import ExecFixupDataId 

54from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

55from lsst.pipe.base import NodeId 

56from lsst.pipe.base.tests.simpleQGraph import AddTaskFactoryMock, makeSimpleQGraph 

57 

58logging.basicConfig(level=logging.DEBUG) 

59 

60_LOG = logging.getLogger(__name__) 

61 

62TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

63 

64 

65class QuantumExecutorMock(QuantumExecutor): 

66 """Mock class for QuantumExecutor. 

67 

68 Parameters 

69 ---------- 

70 mp : `bool` 

71 Whether the mock should use multiprocessing or not. 

72 """ 

73 

74 def __init__(self, mp=False): 

75 self.quanta = [] 

76 if mp: 

77 # in multiprocess mode use shared list 

78 manager = Manager() 

79 self.quanta = manager.list() 

80 self.report = None 

81 self._execute_called = False 

82 

83 def execute(self, task_node, /, quantum): 

84 _LOG.debug("QuantumExecutorMock.execute: task_node=%s dataId=%s", task_node, quantum.dataId) 

85 self._execute_called = True 

86 if task_node.task_class: 

87 try: 

88 # only works for one of the TaskMock classes below 

89 task_node.task_class().runQuantum() 

90 self.report = QuantumReport(dataId=quantum.dataId, taskLabel=task_node.label) 

91 except Exception as exc: 

92 self.report = QuantumReport.from_exception( 

93 exception=exc, 

94 dataId=quantum.dataId, 

95 taskLabel=task_node.label, 

96 ) 

97 raise 

98 self.quanta.append(quantum) 

99 return quantum 

100 

101 def getReport(self): 

102 if not self._execute_called: 

103 raise RuntimeError("getReport called before execute") 

104 return self.report 

105 

106 def getDataIds(self, field): 

107 """Return values for dataId field for each visited quanta. 

108 

109 Parameters 

110 ---------- 

111 field : `str` 

112 Field to select. 

113 """ 

114 return [quantum.dataId[field] for quantum in self.quanta] 

115 

116 

117class QuantumMock: 

118 """Mock equivalent of a `~lsst.daf.butler.Quantum`. 

119 

120 Parameters 

121 ---------- 

122 dataId : `dict` 

123 The Data ID of this quantum. 

124 """ 

125 

126 def __init__(self, dataId): 

127 self.dataId = dataId 

128 

129 def __eq__(self, other): 

130 return self.dataId == other.dataId 

131 

132 def __hash__(self): 

133 # dict.__eq__ is order-insensitive 

134 return hash(tuple(sorted(kv for kv in self.dataId.items()))) 

135 

136 

137class QuantumIterDataMock: 

138 """Simple class to mock QuantumIterData. 

139 

140 Parameters 

141 ---------- 

142 index : `int` 

143 The index of this mock. 

144 task_node : `TaskNodeMock` 

145 Mocked task definition. 

146 **dataId : `~typing.Any` 

147 The data ID of the mocked quantum. 

148 """ 

149 

150 def __init__(self, index, task_node, **dataId): 

151 self.index = index 

152 self.taskDef = task_node 

153 self.task_node = task_node 

154 self.quantum = QuantumMock(dataId) 

155 self.dependencies = set() 

156 self.nodeId = NodeId(index, "DummyBuildString") 

157 

158 

159class QuantumGraphMock: 

160 """Mock for quantum graph. 

161 

162 Parameters 

163 ---------- 

164 qdata : `~collections.abc.Iterable` of `QuantumIterDataMock` 

165 The nodes of the graph. 

166 """ 

167 

168 def __init__(self, qdata): 

169 self._graph = nx.DiGraph() 

170 previous = qdata[0] 

171 for node in qdata[1:]: 

172 self._graph.add_edge(previous, node) 

173 previous = node 

174 

175 def __iter__(self): 

176 yield from nx.topological_sort(self._graph) 

177 

178 def __len__(self): 

179 return len(self._graph) 

180 

181 def findTaskDefByLabel(self, label): 

182 for q in self: 

183 if q.task_node.label == label: 

184 return q.taskDef 

185 

186 def getQuantaForTask(self, taskDef): 

187 nodes = self.getNodesForTask(taskDef) 

188 return {q.quantum for q in nodes} 

189 

190 def getNodesForTask(self, taskDef): 

191 quanta = set() 

192 for q in self: 

193 if q.task_node.label == taskDef.label: 

194 quanta.add(q) 

195 return quanta 

196 

197 @property 

198 def graph(self): 

199 return self._graph 

200 

201 def findCycle(self): 

202 return [] 

203 

204 def determineInputsToQuantumNode(self, node): 

205 result = set() 

206 for n in node.dependencies: 

207 for otherNode in self: 

208 if otherNode.index == n: 

209 result.add(otherNode) 

210 return result 

211 

212 

213class TaskMockMP: 

214 """Simple mock class for task supporting multiprocessing.""" 

215 

216 canMultiprocess = True 

217 

218 def runQuantum(self): 

219 _LOG.debug("TaskMockMP.runQuantum") 

220 pass 

221 

222 

223class TaskMockFail: 

224 """Simple mock class for task which fails.""" 

225 

226 canMultiprocess = True 

227 

228 def runQuantum(self): 

229 _LOG.debug("TaskMockFail.runQuantum") 

230 raise ValueError("expected failure") 

231 

232 

233class TaskMockCrash: 

234 """Simple mock class for task which fails.""" 

235 

236 canMultiprocess = True 

237 

238 def runQuantum(self): 

239 _LOG.debug("TaskMockCrash.runQuantum") 

240 # Disable fault handler to suppress long scary traceback. 

241 faulthandler.disable() 

242 signal.raise_signal(signal.SIGILL) 

243 

244 

245class TaskMockLongSleep: 

246 """Simple mock class for task which "runs" for very long time.""" 

247 

248 canMultiprocess = True 

249 

250 def runQuantum(self): 

251 _LOG.debug("TaskMockLongSleep.runQuantum") 

252 time.sleep(100.0) 

253 

254 

255class TaskMockNoMP: 

256 """Simple mock class for task not supporting multiprocessing.""" 

257 

258 canMultiprocess = False 

259 

260 

261class TaskNodeMock: 

262 """Simple mock class for task definition in a pipeline graph. 

263 

264 Parameters 

265 ---------- 

266 label : `str` 

267 Label of the task in the pipeline. 

268 task_class : `type` 

269 Subclass of `lsst.pipe.base.PipelineTask`. 

270 config : `PipelineTaskConfig`, optional 

271 Configuration for the task. 

272 """ 

273 

274 def __init__(self, label="task1", task_class=TaskMockMP, config=None): 

275 self.label = label 

276 # taskClass to look like TaskDef, task_class to look like TaskNode. 

277 self.taskClass = task_class 

278 self.task_class = task_class 

279 self.config = config 

280 

281 def __str__(self): 

282 return f"TaskNodeMock({self.label}, {self.taskClass.__name__})" 

283 

284 

285def _count_status(report, status): 

286 """Count number of quanta witha a given status.""" 

287 return len([qrep for qrep in report.quantaReports if qrep.status is status]) 

288 

289 

290class MPGraphExecutorTestCase(unittest.TestCase): 

291 """A test case for MPGraphExecutor class.""" 

292 

293 def test_mpexec_nomp(self): 

294 """Make simple graph and execute.""" 

295 task_node = TaskNodeMock() 

296 qgraph = QuantumGraphMock( 

297 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

298 ) 

299 

300 # run in single-process mode 

301 qexec = QuantumExecutorMock() 

302 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec) 

303 mpexec.execute(qgraph) 

304 self.assertEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

305 report = mpexec.getReport() 

306 self.assertEqual(report.status, ExecutionStatus.SUCCESS) 

307 self.assertIsNone(report.exitCode) 

308 self.assertIsNone(report.exceptionInfo) 

309 self.assertEqual(len(report.quantaReports), 3) 

310 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports)) 

311 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports)) 

312 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

313 self.assertTrue(all(qrep.taskLabel == "task1" for qrep in report.quantaReports)) 

314 

315 def test_mpexec_mp(self): 

316 """Make simple graph and execute.""" 

317 task_node = TaskNodeMock() 

318 qgraph = QuantumGraphMock( 

319 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

320 ) 

321 

322 methods = ["spawn"] 

323 if sys.platform == "linux": 

324 methods.append("forkserver") 

325 

326 for method in methods: 

327 with self.subTest(startMethod=method): 

328 # Run in multi-process mode, the order of results is not 

329 # defined. 

330 qexec = QuantumExecutorMock(mp=True) 

331 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, startMethod=method) 

332 mpexec.execute(qgraph) 

333 self.assertCountEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

334 report = mpexec.getReport() 

335 self.assertEqual(report.status, ExecutionStatus.SUCCESS) 

336 self.assertIsNone(report.exitCode) 

337 self.assertIsNone(report.exceptionInfo) 

338 self.assertEqual(len(report.quantaReports), 3) 

339 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports)) 

340 self.assertTrue(all(qrep.exitCode == 0 for qrep in report.quantaReports)) 

341 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

342 self.assertTrue(all(qrep.taskLabel == "task1" for qrep in report.quantaReports)) 

343 

344 def test_mpexec_nompsupport(self): 

345 """Try to run MP for task that has no MP support which should fail.""" 

346 task_node = TaskNodeMock(task_class=TaskMockNoMP) 

347 qgraph = QuantumGraphMock( 

348 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

349 ) 

350 

351 # run in multi-process mode 

352 qexec = QuantumExecutorMock() 

353 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

354 with self.assertRaisesRegex(MPGraphExecutorError, "Task 'task1' does not support multiprocessing"): 

355 mpexec.execute(qgraph) 

356 

357 def test_mpexec_fixup(self): 

358 """Make simple graph and execute, add dependencies by executing fixup 

359 code. 

360 """ 

361 task_node = TaskNodeMock() 

362 

363 for reverse in (False, True): 

364 qgraph = QuantumGraphMock( 

365 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

366 ) 

367 

368 qexec = QuantumExecutorMock() 

369 fixup = ExecFixupDataId("task1", "detector", reverse=reverse) 

370 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec, executionGraphFixup=fixup) 

371 mpexec.execute(qgraph) 

372 

373 expected = [0, 1, 2] 

374 if reverse: 

375 expected = list(reversed(expected)) 

376 self.assertEqual(qexec.getDataIds("detector"), expected) 

377 

378 def test_mpexec_timeout(self): 

379 """Fail due to timeout.""" 

380 task_node = TaskNodeMock() 

381 task_nodeSleep = TaskNodeMock(task_class=TaskMockLongSleep) 

382 qgraph = QuantumGraphMock( 

383 [ 

384 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

385 QuantumIterDataMock(index=1, task_node=task_nodeSleep, detector=1), 

386 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

387 ] 

388 ) 

389 

390 # with failFast we'll get immediate MPTimeoutError 

391 qexec = QuantumExecutorMock(mp=True) 

392 mpexec = MPGraphExecutor(numProc=3, timeout=1, quantumExecutor=qexec, failFast=True) 

393 with self.assertRaises(MPTimeoutError): 

394 mpexec.execute(qgraph) 

395 report = mpexec.getReport() 

396 self.assertEqual(report.status, ExecutionStatus.TIMEOUT) 

397 self.assertEqual(report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPTimeoutError") 

398 self.assertGreater(len(report.quantaReports), 0) 

399 self.assertEqual(_count_status(report, ExecutionStatus.TIMEOUT), 1) 

400 self.assertTrue(any(qrep.exitCode < 0 for qrep in report.quantaReports)) 

401 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

402 

403 # with failFast=False exception happens after last task finishes 

404 qexec = QuantumExecutorMock(mp=True) 

405 mpexec = MPGraphExecutor(numProc=3, timeout=3, quantumExecutor=qexec, failFast=False) 

406 with self.assertRaises(MPTimeoutError): 

407 mpexec.execute(qgraph) 

408 # We expect two tasks (0 and 2) to finish successfully and one task to 

409 # timeout. Unfortunately on busy CPU there is no guarantee that tasks 

410 # finish on time, so expect more timeouts and issue a warning. 

411 detectorIds = set(qexec.getDataIds("detector")) 

412 self.assertLess(len(detectorIds), 3) 

413 if detectorIds != {0, 2}: 

414 warnings.warn(f"Possibly timed out tasks, expected [0, 2], received {detectorIds}") 

415 report = mpexec.getReport() 

416 self.assertEqual(report.status, ExecutionStatus.TIMEOUT) 

417 self.assertEqual(report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPTimeoutError") 

418 self.assertGreater(len(report.quantaReports), 0) 

419 self.assertGreater(_count_status(report, ExecutionStatus.TIMEOUT), 0) 

420 self.assertTrue(any(qrep.exitCode < 0 for qrep in report.quantaReports)) 

421 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

422 

423 def test_mpexec_failure(self): 

424 """Failure in one task should not stop other tasks.""" 

425 task_node = TaskNodeMock() 

426 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

427 qgraph = QuantumGraphMock( 

428 [ 

429 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

430 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

431 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

432 ] 

433 ) 

434 

435 qexec = QuantumExecutorMock(mp=True) 

436 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

437 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

438 mpexec.execute(qgraph) 

439 self.assertCountEqual(qexec.getDataIds("detector"), [0, 2]) 

440 report = mpexec.getReport() 

441 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

442 self.assertEqual( 

443 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

444 ) 

445 self.assertGreater(len(report.quantaReports), 0) 

446 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

447 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

448 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

449 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

450 

451 def test_mpexec_failure_dep(self): 

452 """Failure in one task should skip dependents.""" 

453 task_node = TaskNodeMock() 

454 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

455 qdata = [ 

456 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

457 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

458 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

459 QuantumIterDataMock(index=3, task_node=task_node, detector=3), 

460 QuantumIterDataMock(index=4, task_node=task_node, detector=4), 

461 ] 

462 qdata[2].dependencies.add(1) 

463 qdata[4].dependencies.add(3) 

464 qdata[4].dependencies.add(2) 

465 

466 qgraph = QuantumGraphMock(qdata) 

467 

468 qexec = QuantumExecutorMock(mp=True) 

469 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

470 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

471 mpexec.execute(qgraph) 

472 self.assertCountEqual(qexec.getDataIds("detector"), [0, 3]) 

473 report = mpexec.getReport() 

474 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

475 self.assertEqual( 

476 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

477 ) 

478 # Dependencies of failed tasks do not appear in quantaReports 

479 self.assertGreater(len(report.quantaReports), 0) 

480 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

481 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

482 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 2) 

483 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

484 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

485 

486 def test_mpexec_failure_dep_nomp(self): 

487 """Failure in one task should skip dependents, in-process version.""" 

488 task_node = TaskNodeMock() 

489 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

490 qdata = [ 

491 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

492 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

493 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

494 QuantumIterDataMock(index=3, task_node=task_node, detector=3), 

495 QuantumIterDataMock(index=4, task_node=task_node, detector=4), 

496 ] 

497 qdata[2].dependencies.add(1) 

498 qdata[4].dependencies.add(3) 

499 qdata[4].dependencies.add(2) 

500 

501 qgraph = QuantumGraphMock(qdata) 

502 

503 qexec = QuantumExecutorMock() 

504 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec) 

505 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

506 mpexec.execute(qgraph) 

507 self.assertCountEqual(qexec.getDataIds("detector"), [0, 3]) 

508 report = mpexec.getReport() 

509 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

510 self.assertEqual( 

511 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

512 ) 

513 # Dependencies of failed tasks do not appear in quantaReports 

514 self.assertGreater(len(report.quantaReports), 0) 

515 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

516 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

517 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 2) 

518 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports)) 

519 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

520 

521 def test_mpexec_failure_failfast(self): 

522 """Fast fail stops quickly. 

523 

524 Timing delay of task #3 should be sufficient to process 

525 failure and raise exception. 

526 """ 

527 task_node = TaskNodeMock() 

528 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

529 task_nodeLongSleep = TaskNodeMock(task_class=TaskMockLongSleep) 

530 qdata = [ 

531 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

532 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

533 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

534 QuantumIterDataMock(index=3, task_node=task_nodeLongSleep, detector=3), 

535 QuantumIterDataMock(index=4, task_node=task_node, detector=4), 

536 ] 

537 qdata[1].dependencies.add(0) 

538 qdata[2].dependencies.add(1) 

539 qdata[4].dependencies.add(3) 

540 qdata[4].dependencies.add(2) 

541 

542 qgraph = QuantumGraphMock(qdata) 

543 

544 qexec = QuantumExecutorMock(mp=True) 

545 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, failFast=True) 

546 with self.assertRaisesRegex(MPGraphExecutorError, "failed, exit code=1"): 

547 mpexec.execute(qgraph) 

548 self.assertCountEqual(qexec.getDataIds("detector"), [0]) 

549 report = mpexec.getReport() 

550 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

551 self.assertEqual( 

552 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

553 ) 

554 # Dependencies of failed tasks do not appear in quantaReports 

555 self.assertGreater(len(report.quantaReports), 0) 

556 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

557 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

558 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

559 

560 def test_mpexec_crash(self): 

561 """Check task crash due to signal.""" 

562 task_node = TaskNodeMock() 

563 task_node_crash = TaskNodeMock(task_class=TaskMockCrash) 

564 qgraph = QuantumGraphMock( 

565 [ 

566 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

567 QuantumIterDataMock(index=1, task_node=task_node_crash, detector=1), 

568 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

569 ] 

570 ) 

571 

572 qexec = QuantumExecutorMock(mp=True) 

573 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

574 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

575 mpexec.execute(qgraph) 

576 report = mpexec.getReport() 

577 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

578 self.assertEqual( 

579 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

580 ) 

581 # Dependencies of failed tasks do not appear in quantaReports 

582 self.assertGreater(len(report.quantaReports), 0) 

583 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

584 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

585 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports)) 

586 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

587 

588 def test_mpexec_crash_failfast(self): 

589 """Check task crash due to signal with --fail-fast.""" 

590 task_node = TaskNodeMock() 

591 task_node_crash = TaskNodeMock(task_class=TaskMockCrash) 

592 qgraph = QuantumGraphMock( 

593 [ 

594 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

595 QuantumIterDataMock(index=1, task_node=task_node_crash, detector=1), 

596 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

597 ] 

598 ) 

599 

600 qexec = QuantumExecutorMock(mp=True) 

601 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, failFast=True) 

602 with self.assertRaisesRegex(MPGraphExecutorError, "failed, killed by signal 4 .Illegal instruction"): 

603 mpexec.execute(qgraph) 

604 report = mpexec.getReport() 

605 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

606 self.assertEqual( 

607 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

608 ) 

609 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

610 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports)) 

611 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

612 

613 def test_mpexec_num_fd(self): 

614 """Check that number of open files stays reasonable.""" 

615 task_node = TaskNodeMock() 

616 qgraph = QuantumGraphMock( 

617 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(20)] 

618 ) 

619 

620 this_proc = psutil.Process() 

621 num_fds_0 = this_proc.num_fds() 

622 

623 # run in multi-process mode, the order of results is not defined 

624 qexec = QuantumExecutorMock(mp=True) 

625 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

626 mpexec.execute(qgraph) 

627 

628 num_fds_1 = this_proc.num_fds() 

629 # They should be the same but allow small growth just in case. 

630 # Without DM-26728 fix the difference would be equal to number of 

631 # quanta (20). 

632 self.assertLess(num_fds_1 - num_fds_0, 5) 

633 

634 

635class SingleQuantumExecutorTestCase(unittest.TestCase): 

636 """Tests for SingleQuantumExecutor implementation.""" 

637 

638 instrument = "lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

639 

640 def setUp(self): 

641 self.root = makeTestTempDir(TESTDIR) 

642 

643 def tearDown(self): 

644 removeTestTempDir(self.root) 

645 

646 def test_simple_execute(self) -> None: 

647 """Run execute() method in simplest setup.""" 

648 nQuanta = 1 

649 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

650 

651 nodes = list(qgraph) 

652 self.assertEqual(len(nodes), nQuanta) 

653 node = nodes[0] 

654 

655 taskFactory = AddTaskFactoryMock() 

656 executor = SingleQuantumExecutor(butler, taskFactory) 

657 executor.execute(node.task_node, node.quantum) 

658 self.assertEqual(taskFactory.countExec, 1) 

659 

660 # There must be one dataset of task's output connection 

661 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

662 self.assertEqual(len(refs), 1) 

663 

664 def test_skip_existing_execute(self) -> None: 

665 """Run execute() method twice, with skip_existing_in.""" 

666 nQuanta = 1 

667 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

668 

669 nodes = list(qgraph) 

670 self.assertEqual(len(nodes), nQuanta) 

671 node = nodes[0] 

672 

673 taskFactory = AddTaskFactoryMock() 

674 executor = SingleQuantumExecutor(butler, taskFactory) 

675 executor.execute(node.task_node, node.quantum) 

676 self.assertEqual(taskFactory.countExec, 1) 

677 

678 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

679 self.assertEqual(len(refs), 1) 

680 dataset_id_1 = refs[0].id 

681 

682 # Re-run it with skipExistingIn, it should not run. 

683 assert butler.run is not None 

684 executor = SingleQuantumExecutor(butler, taskFactory, skipExistingIn=[butler.run]) 

685 executor.execute(node.task_node, node.quantum) 

686 self.assertEqual(taskFactory.countExec, 1) 

687 

688 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

689 self.assertEqual(len(refs), 1) 

690 dataset_id_2 = refs[0].id 

691 self.assertEqual(dataset_id_1, dataset_id_2) 

692 

693 def test_clobber_outputs_execute(self) -> None: 

694 """Run execute() method twice, with clobber_outputs.""" 

695 nQuanta = 1 

696 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

697 

698 nodes = list(qgraph) 

699 self.assertEqual(len(nodes), nQuanta) 

700 node = nodes[0] 

701 

702 taskFactory = AddTaskFactoryMock() 

703 executor = SingleQuantumExecutor(butler, taskFactory) 

704 executor.execute(node.task_node, node.quantum) 

705 self.assertEqual(taskFactory.countExec, 1) 

706 

707 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

708 self.assertEqual(len(refs), 1) 

709 dataset_id_1 = refs[0].id 

710 

711 original_dataset = butler.get(refs[0]) 

712 

713 # Remove the dataset ourself, and replace it with something 

714 # different so we can check later whether it got replaced. 

715 butler.pruneDatasets([refs[0]], disassociate=False, unstore=True, purge=False) 

716 replacement = original_dataset + 10 

717 butler.put(replacement, refs[0]) 

718 

719 # Re-run it with clobberOutputs and skipExistingIn, it should not 

720 # clobber but should skip instead. 

721 assert butler.run is not None 

722 executor = SingleQuantumExecutor( 

723 butler, taskFactory, skipExistingIn=[butler.run], clobberOutputs=True 

724 ) 

725 executor.execute(node.task_node, node.quantum) 

726 self.assertEqual(taskFactory.countExec, 1) 

727 

728 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

729 self.assertEqual(len(refs), 1) 

730 dataset_id_2 = refs[0].id 

731 self.assertEqual(dataset_id_1, dataset_id_2) 

732 

733 second_dataset = butler.get(refs[0]) 

734 self.assertEqual(list(second_dataset), list(replacement)) 

735 

736 # Re-run it with clobberOutputs but without skipExistingIn, it should 

737 # clobber. 

738 assert butler.run is not None 

739 executor = SingleQuantumExecutor(butler, taskFactory, clobberOutputs=True) 

740 executor.execute(node.task_node, node.quantum) 

741 self.assertEqual(taskFactory.countExec, 2) 

742 

743 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

744 self.assertEqual(len(refs), 1) 

745 dataset_id_3 = refs[0].id 

746 

747 third_dataset = butler.get(refs[0]) 

748 self.assertEqual(list(third_dataset), list(original_dataset)) 

749 

750 # No change in UUID even after replacement 

751 self.assertEqual(dataset_id_1, dataset_id_3) 

752 

753 

754def setup_module(module): 

755 """Force spawn to be used if no method given explicitly. 

756 

757 This can be removed when Python 3.14 changes the default. 

758 

759 Parameters 

760 ---------- 

761 module : `~types.ModuleType` 

762 Module to set up. 

763 """ 

764 multiprocessing.set_start_method("spawn", force=True) 

765 

766 

767if __name__ == "__main__": 

768 # Do not need to force start mode when running standalone. 

769 multiprocessing.set_start_method("spawn") 

770 unittest.main()