Coverage for tests/test_executors.py: 15%

429 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 02:55 -0700

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for cmdLineFwk module. 

29""" 

30 

31import faulthandler 

32import logging 

33import multiprocessing 

34import os 

35import signal 

36import sys 

37import time 

38import unittest 

39import warnings 

40from multiprocessing import Manager 

41 

42import networkx as nx 

43import psutil 

44from lsst.ctrl.mpexec import ( 

45 ExecutionStatus, 

46 MPGraphExecutor, 

47 MPGraphExecutorError, 

48 MPTimeoutError, 

49 QuantumExecutor, 

50 QuantumReport, 

51 SingleQuantumExecutor, 

52) 

53from lsst.ctrl.mpexec.execFixupDataId import ExecFixupDataId 

54from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

55from lsst.pipe.base import NodeId, QgraphSummary, QgraphTaskSummary 

56from lsst.pipe.base.tests.simpleQGraph import AddTaskFactoryMock, makeSimpleQGraph 

57 

58logging.basicConfig(level=logging.DEBUG) 

59 

60_LOG = logging.getLogger(__name__) 

61 

62TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

63 

64 

65class QuantumExecutorMock(QuantumExecutor): 

66 """Mock class for QuantumExecutor. 

67 

68 Parameters 

69 ---------- 

70 mp : `bool` 

71 Whether the mock should use multiprocessing or not. 

72 """ 

73 

74 def __init__(self, mp=False): 

75 self.quanta = [] 

76 if mp: 

77 # in multiprocess mode use shared list 

78 manager = Manager() 

79 self.quanta = manager.list() 

80 self.report = None 

81 self._execute_called = False 

82 

83 def execute(self, task_node, /, quantum): 

84 _LOG.debug("QuantumExecutorMock.execute: task_node=%s dataId=%s", task_node, quantum.dataId) 

85 self._execute_called = True 

86 if task_node.task_class: 

87 try: 

88 # only works for one of the TaskMock classes below 

89 task_node.task_class().runQuantum() 

90 self.report = QuantumReport(dataId=quantum.dataId, taskLabel=task_node.label) 

91 except Exception as exc: 

92 self.report = QuantumReport.from_exception( 

93 exception=exc, 

94 dataId=quantum.dataId, 

95 taskLabel=task_node.label, 

96 ) 

97 raise 

98 self.quanta.append(quantum) 

99 return quantum 

100 

101 def getReport(self): 

102 if not self._execute_called: 

103 raise RuntimeError("getReport called before execute") 

104 return self.report 

105 

106 def getDataIds(self, field): 

107 """Return values for dataId field for each visited quanta. 

108 

109 Parameters 

110 ---------- 

111 field : `str` 

112 Field to select. 

113 """ 

114 return [quantum.dataId[field] for quantum in self.quanta] 

115 

116 

117class QuantumMock: 

118 """Mock equivalent of a `~lsst.daf.butler.Quantum`. 

119 

120 Parameters 

121 ---------- 

122 dataId : `dict` 

123 The Data ID of this quantum. 

124 """ 

125 

126 def __init__(self, dataId): 

127 self.dataId = dataId 

128 

129 def __eq__(self, other): 

130 return self.dataId == other.dataId 

131 

132 def __hash__(self): 

133 # dict.__eq__ is order-insensitive 

134 return hash(tuple(sorted(kv for kv in self.dataId.items()))) 

135 

136 

137class QuantumIterDataMock: 

138 """Simple class to mock QuantumIterData. 

139 

140 Parameters 

141 ---------- 

142 index : `int` 

143 The index of this mock. 

144 task_node : `TaskNodeMock` 

145 Mocked task definition. 

146 **dataId : `~typing.Any` 

147 The data ID of the mocked quantum. 

148 """ 

149 

150 def __init__(self, index, task_node, **dataId): 

151 self.index = index 

152 self.taskDef = task_node 

153 self.task_node = task_node 

154 self.quantum = QuantumMock(dataId) 

155 self.dependencies = set() 

156 self.nodeId = NodeId(index, "DummyBuildString") 

157 

158 

159class QuantumGraphMock: 

160 """Mock for quantum graph. 

