Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 15%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

276 statements  

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24import gc 

25import logging 

26import multiprocessing 

27import pickle 

28import signal 

29import sys 

30import time 

31from enum import Enum 

32from typing import Optional 

33 

34from lsst.daf.butler.cli.cliLog import CliLog 

35from lsst.pipe.base import InvalidQuantumError 

36from lsst.pipe.base.graph.graph import QuantumGraph 

37from lsst.utils.threads import disable_implicit_threading 

38 

39# ----------------------------- 

40# Imports for other modules -- 

41# ----------------------------- 

42from .quantumGraphExecutor import QuantumGraphExecutor 

43from .reports import ExecutionStatus, QuantumReport, Report 

44 

45_LOG = logging.getLogger(__name__) 

46 

47 

48# Possible states for the executing task: 

49# - PENDING: job has not started yet 

50# - RUNNING: job is currently executing 

51# - FINISHED: job finished successfully 

52# - FAILED: job execution failed (process returned non-zero status) 

53# - TIMED_OUT: job is killed due to too long execution time 

54# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

55JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

56 

57 

58class _Job: 

59 """Class representing a job running single task. 

60 

61 Parameters 

62 ---------- 

63 qnode: `~lsst.pipe.base.QuantumNode` 

64 Quantum and some associated information. 

65 """ 

66 

67 def __init__(self, qnode): 

68 self.qnode = qnode 

69 self.process = None 

70 self._state = JobState.PENDING 

71 self.started = None 

72 self._rcv_conn = None 

73 self._terminated = False 

74 

75 @property 

76 def state(self): 

77 """Job processing state (JobState)""" 

78 return self._state 

79 

80 @property 

81 def terminated(self): 

82 """Return True if job was killed by stop() method and negative exit 

83 code is returned from child process. (`bool`)""" 

84 return self._terminated and self.process.exitcode < 0 

85 

86 def start(self, butler, quantumExecutor, startMethod=None): 

87 """Start process which runs the task. 

88 

89 Parameters 

90 ---------- 

91 butler : `lsst.daf.butler.Butler` 

92 Data butler instance. 

93 quantumExecutor : `QuantumExecutor` 

94 Executor for single quantum. 

95 startMethod : `str`, optional 

96 Start method from `multiprocessing` module. 

97 """ 

98 # Unpickling of quantum has to happen after butler, this is why 

99 # it is pickled manually here. 

100 quantum_pickle = pickle.dumps(self.qnode.quantum) 

101 taskDef = self.qnode.taskDef 

102 self._rcv_conn, snd_conn = multiprocessing.Pipe(False) 

103 logConfigState = CliLog.configState 

104 mp_ctx = multiprocessing.get_context(startMethod) 

105 self.process = mp_ctx.Process( 

106 target=_Job._executeJob, 

107 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn), 

108 name=f"task-{self.qnode.quantum.dataId}", 

109 ) 

110 self.process.start() 

111 self.started = time.time() 

112 self._state = JobState.RUNNING 

113 

114 @staticmethod 

115 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn): 

116 """Execute a job with arguments. 

117 

118 Parameters 

119 ---------- 

120 quantumExecutor : `QuantumExecutor` 

121 Executor for single quantum. 

122 taskDef : `bytes` 

123 Task definition structure. 

124 quantum_pickle : `bytes` 

125 Quantum for this task execution in pickled form. 

126 butler : `lss.daf.butler.Butler` 

127 Data butler instance. 

128 snd_conn : `multiprocessing.Connection` 

129 Connection to send job report to parent process. 

130 """ 

131 if logConfigState and not CliLog.configState: 

132 # means that we are in a new spawned Python process and we have to 

133 # re-initialize logging 

134 CliLog.replayConfigState(logConfigState) 

135 

136 # have to reset connection pool to avoid sharing connections 

137 if butler is not None: 

138 butler.registry.resetConnectionPool() 

139 

140 quantum = pickle.loads(quantum_pickle) 

141 try: 

142 quantumExecutor.execute(taskDef, quantum, butler) 

143 finally: 

144 # If sending fails we do not want this new exception to be exposed. 

145 try: 

