Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

287 statements  

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24import gc 

25import importlib 

26import logging 

27import multiprocessing 

28import pickle 

29import signal 

30import sys 

31import time 

32from enum import Enum 

33from typing import Optional 

34 

35from lsst.daf.butler.cli.cliLog import CliLog 

36from lsst.pipe.base import InvalidQuantumError 

37from lsst.pipe.base.graph.graph import QuantumGraph 

38from lsst.utils.threads import disable_implicit_threading 

39 

40# ----------------------------- 

41# Imports for other modules -- 

42# ----------------------------- 

43from .quantumGraphExecutor import QuantumGraphExecutor 

44from .reports import ExecutionStatus, QuantumReport, Report 

45 

46_LOG = logging.getLogger(__name__) 

47 

48 

49# Possible states for the executing task: 

50# - PENDING: job has not started yet 

51# - RUNNING: job is currently executing 

52# - FINISHED: job finished successfully 

53# - FAILED: job execution failed (process returned non-zero status) 

54# - TIMED_OUT: job is killed due to too long execution time 

55# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

56JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

57 

58 

59class _Job: 

60 """Class representing a job running single task. 

61 

62 Parameters 

63 ---------- 

64 qnode: `~lsst.pipe.base.QuantumNode` 

65 Quantum and some associated information. 

66 """ 

67 

68 def __init__(self, qnode): 

69 self.qnode = qnode 

70 self.process = None 

71 self._state = JobState.PENDING 

72 self.started = None 

73 self._rcv_conn = None 

74 self._terminated = False 

75 

76 @property 

77 def state(self): 

78 """Job processing state (JobState)""" 

79 return self._state 

80 

81 @property 

82 def terminated(self): 

83 """Return True if job was killed by stop() method and negative exit 

84 code is returned from child process. (`bool`)""" 

85 return self._terminated and self.process.exitcode < 0 

86 

87 def start(self, butler, quantumExecutor, startMethod=None): 

88 """Start process which runs the task. 

89 

90 Parameters 

91 ---------- 

92 butler : `lsst.daf.butler.Butler` 

93 Data butler instance. 

94 quantumExecutor : `QuantumExecutor` 

95 Executor for single quantum. 

96 startMethod : `str`, optional 

97 Start method from `multiprocessing` module. 

98 """ 

99 # Unpickling of quantum has to happen after butler, this is why 

100 # it is pickled manually here. 

101 quantum_pickle = pickle.dumps(self.qnode.quantum) 

102 taskDef = self.qnode.taskDef 

103 self._rcv_conn, snd_conn = multiprocessing.Pipe(False) 

104 logConfigState = CliLog.configState 

105 mp_ctx = multiprocessing.get_context(startMethod) 

106 self.process = mp_ctx.Process( 

107 target=_Job._executeJob, 

108 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn), 

109 name=f"task-{self.qnode.quantum.dataId}", 

110 ) 

111 self.process.start() 

112 self.started = time.time() 

113 self._state = JobState.RUNNING 

114 

115 @staticmethod 

116 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn): 

117 """Execute a job with arguments. 

118 

119 Parameters 

120 ---------- 

121 quantumExecutor : `QuantumExecutor` 

122 Executor for single quantum. 

123 taskDef : `bytes` 

124 Task definition structure. 

125 quantum_pickle : `bytes` 

126 Quantum for this task execution in pickled form. 

127 butler : `lss.daf.butler.Butler` 

128 Data butler instance. 

129 snd_conn : `multiprocessing.Connection` 

130 Connection to send job report to parent process. 

131 """ 

132 if logConfigState and not CliLog.configState: 

133 # means that we are in a new spawned Python process and we have to 

134 # re-initialize logging 

135 CliLog.replayConfigState(logConfigState) 

136 

137 # have to reset connection pool to avoid sharing connections 

138 if butler is not None: 

139 butler.registry.resetConnectionPool() 

140 

141 quantum = pickle.loads(quantum_pickle) 

142 try: 

143 quantumExecutor.execute(taskDef, quantum, butler) 

144 finally: 

145 # If sending fails we do not want this new exception to be exposed. 

146 try: 

147 report = quantumExecutor.getReport() 

148 snd_conn.send(report) 

149 except Exception: 

150 pass 

151 

152 def stop(self): 