161 

162 Parameters 

163 ---------- 

164 qdata : `~collections.abc.Iterable` of `QuantumIterDataMock` 

165 The nodes of the graph. 

166 """ 

167 

168 def __init__(self, qdata): 

169 self._graph = nx.DiGraph() 

170 previous = qdata[0] 

171 for node in qdata[1:]: 

172 self._graph.add_edge(previous, node) 

173 previous = node 

174 

175 def __iter__(self): 

176 yield from nx.topological_sort(self._graph) 

177 

178 def __len__(self): 

179 return len(self._graph) 

180 

181 def findTaskDefByLabel(self, label): 

182 for q in self: 

183 if q.task_node.label == label: 

184 return q.taskDef 

185 

186 def getQuantaForTask(self, taskDef): 

187 nodes = self.getNodesForTask(taskDef) 

188 return {q.quantum for q in nodes} 

189 

190 def getNodesForTask(self, taskDef): 

191 quanta = set() 

192 for q in self: 

193 if q.task_node.label == taskDef.label: 

194 quanta.add(q) 

195 return quanta 

196 

197 @property 

198 def graph(self): 

199 return self._graph 

200 

201 def findCycle(self): 

202 return [] 

203 

204 def determineInputsToQuantumNode(self, node): 

205 result = set() 

206 for n in node.dependencies: 

207 for otherNode in self: 

208 if otherNode.index == n: 

209 result.add(otherNode) 

210 return result 

211 

212 def getSummary(self): 

213 summary = QgraphSummary( 

214 graphID="1712445133.605479-3902002", 

215 cmdLine="mock_pipetask -a 1 -b 2 -c 3 4 5 6", 

216 pipeBaseVersion="1.1.1", 

217 creationUTC="", 

218 inputCollection=["mock_input"], 

219 outputCollection="mock_output", 

220 outputRun="mock_run", 

221 ) 

222 for q in self: 

223 qts = summary.qgraphTaskSummaries.setdefault( 

224 q.taskDef.label, QgraphTaskSummary(taskLabel=q.taskDef.label) 

225 ) 

226 qts.numQuanta += 1 

227 

228 for k in ["in1", "in2", "in3"]: 

229 qts.numInputs[k] += 1 

230 

231 for k in ["out1", "out2", "out3"]: 

232 qts.numOutputs[k] += 1 

233 

234 return summary 

235 

236 

237class TaskMockMP: 

238 """Simple mock class for task supporting multiprocessing.""" 

239 

240 canMultiprocess = True 

241 

242 def runQuantum(self): 

243 _LOG.debug("TaskMockMP.runQuantum") 

244 pass 

245 

246 

247class TaskMockFail: 

248 """Simple mock class for task which fails.""" 

249 

250 canMultiprocess = True 

251 

252 def runQuantum(self): 

253 _LOG.debug("TaskMockFail.runQuantum") 

254 raise ValueError("expected failure") 

255 

256 

257class TaskMockCrash: 

258 """Simple mock class for task which fails.""" 

259 

260 canMultiprocess = True 

261 

262 def runQuantum(self): 

263 _LOG.debug("TaskMockCrash.runQuantum") 

264 # Disable fault handler to suppress long scary traceback. 

265 faulthandler.disable() 

266 signal.raise_signal(signal.SIGILL) 

267 

268 

269class TaskMockLongSleep: 

270 """Simple mock class for task which "runs" for very long time.""" 

271 

272 canMultiprocess = True 

273 

274 def runQuantum(self): 

275 _LOG.debug("TaskMockLongSleep.runQuantum") 

276 time.sleep(100.0) 

277 

278 

279class TaskMockNoMP: 

280 """Simple mock class for task not supporting multiprocessing.""" 

281 

282 canMultiprocess = False 

283 

284 

285class TaskNodeMock: 

286 """Simple mock class for task definition in a pipeline graph. 

287 

288 Parameters 