146 report = quantumExecutor.getReport() 

147 snd_conn.send(report) 

148 except Exception: 

149 pass 

150 

151 def stop(self): 

152 """Stop the process.""" 

153 self.process.terminate() 

154 # give it 1 second to finish or KILL 

155 for i in range(10): 

156 time.sleep(0.1) 

157 if not self.process.is_alive(): 

158 break 

159 else: 

160 _LOG.debug("Killing process %s", self.process.name) 

161 self.process.kill() 

162 self._terminated = True 

163 

164 def cleanup(self): 

165 """Release processes resources, has to be called for each finished 

166 process. 

167 """ 

168 if self.process and not self.process.is_alive(): 

169 self.process.close() 

170 self.process = None 

171 self._rcv_conn = None 

172 

173 def report(self) -> QuantumReport: 

174 """Return task report, should be called after process finishes and 

175 before cleanup(). 

176 """ 

177 try: 

178 report = self._rcv_conn.recv() 

179 report.exitCode = self.process.exitcode 

180 except Exception: 

181 # Likely due to the process killed, but there may be other reasons. 

182 report = QuantumReport.from_exit_code( 

183 exitCode=self.process.exitcode, 

184 dataId=self.qnode.quantum.dataId, 

185 taskLabel=self.qnode.taskDef.label, 

186 ) 

187 if self.terminated: 

188 # Means it was killed, assume it's due to timeout 

189 report.status = ExecutionStatus.TIMEOUT 

190 return report 

191 

192 def failMessage(self): 

193 """Return a message describing task failure""" 

194 exitcode = self.process.exitcode 

195 if exitcode < 0: 

196 # Negative exit code means it is killed by signal 

197 signum = -exitcode 

198 msg = f"Task {self} failed, killed by signal {signum}" 

199 # Just in case this is some very odd signal, expect ValueError 

200 try: 

201 strsignal = signal.strsignal(signum) 

202 msg = f"{msg} ({strsignal})" 

203 except ValueError: 

204 pass 

205 elif exitcode > 0: 

206 msg = f"Task {self} failed, exit code={exitcode}" 

207 else: 

208 msg = None 

209 return msg 

210 

211 def __str__(self): 

212 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

213 

214 

215class _JobList: 

216 """Simple list of _Job instances with few convenience methods. 

217 

218 Parameters 

219 ---------- 

220 iterable : iterable of `~lsst.pipe.base.QuantumNode` 

221 Sequence of Quanta to execute. This has to be ordered according to 

222 task dependencies. 

223 """ 

224 

225 def __init__(self, iterable): 

226 self.jobs = [_Job(qnode) for qnode in iterable] 

227 self.pending = self.jobs[:] 

228 self.running = [] 

229 self.finishedNodes = set() 

230 self.failedNodes = set() 

231 self.timedOutNodes = set() 

232 

233 def submit(self, job, butler, quantumExecutor, startMethod=None): 

234 """Submit one more job for execution 

235 

236 Parameters 

237 ---------- 

238 job : `_Job` 

239 Job to submit. 

240 butler : `lsst.daf.butler.Butler` 

241 Data butler instance. 

242 quantumExecutor : `QuantumExecutor` 

243 Executor for single quantum. 

244 startMethod : `str`, optional 

245 Start method from `multiprocessing` module. 

246 """ 

247 # this will raise if job is not in pending list 

248 self.pending.remove(job) 

249 job.start(butler, quantumExecutor, startMethod) 

250 self.running.append(job) 

251 

252 def setJobState(self, job, state): 

253 """Update job state. 

254 

255 Parameters 

256 ---------- 

257 job : `_Job` 

258 Job to submit. 

259 state : `JobState` 

260 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

261 FAILED_DEP state is acceptable. 

262 """ 

263 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP) 

264 assert state in allowedStates, f"State {state} not allowed here" 

265 

266 # remove job from pending/running lists 

267 if job.state == JobState.PENDING: 

268 self.pending.remove(job) 

269 elif job.state == JobState.RUNNING: 

270 self.running.remove(job) 

271 

272 qnode = job.qnode 

273 # it should not be in any of these, but just in case 