153 """Stop the process.""" 

154 self.process.terminate() 

155 # give it 1 second to finish or KILL 

156 for i in range(10): 

157 time.sleep(0.1) 

158 if not self.process.is_alive(): 

159 break 

160 else: 

161 _LOG.debug("Killing process %s", self.process.name) 

162 self.process.kill() 

163 self._terminated = True 

164 

165 def cleanup(self): 

166 """Release processes resources, has to be called for each finished 

167 process. 

168 """ 

169 if self.process and not self.process.is_alive(): 

170 self.process.close() 

171 self.process = None 

172 self._rcv_conn = None 

173 

174 def report(self) -> QuantumReport: 

175 """Return task report, should be called after process finishes and 

176 before cleanup(). 

177 """ 

178 try: 

179 report = self._rcv_conn.recv() 

180 report.exitCode = self.process.exitcode 

181 except Exception: 

182 # Likely due to the process killed, but there may be other reasons. 

183 report = QuantumReport.from_exit_code( 

184 exitCode=self.process.exitcode, 

185 dataId=self.qnode.quantum.dataId, 

186 taskLabel=self.qnode.taskDef.label, 

187 ) 

188 if self.terminated: 

189 # Means it was killed, assume it's due to timeout 

190 report.status = ExecutionStatus.TIMEOUT 

191 return report 

192 

193 def failMessage(self): 

194 """Return a message describing task failure""" 

195 exitcode = self.process.exitcode 

196 if exitcode < 0: 

197 # Negative exit code means it is killed by signal 

198 signum = -exitcode 

199 msg = f"Task {self} failed, killed by signal {signum}" 

200 # Just in case this is some very odd signal, expect ValueError 

201 try: 

202 strsignal = signal.strsignal(signum) 

203 msg = f"{msg} ({strsignal})" 

204 except ValueError: 

205 pass 

206 elif exitcode > 0: 

207 msg = f"Task {self} failed, exit code={exitcode}" 

208 else: 

209 msg = None 

210 return msg 

211 

212 def __str__(self): 

213 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

214 

215 

216class _JobList: 

217 """Simple list of _Job instances with few convenience methods. 

218 

219 Parameters 

220 ---------- 

221 iterable : iterable of `~lsst.pipe.base.QuantumNode` 

222 Sequence of Quanta to execute. This has to be ordered according to 

223 task dependencies. 

224 """ 

225 

226 def __init__(self, iterable): 

227 self.jobs = [_Job(qnode) for qnode in iterable] 

228 self.pending = self.jobs[:] 

229 self.running = [] 

230 self.finishedNodes = set() 

231 self.failedNodes = set() 

232 self.timedOutNodes = set() 

233 

234 def submit(self, job, butler, quantumExecutor, startMethod=None): 

235 """Submit one more job for execution 

236 

237 Parameters 

238 ---------- 

239 job : `_Job` 

240 Job to submit. 

241 butler : `lsst.daf.butler.Butler` 

242 Data butler instance. 

243 quantumExecutor : `QuantumExecutor` 

244 Executor for single quantum. 

245 startMethod : `str`, optional 

246 Start method from `multiprocessing` module. 

247 """ 

248 # this will raise if job is not in pending list 

249 self.pending.remove(job) 

250 job.start(butler, quantumExecutor, startMethod) 

251 self.running.append(job) 

252 

253 def setJobState(self, job, state): 

254 """Update job state. 

255 

256 Parameters 

257 ---------- 

258 job : `_Job` 

259 Job to submit. 

260 state : `JobState` 

261 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

262 FAILED_DEP state is acceptable. 

263 """ 

264 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP) 

265 assert state in allowedStates, f"State {state} not allowed here" 

266 

267 # remove job from pending/running lists 

268 if job.state == JobState.PENDING: 

269 self.pending.remove(job) 

270 elif job.state == JobState.RUNNING: 

271 self.running.remove(job) 

272 

273 qnode = job.qnode 

274 # it should not be in any of these, but just in case 

275 self.finishedNodes.discard(qnode) 

276 self.failedNodes.discard(qnode) 

277 self.timedOutNodes.discard(qnode) 

278 

279 job._state = state 

280 if state == JobState.FINISHED: 

281 self.finishedNodes.add(qnode) 

282 elif state == JobState.FAILED: 