289 ---------- 

290 label : `str` 

291 Label of the task in the pipeline. 

292 task_class : `type` 

293 Subclass of `lsst.pipe.base.PipelineTask`. 

294 config : `PipelineTaskConfig`, optional 

295 Configuration for the task. 

296 """ 

297 

298 def __init__(self, label="task1", task_class=TaskMockMP, config=None): 

299 self.label = label 

300 # taskClass to look like TaskDef, task_class to look like TaskNode. 

301 self.taskClass = task_class 

302 self.task_class = task_class 

303 self.config = config 

304 

305 def __str__(self): 

306 return f"TaskNodeMock({self.label}, {self.taskClass.__name__})" 

307 

308 

309def _count_status(report, status): 

310 """Count number of quanta witha a given status.""" 

311 return len([qrep for qrep in report.quantaReports if qrep.status is status]) 

312 

313 

314class MPGraphExecutorTestCase(unittest.TestCase): 

315 """A test case for MPGraphExecutor class.""" 

316 

317 def test_mpexec_nomp(self): 

318 """Make simple graph and execute.""" 

319 task_node = TaskNodeMock() 

320 qgraph = QuantumGraphMock( 

321 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

322 ) 

323 

324 # run in single-process mode 

325 qexec = QuantumExecutorMock() 

326 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec) 

327 mpexec.execute(qgraph) 

328 self.assertEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

329 report = mpexec.getReport() 

330 self.assertEqual(report.status, ExecutionStatus.SUCCESS) 

331 self.assertIsNone(report.exitCode) 

332 self.assertIsNone(report.exceptionInfo) 

333 self.assertEqual(len(report.quantaReports), 3) 

334 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports)) 

335 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports)) 

336 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

337 self.assertTrue(all(qrep.taskLabel == "task1" for qrep in report.quantaReports)) 

338 

339 def test_mpexec_mp(self): 

340 """Make simple graph and execute.""" 

341 task_node = TaskNodeMock() 

342 qgraph = QuantumGraphMock( 

343 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

344 ) 

345 

346 methods = ["spawn"] 

347 if sys.platform == "linux": 

348 methods.append("forkserver") 

349 

350 for method in methods: 

351 with self.subTest(startMethod=method): 

352 # Run in multi-process mode, the order of results is not 

353 # defined. 

354 qexec = QuantumExecutorMock(mp=True) 

355 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, startMethod=method) 

356 mpexec.execute(qgraph) 

357 self.assertCountEqual(qexec.getDataIds("detector"), [0, 1, 2]) 

358 report = mpexec.getReport() 

359 self.assertEqual(report.status, ExecutionStatus.SUCCESS) 

360 self.assertIsNone(report.exitCode) 

361 self.assertIsNone(report.exceptionInfo) 

362 self.assertEqual(len(report.quantaReports), 3) 

363 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports)) 

364 self.assertTrue(all(qrep.exitCode == 0 for qrep in report.quantaReports)) 

365 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

366 self.assertTrue(all(qrep.taskLabel == "task1" for qrep in report.quantaReports)) 

367 

368 def test_mpexec_nompsupport(self): 

369 """Try to run MP for task that has no MP support which should fail.""" 

370 task_node = TaskNodeMock(task_class=TaskMockNoMP) 

371 qgraph = QuantumGraphMock( 

372 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

373 ) 

374 

375 # run in multi-process mode 

376 qexec = QuantumExecutorMock() 

377 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

378 with self.assertRaisesRegex(MPGraphExecutorError, "Task 'task1' does not support multiprocessing"): 

379 mpexec.execute(qgraph) 

380 

381 def test_mpexec_fixup(self): 

382 """Make simple graph and execute, add dependencies by executing fixup 

383 code. 