274 self.finishedNodes.discard(qnode) 

275 self.failedNodes.discard(qnode) 

276 self.timedOutNodes.discard(qnode) 

277 

278 job._state = state 

279 if state == JobState.FINISHED: 

280 self.finishedNodes.add(qnode) 

281 elif state == JobState.FAILED: 

282 self.failedNodes.add(qnode) 

283 elif state == JobState.FAILED_DEP: 

284 self.failedNodes.add(qnode) 

285 elif state == JobState.TIMED_OUT: 

286 self.failedNodes.add(qnode) 

287 self.timedOutNodes.add(qnode) 

288 else: 

289 raise ValueError(f"Unexpected state value: {state}") 

290 

291 def cleanup(self): 

292 """Do periodic cleanup for jobs that did not finish correctly. 

293 

294 If timed out jobs are killed but take too long to stop then regular 

295 cleanup will not work for them. Here we check all timed out jobs 

296 periodically and do cleanup if they managed to die by this time. 

297 """ 

298 for job in self.jobs: 

299 if job.state == JobState.TIMED_OUT and job.process is not None: 

300 job.cleanup() 

301 

302 

303class MPGraphExecutorError(Exception): 

304 """Exception class for errors raised by MPGraphExecutor.""" 

305 

306 pass 

307 

308 

309class MPTimeoutError(MPGraphExecutorError): 

310 """Exception raised when task execution times out.""" 

311 

312 pass 

313 

314 

315class MPGraphExecutor(QuantumGraphExecutor): 

316 """Implementation of QuantumGraphExecutor using same-host multiprocess 

317 execution of Quanta. 

318 

319 Parameters 

320 ---------- 

321 numProc : `int` 

322 Number of processes to use for executing tasks. 

323 timeout : `float` 

324 Time in seconds to wait for tasks to finish. 

325 quantumExecutor : `QuantumExecutor` 

326 Executor for single quantum. For multiprocess-style execution when 

327 ``numProc`` is greater than one this instance must support pickle. 

328 startMethod : `str`, optional 

329 Start method from `multiprocessing` module, `None` selects the best 

330 one for current platform. 

331 failFast : `bool`, optional 

332 If set to ``True`` then stop processing on first error from any task. 

333 executionGraphFixup : `ExecutionGraphFixup`, optional 

334 Instance used for modification of execution graph. 

335 """ 

336 

337 def __init__( 

338 self, 

339 numProc, 

340 timeout, 

341 quantumExecutor, 

342 *, 

343 startMethod=None, 

344 failFast=False, 

345 executionGraphFixup=None, 

346 ): 

347 self.numProc = numProc 

348 self.timeout = timeout 

349 self.quantumExecutor = quantumExecutor 

350 self.failFast = failFast 

351 self.executionGraphFixup = executionGraphFixup 

352 self.report: Optional[Report] = None 

353 

354 # We set default start method as spawn for MacOS and fork for Linux; 

355 # None for all other platforms to use multiprocessing default. 

356 if startMethod is None: 

357 methods = dict(linux="fork", darwin="spawn") 

358 startMethod = methods.get(sys.platform) 

359 self.startMethod = startMethod 

360 

361 def execute(self, graph, butler): 

362 # Docstring inherited from QuantumGraphExecutor.execute 

363 graph = self._fixupQuanta(graph) 

364 self.report = Report() 

365 try: 

366 if self.numProc > 1: 

367 self._executeQuantaMP(graph, butler) 

368 else: 

369 self._executeQuantaInProcess(graph, butler) 

370 except Exception as exc: 

371 self.report.set_exception(exc) 

372 raise 

373 

374 def _fixupQuanta(self, graph: QuantumGraph): 

375 """Call fixup code to modify execution graph. 

376 

377 Parameters 

378 ---------- 

379 graph : `QuantumGraph` 

380 `QuantumGraph` to modify 

381 

382 Returns 

383 ------- 

384 graph : `QuantumGraph` 

385 Modified `QuantumGraph`. 

386 

387 Raises 

388 ------ 

389 MPGraphExecutorError 

390 Raised if execution graph cannot be ordered after modification, 

391 i.e. it has dependency cycles. 

392 """ 