283 self.failedNodes.add(qnode) 

284 elif state == JobState.FAILED_DEP: 

285 self.failedNodes.add(qnode) 

286 elif state == JobState.TIMED_OUT: 

287 self.failedNodes.add(qnode) 

288 self.timedOutNodes.add(qnode) 

289 else: 

290 raise ValueError(f"Unexpected state value: {state}") 

291 

292 def cleanup(self): 

293 """Do periodic cleanup for jobs that did not finish correctly. 

294 

295 If timed out jobs are killed but take too long to stop then regular 

296 cleanup will not work for them. Here we check all timed out jobs 

297 periodically and do cleanup if they managed to die by this time. 

298 """ 

299 for job in self.jobs: 

300 if job.state == JobState.TIMED_OUT and job.process is not None: 

301 job.cleanup() 

302 

303 

304class MPGraphExecutorError(Exception): 

305 """Exception class for errors raised by MPGraphExecutor.""" 

306 

307 pass 

308 

309 

310class MPTimeoutError(MPGraphExecutorError): 

311 """Exception raised when task execution times out.""" 

312 

313 pass 

314 

315 

316class MPGraphExecutor(QuantumGraphExecutor): 

317 """Implementation of QuantumGraphExecutor using same-host multiprocess 

318 execution of Quanta. 

319 

320 Parameters 

321 ---------- 

322 numProc : `int` 

323 Number of processes to use for executing tasks. 

324 timeout : `float` 

325 Time in seconds to wait for tasks to finish. 

326 quantumExecutor : `QuantumExecutor` 

327 Executor for single quantum. For multiprocess-style execution when 

328 ``numProc`` is greater than one this instance must support pickle. 

329 startMethod : `str`, optional 

330 Start method from `multiprocessing` module, `None` selects the best 

331 one for current platform. 

332 failFast : `bool`, optional 

333 If set to ``True`` then stop processing on first error from any task. 

334 pdb : `str`, optional 

335 Debugger to import and use (via the ``post_mortem`` function) in the 

336 event of an exception. 

337 executionGraphFixup : `ExecutionGraphFixup`, optional 

338 Instance used for modification of execution graph. 

339 """ 

340 

341 def __init__( 

342 self, 

343 numProc, 

344 timeout, 

345 quantumExecutor, 

346 *, 

347 startMethod=None, 

348 failFast=False, 

349 pdb=None, 

350 executionGraphFixup=None, 

351 ): 

352 self.numProc = numProc 

353 self.timeout = timeout 

354 self.quantumExecutor = quantumExecutor 

355 self.failFast = failFast 

356 self.pdb = pdb 

357 self.executionGraphFixup = executionGraphFixup 

358 self.report: Optional[Report] = None 

359 

360 # We set default start method as spawn for MacOS and fork for Linux; 

361 # None for all other platforms to use multiprocessing default. 

362 if startMethod is None: 

363 methods = dict(linux="fork", darwin="spawn") 

364 startMethod = methods.get(sys.platform) 

365 self.startMethod = startMethod 

366 

367 def execute(self, graph, butler): 

368 # Docstring inherited from QuantumGraphExecutor.execute 

369 graph = self._fixupQuanta(graph) 

370 self.report = Report() 

371 try: 

372 if self.numProc > 1: 

373 self._executeQuantaMP(graph, butler) 

374 else: 

375 self._executeQuantaInProcess(graph, butler) 

376 except Exception as exc: 

377 self.report.set_exception(exc) 

378 raise 

379 

380 def _fixupQuanta(self, graph: QuantumGraph): 

381 """Call fixup code to modify execution graph. 

382 

383 Parameters 

384 ---------- 

385 graph : `QuantumGraph` 

386 `QuantumGraph` to modify 

387 

388 Returns 

389 ------- 

390 graph : `QuantumGraph` 

391 Modified `QuantumGraph`. 

392 

393 Raises 

394 ------ 

395 MPGraphExecutorError 

396 Raised if execution graph cannot be ordered after modification, 

397 i.e. it has dependency cycles. 

398 """ 

399 if not self.executionGraphFixup: 

400 return graph 

401 

402 _LOG.debug("Call execution graph fixup method") 

403 graph = self.executionGraphFixup.fixupQuanta(graph) 

404 

405 # Detect if there is now a cycle created within the graph 