384 """ 

385 task_node = TaskNodeMock() 

386 

387 for reverse in (False, True): 

388 qgraph = QuantumGraphMock( 

389 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(3)] 

390 ) 

391 

392 qexec = QuantumExecutorMock() 

393 fixup = ExecFixupDataId("task1", "detector", reverse=reverse) 

394 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec, executionGraphFixup=fixup) 

395 mpexec.execute(qgraph) 

396 

397 expected = [0, 1, 2] 

398 if reverse: 

399 expected = list(reversed(expected)) 

400 self.assertEqual(qexec.getDataIds("detector"), expected) 

401 

402 def test_mpexec_timeout(self): 

403 """Fail due to timeout.""" 

404 task_node = TaskNodeMock() 

405 task_nodeSleep = TaskNodeMock(task_class=TaskMockLongSleep) 

406 qgraph = QuantumGraphMock( 

407 [ 

408 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

409 QuantumIterDataMock(index=1, task_node=task_nodeSleep, detector=1), 

410 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

411 ] 

412 ) 

413 

414 # with failFast we'll get immediate MPTimeoutError 

415 qexec = QuantumExecutorMock(mp=True) 

416 mpexec = MPGraphExecutor(numProc=3, timeout=1, quantumExecutor=qexec, failFast=True) 

417 with self.assertRaises(MPTimeoutError): 

418 mpexec.execute(qgraph) 

419 report = mpexec.getReport() 

420 self.assertEqual(report.status, ExecutionStatus.TIMEOUT) 

421 self.assertEqual(report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPTimeoutError") 

422 self.assertGreater(len(report.quantaReports), 0) 

423 self.assertEqual(_count_status(report, ExecutionStatus.TIMEOUT), 1) 

424 self.assertTrue(any(qrep.exitCode < 0 for qrep in report.quantaReports)) 

425 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

426 

427 # with failFast=False exception happens after last task finishes 

428 qexec = QuantumExecutorMock(mp=True) 

429 mpexec = MPGraphExecutor(numProc=3, timeout=3, quantumExecutor=qexec, failFast=False) 

430 with self.assertRaises(MPTimeoutError): 

431 mpexec.execute(qgraph) 

432 # We expect two tasks (0 and 2) to finish successfully and one task to 

433 # timeout. Unfortunately on busy CPU there is no guarantee that tasks 

434 # finish on time, so expect more timeouts and issue a warning. 

435 detectorIds = set(qexec.getDataIds("detector")) 

436 self.assertLess(len(detectorIds), 3) 

437 if detectorIds != {0, 2}: 

438 warnings.warn(f"Possibly timed out tasks, expected [0, 2], received {detectorIds}") 

439 report = mpexec.getReport() 

440 self.assertEqual(report.status, ExecutionStatus.TIMEOUT) 

441 self.assertEqual(report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPTimeoutError") 

442 self.assertGreater(len(report.quantaReports), 0) 

443 self.assertGreater(_count_status(report, ExecutionStatus.TIMEOUT), 0) 

444 self.assertTrue(any(qrep.exitCode < 0 for qrep in report.quantaReports)) 

445 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

446 

447 def test_mpexec_failure(self): 

448 """Failure in one task should not stop other tasks.""" 

449 task_node = TaskNodeMock() 

450 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

451 qgraph = QuantumGraphMock( 

452 [ 

453 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

454 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

455 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

456 ] 

457 ) 

458 

459 qexec = QuantumExecutorMock(mp=True) 

460 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

461 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

462 mpexec.execute(qgraph) 

463 self.assertCountEqual(qexec.getDataIds("detector"), [0, 2]) 

464 report = mpexec.getReport() 

465 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

466 self.assertEqual( 

467 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

468 ) 

469 self.assertGreater(len(report.quantaReports), 0) 

470 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

471 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

472 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

473 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

474 

475 def test_mpexec_failure_dep(self): 

476 """Failure in one task should skip dependents.""" 

477 task_node = TaskNodeMock() 

478 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

479 qdata = [ 

480 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

481 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

482 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

483 QuantumIterDataMock(index=3, task_node=task_node, detector=3), 

484 QuantumIterDataMock(index=4, task_node=task_node, detector=4), 

485 ] 

486 qdata[2].dependencies.add(1) 

487 qdata[4].dependencies.add(3) 

488 qdata[4].dependencies.add(2) 

489 

490 qgraph = QuantumGraphMock(qdata) 

491 

492 qexec = QuantumExecutorMock(mp=True) 

493 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

494 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

495 mpexec.execute(qgraph) 

496 self.assertCountEqual(qexec.getDataIds("detector"), [0, 3]) 

497 report = mpexec.getReport() 

498 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

499 self.assertEqual( 

500 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

501 ) 

502 # Dependencies of failed tasks do not appear in quantaReports 

503 self.assertGreater(len(report.quantaReports), 0) 

504 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

505 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

506 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 2) 

507 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

508 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

509 

510 def test_mpexec_failure_dep_nomp(self): 

511 """Failure in one task should skip dependents, in-process version.""" 

512 task_node = TaskNodeMock() 

513 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

514 qdata = [ 

515 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

516 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

517 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

518 QuantumIterDataMock(index=3, task_node=task_node, detector=3), 

519 QuantumIterDataMock(index=4, task_node=task_node, detector=4), 

520 ] 

521 qdata[2].dependencies.add(1) 

522 qdata[4].dependencies.add(3) 

523 qdata[4].dependencies.add(2) 

524 

525 qgraph = QuantumGraphMock(qdata) 

526 

527 qexec = QuantumExecutorMock() 

528 mpexec = MPGraphExecutor(numProc=1, timeout=100, quantumExecutor=qexec) 

529 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

530 mpexec.execute(qgraph) 

531 self.assertCountEqual(qexec.getDataIds("detector"), [0, 3]) 

532 report = mpexec.getReport() 

533 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

534 self.assertEqual( 

535 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

536 ) 

537 # Dependencies of failed tasks do not appear in quantaReports 

538 self.assertGreater(len(report.quantaReports), 0) 

539 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

540 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

541 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 2) 

542 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports)) 

543 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

544 

545 def test_mpexec_failure_failfast(self): 

546 """Fast fail stops quickly. 