393 if not self.executionGraphFixup: 

394 return graph 

395 

396 _LOG.debug("Call execution graph fixup method") 

397 graph = self.executionGraphFixup.fixupQuanta(graph) 

398 

399 # Detect if there is now a cycle created within the graph 

400 if graph.findCycle(): 

401 raise MPGraphExecutorError("Updated execution graph has dependency cycle.") 

402 

403 return graph 

404 

405 def _executeQuantaInProcess(self, graph, butler): 

406 """Execute all Quanta in current process. 

407 

408 Parameters 

409 ---------- 

410 graph : `QuantumGraph` 

411 `QuantumGraph` that is to be executed 

412 butler : `lsst.daf.butler.Butler` 

413 Data butler instance 

414 """ 

415 successCount, totalCount = 0, len(graph) 

416 failedNodes = set() 

417 for qnode in graph: 

418 

419 # Any failed inputs mean that the quantum has to be skipped. 

420 inputNodes = graph.determineInputsToQuantumNode(qnode) 

421 if inputNodes & failedNodes: 

422 _LOG.error( 

423 "Upstream job failed for task <%s dataId=%s>, skipping this task.", 

424 qnode.taskDef, 

425 qnode.quantum.dataId, 

426 ) 

427 failedNodes.add(qnode) 

428 quantum_report = QuantumReport( 

429 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label 

430 ) 

431 self.report.quantaReports.append(quantum_report) 

432 continue 

433 

434 _LOG.debug("Executing %s", qnode) 

435 try: 

436 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

437 successCount += 1 

438 except Exception as exc: 

439 failedNodes.add(qnode) 

440 self.report.status = ExecutionStatus.FAILURE 

441 if self.failFast: 

442 raise MPGraphExecutorError( 

443 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed." 

444 ) from exc 

445 else: 

446 # Note that there could be exception safety issues, which 

447 # we presently ignore. 

448 _LOG.error( 

449 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.", 

450 qnode.taskDef, 

451 qnode.quantum.dataId, 

452 exc_info=exc, 

453 ) 

454 finally: 

455 # sqlalchemy has some objects that can last until a garbage 

456 # collection cycle is run, which can happen at unpredictable 

457 # times, run a collection loop here explicitly. 

458 gc.collect() 

459 

460 quantum_report = self.quantumExecutor.getReport() 

461 if quantum_report: 

462 self.report.quantaReports.append(quantum_report) 

463 

464 _LOG.info( 

465 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

466 successCount, 

467 len(failedNodes), 

468 totalCount - successCount - len(failedNodes), 

469 totalCount, 

470 ) 

471 

472 # Raise an exception if there were any failures. 

473 if failedNodes: 

474 raise MPGraphExecutorError("One or more tasks failed during execution.") 

475 

476 def _executeQuantaMP(self, graph, butler): 

477 """Execute all Quanta in separate processes. 

478 

479 Parameters 

480 ---------- 

481 graph : `QuantumGraph` 

482 `QuantumGraph` that is to be executed. 

483 butler : `lsst.daf.butler.Butler` 

484 Data butler instance 

485 """ 

486 

487 disable_implicit_threading() # To prevent thread contention 

488 

489 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

490 

491 # re-pack input quantum data into jobs list 

492 jobs = _JobList(graph) 

493 

494 # check that all tasks can run in sub-process 

495 for job in jobs.jobs: 

496 taskDef = job.qnode.taskDef 

497 if not taskDef.taskClass.canMultiprocess: 

498 raise MPGraphExecutorError( 

499 f"Task {taskDef.taskName} does not support multiprocessing; use single process" 

500 ) 

501 

502 finishedCount, failedCount = 0, 0 

503 while jobs.pending or jobs.running: 

504 

505 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

506 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

507 

508 # See if any jobs have finished 

509 for job in jobs.running: 

510 if not job.process.is_alive(): 

511 _LOG.debug("finished: %s", job) 

512 # finished 

513 exitcode = job.process.exitcode 

514 quantum_report = job.report() 