406 if graph.findCycle(): 

407 raise MPGraphExecutorError("Updated execution graph has dependency cycle.") 

408 

409 return graph 

410 

411 def _executeQuantaInProcess(self, graph, butler): 

412 """Execute all Quanta in current process. 

413 

414 Parameters 

415 ---------- 

416 graph : `QuantumGraph` 

417 `QuantumGraph` that is to be executed 

418 butler : `lsst.daf.butler.Butler` 

419 Data butler instance 

420 """ 

421 successCount, totalCount = 0, len(graph) 

422 failedNodes = set() 

423 for qnode in graph: 

424 

425 # Any failed inputs mean that the quantum has to be skipped. 

426 inputNodes = graph.determineInputsToQuantumNode(qnode) 

427 if inputNodes & failedNodes: 

428 _LOG.error( 

429 "Upstream job failed for task <%s dataId=%s>, skipping this task.", 

430 qnode.taskDef, 

431 qnode.quantum.dataId, 

432 ) 

433 failedNodes.add(qnode) 

434 quantum_report = QuantumReport( 

435 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label 

436 ) 

437 self.report.quantaReports.append(quantum_report) 

438 continue 

439 

440 _LOG.debug("Executing %s", qnode) 

441 try: 

442 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

443 successCount += 1 

444 except Exception as exc: 

445 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty(): 

446 _LOG.error( 

447 "Task <%s dataId=%s> failed; dropping into pdb.", 

448 qnode.taskDef, 

449 qnode.quantum.dataId, 

450 exc_info=exc, 

451 ) 

452 try: 

453 pdb = importlib.import_module(self.pdb) 

454 except ImportError as imp_exc: 

455 raise MPGraphExecutorError( 

456 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}" 

457 ) from exc 

458 if not hasattr(pdb, "post_mortem"): 

459 raise MPGraphExecutorError( 

460 f"Specified debugger module ({self.pdb}) can't debug with post_mortem", 

461 ) from exc 

462 pdb.post_mortem(exc.__traceback__) 

463 failedNodes.add(qnode) 

464 self.report.status = ExecutionStatus.FAILURE 

465 if self.failFast: 

466 raise MPGraphExecutorError( 

467 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed." 

468 ) from exc 

469 else: 

470 # Note that there could be exception safety issues, which 

471 # we presently ignore. 

472 _LOG.error( 

473 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.", 

474 qnode.taskDef, 

475 qnode.quantum.dataId, 

476 exc_info=exc, 

477 ) 

478 finally: 

479 # sqlalchemy has some objects that can last until a garbage 

480 # collection cycle is run, which can happen at unpredictable 

481 # times, run a collection loop here explicitly. 

482 gc.collect() 

483 

484 quantum_report = self.quantumExecutor.getReport() 

485 if quantum_report: 

486 self.report.quantaReports.append(quantum_report) 

487 

488 _LOG.info( 

489 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

490 successCount, 

491 len(failedNodes), 

492 totalCount - successCount - len(failedNodes), 

493 totalCount, 

494 ) 

495 

496 # Raise an exception if there were any failures. 

497 if failedNodes: 

498 raise MPGraphExecutorError("One or more tasks failed during execution.") 

499 

500 def _executeQuantaMP(self, graph, butler): 

501 """Execute all Quanta in separate processes. 

502 

503 Parameters 

504 ---------- 

505 graph : `QuantumGraph` 

506 `QuantumGraph` that is to be executed. 

507 butler : `lsst.daf.butler.Butler` 

508 Data butler instance 

509 """ 

510 

511 disable_implicit_threading() # To prevent thread contention 

512 

513 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

514 

515 # re-pack input quantum data into jobs list 

516 jobs = _JobList(graph) 

517 

518 # check that all tasks can run in sub-process 

519 for job in jobs.jobs: 

520 taskDef = job.qnode.taskDef 

521 if not taskDef.taskClass.canMultiprocess: 

522 raise MPGraphExecutorError( 

523 f"Task {taskDef.taskName} does not support multiprocessing; use single process" 

524 ) 

525 

526 finishedCount, failedCount = 0, 0 

527 while jobs.pending or jobs.running: 

528 

529 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

530 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

531 

532 # See if any jobs have finished 

533 for job in jobs.running: 

534 if not job.process.is_alive(): 