547 

548 Timing delay of task #3 should be sufficient to process 

549 failure and raise exception. 

550 """ 

551 task_node = TaskNodeMock() 

552 task_node_fail = TaskNodeMock(task_class=TaskMockFail) 

553 task_nodeLongSleep = TaskNodeMock(task_class=TaskMockLongSleep) 

554 qdata = [ 

555 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

556 QuantumIterDataMock(index=1, task_node=task_node_fail, detector=1), 

557 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

558 QuantumIterDataMock(index=3, task_node=task_nodeLongSleep, detector=3), 

559 QuantumIterDataMock(index=4, task_node=task_node, detector=4), 

560 ] 

561 qdata[1].dependencies.add(0) 

562 qdata[2].dependencies.add(1) 

563 qdata[4].dependencies.add(3) 

564 qdata[4].dependencies.add(2) 

565 

566 qgraph = QuantumGraphMock(qdata) 

567 

568 qexec = QuantumExecutorMock(mp=True) 

569 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, failFast=True) 

570 with self.assertRaisesRegex(MPGraphExecutorError, "failed, exit code=1"): 

571 mpexec.execute(qgraph) 

572 self.assertCountEqual(qexec.getDataIds("detector"), [0]) 

573 report = mpexec.getReport() 

574 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

575 self.assertEqual( 

576 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

577 ) 

578 # Dependencies of failed tasks do not appear in quantaReports 

579 self.assertGreater(len(report.quantaReports), 0) 

580 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

581 self.assertTrue(any(qrep.exitCode > 0 for qrep in report.quantaReports)) 

582 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports)) 

583 

584 def test_mpexec_crash(self): 

585 """Check task crash due to signal.""" 

586 task_node = TaskNodeMock() 

587 task_node_crash = TaskNodeMock(task_class=TaskMockCrash) 

588 qgraph = QuantumGraphMock( 

589 [ 

590 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

591 QuantumIterDataMock(index=1, task_node=task_node_crash, detector=1), 

592 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

593 ] 

594 ) 

595 

596 qexec = QuantumExecutorMock(mp=True) 

597 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

598 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"): 

599 mpexec.execute(qgraph) 

600 report = mpexec.getReport() 

601 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

602 self.assertEqual( 

603 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

604 ) 

605 # Dependencies of failed tasks do not appear in quantaReports 

606 self.assertGreater(len(report.quantaReports), 0) 

607 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

608 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 2) 

609 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports)) 

610 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

611 

612 def test_mpexec_crash_failfast(self): 

613 """Check task crash due to signal with --fail-fast.""" 

614 task_node = TaskNodeMock() 

615 task_node_crash = TaskNodeMock(task_class=TaskMockCrash) 

616 qgraph = QuantumGraphMock( 

617 [ 

618 QuantumIterDataMock(index=0, task_node=task_node, detector=0), 

619 QuantumIterDataMock(index=1, task_node=task_node_crash, detector=1), 

620 QuantumIterDataMock(index=2, task_node=task_node, detector=2), 

621 ] 

622 ) 

623 

624 qexec = QuantumExecutorMock(mp=True) 

625 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec, failFast=True) 

626 with self.assertRaisesRegex(MPGraphExecutorError, "failed, killed by signal 4 .Illegal instruction"): 

627 mpexec.execute(qgraph) 

628 report = mpexec.getReport() 

629 self.assertEqual(report.status, ExecutionStatus.FAILURE) 

630 self.assertEqual( 

631 report.exceptionInfo.className, "lsst.ctrl.mpexec.mpGraphExecutor.MPGraphExecutorError" 

632 ) 

633 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1) 

634 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports)) 

635 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports)) 

636 

637 def test_mpexec_num_fd(self): 

638 """Check that number of open files stays reasonable.""" 

639 task_node = TaskNodeMock() 

640 qgraph = QuantumGraphMock( 

641 [QuantumIterDataMock(index=i, task_node=task_node, detector=i) for i in range(20)] 

642 ) 

643 

644 this_proc = psutil.Process() 

645 num_fds_0 = this_proc.num_fds() 

646 

647 # run in multi-process mode, the order of results is not defined 

648 qexec = QuantumExecutorMock(mp=True) 

649 mpexec = MPGraphExecutor(numProc=3, timeout=100, quantumExecutor=qexec) 

650 mpexec.execute(qgraph) 

651 

652 num_fds_1 = this_proc.num_fds() 

653 # They should be the same but allow small growth just in case. 

654 # Without DM-26728 fix the difference would be equal to number of 

655 # quanta (20). 

656 self.assertLess(num_fds_1 - num_fds_0, 5) 

657 

658 

659class SingleQuantumExecutorTestCase(unittest.TestCase): 

660 """Tests for SingleQuantumExecutor implementation.""" 

661 

662 instrument = "lsst.pipe.base.tests.simpleQGraph.SimpleInstrument" 

663 

664 def setUp(self): 

665 self.root = makeTestTempDir(TESTDIR) 

666 

667 def tearDown(self): 

668 removeTestTempDir(self.root) 

669 

670 def test_simple_execute(self) -> None: 

671 """Run execute() method in simplest setup.""" 

672 nQuanta = 1 

673 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

674 

675 nodes = list(qgraph) 

676 self.assertEqual(len(nodes), nQuanta) 

677 node = nodes[0] 

678 

679 taskFactory = AddTaskFactoryMock() 

680 executor = SingleQuantumExecutor(butler, taskFactory) 

681 executor.execute(node.task_node, node.quantum) 

682 self.assertEqual(taskFactory.countExec, 1) 

683 

684 # There must be one dataset of task's output connection 

685 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

686 self.assertEqual(len(refs), 1) 

687 

688 def test_skip_existing_execute(self) -> None: 

689 """Run execute() method twice, with skip_existing_in.""" 

690 nQuanta = 1 

691 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

692 

693 nodes = list(qgraph) 

694 self.assertEqual(len(nodes), nQuanta) 

695 node = nodes[0] 

696 

697 taskFactory = AddTaskFactoryMock() 

698 executor = SingleQuantumExecutor(butler, taskFactory) 

699 executor.execute(node.task_node, node.quantum) 

700 self.assertEqual(taskFactory.countExec, 1) 

701 

702 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

703 self.assertEqual(len(refs), 1) 

704 dataset_id_1 = refs[0].id 

705 

706 # Re-run it with skipExistingIn, it should not run. 

707 assert butler.run is not None 

708 executor = SingleQuantumExecutor(butler, taskFactory, skipExistingIn=[butler.run]) 

709 executor.execute(node.task_node, node.quantum) 

710 self.assertEqual(taskFactory.countExec, 1) 

711 

712 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

713 self.assertEqual(len(refs), 1) 

714 dataset_id_2 = refs[0].id 

715 self.assertEqual(dataset_id_1, dataset_id_2) 

716 

717 def test_clobber_outputs_execute(self) -> None: 

718 """Run execute() method twice, with clobber_outputs.""" 

719 nQuanta = 1 

720 butler, qgraph = makeSimpleQGraph(nQuanta, root=self.root, instrument=self.instrument) 

721 

722 nodes = list(qgraph) 

723 self.assertEqual(len(nodes), nQuanta) 

724 node = nodes[0] 

725 

726 taskFactory = AddTaskFactoryMock() 

727 executor = SingleQuantumExecutor(butler, taskFactory) 

728 executor.execute(node.task_node, node.quantum) 

729 self.assertEqual(taskFactory.countExec, 1) 

730 

731 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

732 self.assertEqual(len(refs), 1) 

733 dataset_id_1 = refs[0].id 

734 

735 original_dataset = butler.get(refs[0]) 

736 

737 # Remove the dataset ourself, and replace it with something 

738 # different so we can check later whether it got replaced. 

739 butler.pruneDatasets([refs[0]], disassociate=False, unstore=True, purge=False) 

740 replacement = original_dataset + 10 

741 butler.put(replacement, refs[0]) 

742 

743 # Re-run it with clobberOutputs and skipExistingIn, it should not 

744 # clobber but should skip instead. 

745 assert butler.run is not None 

746 executor = SingleQuantumExecutor( 

747 butler, taskFactory, skipExistingIn=[butler.run], clobberOutputs=True 

748 ) 

749 executor.execute(node.task_node, node.quantum) 

750 self.assertEqual(taskFactory.countExec, 1) 

751 

752 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

753 self.assertEqual(len(refs), 1) 

754 dataset_id_2 = refs[0].id 

755 self.assertEqual(dataset_id_1, dataset_id_2) 

756 

757 second_dataset = butler.get(refs[0]) 

758 self.assertEqual(list(second_dataset), list(replacement)) 

759 

760 # Re-run it with clobberOutputs but without skipExistingIn, it should 

761 # clobber. 

762 assert butler.run is not None 

763 executor = SingleQuantumExecutor(butler, taskFactory, clobberOutputs=True) 

764 executor.execute(node.task_node, node.quantum) 

765 self.assertEqual(taskFactory.countExec, 2) 

766 

767 refs = list(butler.registry.queryDatasets("add_dataset1", collections=butler.run)) 

768 self.assertEqual(len(refs), 1) 

769 dataset_id_3 = refs[0].id 

770 

771 third_dataset = butler.get(refs[0]) 

772 self.assertEqual(list(third_dataset), list(original_dataset)) 

773 

774 # No change in UUID even after replacement 

775 self.assertEqual(dataset_id_1, dataset_id_3) 

776 

777 

778def setup_module(module): 

779 """Force spawn to be used if no method given explicitly. 

780 

781 This can be removed when Python 3.14 changes the default. 

782 

783 Parameters 

784 ---------- 

785 module : `~types.ModuleType` 

786 Module to set up. 

787 """ 

788 multiprocessing.set_start_method("spawn", force=True) 

789 

790 

791if __name__ == "__main__": 

792 # Do not need to force start mode when running standalone. 

793 multiprocessing.set_start_method("spawn") 

794 unittest.main()