515 self.report.quantaReports.append(quantum_report) 

516 if exitcode == 0: 

517 jobs.setJobState(job, JobState.FINISHED) 

518 job.cleanup() 

519 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

520 else: 

521 if job.terminated: 

522 # Was killed due to timeout. 

523 if self.report.status == ExecutionStatus.SUCCESS: 

524 # Do not override global FAILURE status 

525 self.report.status = ExecutionStatus.TIMEOUT 

526 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed" 

527 jobs.setJobState(job, JobState.TIMED_OUT) 

528 else: 

529 self.report.status = ExecutionStatus.FAILURE 

530 # failMessage() has to be called before cleanup() 

531 message = job.failMessage() 

532 jobs.setJobState(job, JobState.FAILED) 

533 

534 job.cleanup() 

535 _LOG.debug("failed: %s", job) 

536 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

537 # stop all running jobs 

538 for stopJob in jobs.running: 

539 if stopJob is not job: 

540 stopJob.stop() 

541 if job.state is JobState.TIMED_OUT: 

542 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

543 else: 

544 raise MPGraphExecutorError(message) 

545 else: 

546 _LOG.error("%s; processing will continue for remaining tasks.", message) 

547 else: 

548 # check for timeout 

549 now = time.time() 

550 if now - job.started > self.timeout: 

551 # Try to kill it, and there is a chance that it 

552 # finishes successfully before it gets killed. Exit 

553 # status is handled by the code above on next 

554 # iteration. 

555 _LOG.debug("Terminating job %s due to timeout", job) 

556 job.stop() 

557 

558 # Fail jobs whose inputs failed, this may need several iterations 

559 # if the order is not right, will be done in the next loop. 

560 if jobs.failedNodes: 

561 for job in jobs.pending: 

562 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

563 if jobInputNodes & jobs.failedNodes: 

564 quantum_report = QuantumReport( 

565 status=ExecutionStatus.SKIPPED, 

566 dataId=job.qnode.quantum.dataId, 

567 taskLabel=job.qnode.taskDef.label, 

568 ) 

569 self.report.quantaReports.append(quantum_report) 

570 jobs.setJobState(job, JobState.FAILED_DEP) 

571 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

572 

573 # see if we can start more jobs 

574 if len(jobs.running) < self.numProc: 

575 for job in jobs.pending: 

576 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

577 if jobInputNodes <= jobs.finishedNodes: 

578 # all dependencies have completed, can start new job 

579 if len(jobs.running) < self.numProc: 

580 _LOG.debug("Submitting %s", job) 

581 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

582 if len(jobs.running) >= self.numProc: 

583 # Cannot start any more jobs, wait until something 

584 # finishes. 

585 break 

586 

587 # Do cleanup for timed out jobs if necessary. 

588 jobs.cleanup() 

589 

590 # Print progress message if something changed. 

591 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

592 if (finishedCount, failedCount) != (newFinished, newFailed): 

593 finishedCount, failedCount = newFinished, newFailed 

594 totalCount = len(jobs.jobs) 

595 _LOG.info( 

596 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

597 finishedCount, 

598 failedCount, 

599 totalCount - finishedCount - failedCount, 

600 totalCount, 

601 ) 

602 

603 # Here we want to wait until one of the running jobs completes 

604 # but multiprocessing does not provide an API for that, for now 

605 # just sleep a little bit and go back to the loop. 

606 if jobs.running: 

607 time.sleep(0.1) 

608 

609 if jobs.failedNodes: 

610 # print list of failed jobs 

611 _LOG.error("Failed jobs:") 

612 for job in jobs.jobs: 

613 if job.state != JobState.FINISHED: 

614 _LOG.error(" - %s: %s", job.state.name, job) 

615 

616 # if any job failed raise an exception 

617 if jobs.failedNodes == jobs.timedOutNodes: 

618 raise MPTimeoutError("One or more tasks timed out during execution.") 

619 else: 

620 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.") 

621 

622 def getReport(self) -> Optional[Report]: 

623 # Docstring inherited from base class 

624 if self.report is None: 

625 raise RuntimeError("getReport() called before execute()") 

626 return self.report