535 _LOG.debug("finished: %s", job) 

536 # finished 

537 exitcode = job.process.exitcode 

538 quantum_report = job.report() 

539 self.report.quantaReports.append(quantum_report) 

540 if exitcode == 0: 

541 jobs.setJobState(job, JobState.FINISHED) 

542 job.cleanup() 

543 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

544 else: 

545 if job.terminated: 

546 # Was killed due to timeout. 

547 if self.report.status == ExecutionStatus.SUCCESS: 

548 # Do not override global FAILURE status 

549 self.report.status = ExecutionStatus.TIMEOUT 

550 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed" 

551 jobs.setJobState(job, JobState.TIMED_OUT) 

552 else: 

553 self.report.status = ExecutionStatus.FAILURE 

554 # failMessage() has to be called before cleanup() 

555 message = job.failMessage() 

556 jobs.setJobState(job, JobState.FAILED) 

557 

558 job.cleanup() 

559 _LOG.debug("failed: %s", job) 

560 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

561 # stop all running jobs 

562 for stopJob in jobs.running: 

563 if stopJob is not job: 

564 stopJob.stop() 

565 if job.state is JobState.TIMED_OUT: 

566 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

567 else: 

568 raise MPGraphExecutorError(message) 

569 else: 

570 _LOG.error("%s; processing will continue for remaining tasks.", message) 

571 else: 

572 # check for timeout 

573 now = time.time() 

574 if now - job.started > self.timeout: 

575 # Try to kill it, and there is a chance that it 

576 # finishes successfully before it gets killed. Exit 

577 # status is handled by the code above on next 

578 # iteration. 

579 _LOG.debug("Terminating job %s due to timeout", job) 

580 job.stop() 

581 

582 # Fail jobs whose inputs failed, this may need several iterations 

583 # if the order is not right, will be done in the next loop. 

584 if jobs.failedNodes: 

585 for job in jobs.pending: 

586 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

587 if jobInputNodes & jobs.failedNodes: 

588 quantum_report = QuantumReport( 

589 status=ExecutionStatus.SKIPPED, 

590 dataId=job.qnode.quantum.dataId, 

591 taskLabel=job.qnode.taskDef.label, 

592 ) 

593 self.report.quantaReports.append(quantum_report) 

594 jobs.setJobState(job, JobState.FAILED_DEP) 

595 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

596 

597 # see if we can start more jobs 

598 if len(jobs.running) < self.numProc: 

599 for job in jobs.pending: 

600 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

601 if jobInputNodes <= jobs.finishedNodes: 

602 # all dependencies have completed, can start new job 

603 if len(jobs.running) < self.numProc: 

604 _LOG.debug("Submitting %s", job) 

605 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

606 if len(jobs.running) >= self.numProc: 

607 # Cannot start any more jobs, wait until something 

608 # finishes. 

609 break 

610 

611 # Do cleanup for timed out jobs if necessary. 

612 jobs.cleanup() 

613 

614 # Print progress message if something changed. 

615 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

616 if (finishedCount, failedCount) != (newFinished, newFailed): 

617 finishedCount, failedCount = newFinished, newFailed 

618 totalCount = len(jobs.jobs) 

619 _LOG.info( 

620 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

621 finishedCount, 

622 failedCount, 

623 totalCount - finishedCount - failedCount, 

624 totalCount, 

625 ) 

626 

627 # Here we want to wait until one of the running jobs completes 

628 # but multiprocessing does not provide an API for that, for now 

629 # just sleep a little bit and go back to the loop. 

630 if jobs.running: 

631 time.sleep(0.1) 

632 

633 if jobs.failedNodes: 

634 # print list of failed jobs 

635 _LOG.error("Failed jobs:") 

636 for job in jobs.jobs: 

637 if job.state != JobState.FINISHED: 

638 _LOG.error(" - %s: %s", job.state.name, job) 

639 

640 # if any job failed raise an exception 

641 if jobs.failedNodes == jobs.timedOutNodes: 

642 raise MPTimeoutError("One or more tasks timed out during execution.") 

643 else: 

644 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.") 

645 

646 def getReport(self) -> Optional[Report]: 

647 # Docstring inherited from base class 

648 if self.report is None: 

649 raise RuntimeError("getReport() called before execute()") 

650 return self.report