Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

274 statements  

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24import gc 

25import logging 

26import multiprocessing 

27import pickle 

28import signal 

29import sys 

30import time 

31from enum import Enum 

32from typing import Optional 

33 

34from lsst.daf.butler.cli.cliLog import CliLog 

35from lsst.pipe.base import InvalidQuantumError 

36from lsst.pipe.base.graph.graph import QuantumGraph 

37from lsst.utils.threads import disable_implicit_threading 

38 

39# ----------------------------- 

40# Imports for other modules -- 

41# ----------------------------- 

42from .quantumGraphExecutor import QuantumGraphExecutor 

43from .reports import ExecutionStatus, QuantumReport, Report 

44 

45_LOG = logging.getLogger(__name__) 

46 

47 

48# Possible states for the executing task: 

49# - PENDING: job has not started yet 

50# - RUNNING: job is currently executing 

51# - FINISHED: job finished successfully 

52# - FAILED: job execution failed (process returned non-zero status) 

53# - TIMED_OUT: job is killed due to too long execution time 

54# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

55JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

56 

57 

58class _Job: 

59 """Class representing a job running single task. 

60 

61 Parameters 

62 ---------- 

63 qnode: `~lsst.pipe.base.QuantumNode` 

64 Quantum and some associated information. 

65 """ 

66 

67 def __init__(self, qnode): 

68 self.qnode = qnode 

69 self.process = None 

70 self._state = JobState.PENDING 

71 self.started = None 

72 self._rcv_conn = None 

73 

74 @property 

75 def state(self): 

76 """Job processing state (JobState)""" 

77 return self._state 

78 

79 def start(self, butler, quantumExecutor, startMethod=None): 

80 """Start process which runs the task. 

81 

82 Parameters 

83 ---------- 

84 butler : `lsst.daf.butler.Butler` 

85 Data butler instance. 

86 quantumExecutor : `QuantumExecutor` 

87 Executor for single quantum. 

88 startMethod : `str`, optional 

89 Start method from `multiprocessing` module. 

90 """ 

91 # Unpickling of quantum has to happen after butler, this is why 

92 # it is pickled manually here. 

93 quantum_pickle = pickle.dumps(self.qnode.quantum) 

94 taskDef = self.qnode.taskDef 

95 self._rcv_conn, snd_conn = multiprocessing.Pipe(False) 

96 logConfigState = CliLog.configState 

97 mp_ctx = multiprocessing.get_context(startMethod) 

98 self.process = mp_ctx.Process( 

99 target=_Job._executeJob, 

100 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn), 

101 name=f"task-{self.qnode.quantum.dataId}", 

102 ) 

103 self.process.start() 

104 self.started = time.time() 

105 self._state = JobState.RUNNING 

106 

107 @staticmethod 

108 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn): 

109 """Execute a job with arguments. 

110 

111 Parameters 

112 ---------- 

113 quantumExecutor : `QuantumExecutor` 

114 Executor for single quantum. 

115 taskDef : `bytes` 

116 Task definition structure. 

117 quantum_pickle : `bytes` 

118 Quantum for this task execution in pickled form. 

119 butler : `lss.daf.butler.Butler` 

120 Data butler instance. 

121 snd_conn : `multiprocessing.Connection` 

122 Connection to send job report to parent process. 

123 """ 

124 if logConfigState and not CliLog.configState: 

125 # means that we are in a new spawned Python process and we have to 

126 # re-initialize logging 

127 CliLog.replayConfigState(logConfigState) 

128 

129 # have to reset connection pool to avoid sharing connections 

130 if butler is not None: 

131 butler.registry.resetConnectionPool() 

132 

133 quantum = pickle.loads(quantum_pickle) 

134 try: 

135 quantumExecutor.execute(taskDef, quantum, butler) 

136 finally: 

137 # If sending fails we do not want this new exception to be exposed. 

138 try: 

139 report = quantumExecutor.getReport() 

140 snd_conn.send(report) 

141 except Exception: 

142 pass 

143 

144 def stop(self): 

145 """Stop the process.""" 

146 self.process.terminate() 

147 # give it 1 second to finish or KILL 

148 for i in range(10): 

149 time.sleep(0.1) 

150 if not self.process.is_alive(): 

151 break 

152 else: 

153 _LOG.debug("Killing process %s", self.process.name) 

154 self.process.kill() 

155 

156 def cleanup(self): 

157 """Release processes resources, has to be called for each finished 

158 process. 

159 """ 

160 if self.process and not self.process.is_alive(): 

161 self.process.close() 

162 self.process = None 

163 self._rcv_conn = None 

164 

165 def report(self) -> QuantumReport: 

166 """Return task report, should be called after process finishes and 

167 before cleanup(). 

168 """ 

169 try: 

170 report = self._rcv_conn.recv() 

171 report.exitCode = self.process.exitcode 

172 except Exception: 

173 # Likely due to the process killed, but there may be other reasons. 

174 report = QuantumReport.from_exit_code( 

175 exitCode=self.process.exitcode, 

176 dataId=self.qnode.quantum.dataId, 

177 taskLabel=self.qnode.taskDef.label, 

178 ) 

179 return report 

180 

181 def failMessage(self): 

182 """Return a message describing task failure""" 

183 exitcode = self.process.exitcode 

184 if exitcode < 0: 

185 # Negative exit code means it is killed by signal 

186 signum = -exitcode 

187 msg = f"Task {self} failed, killed by signal {signum}" 

188 # Just in case this is some very odd signal, expect ValueError 

189 try: 

190 strsignal = signal.strsignal(signum) 

191 msg = f"{msg} ({strsignal})" 

192 except ValueError: 

193 pass 

194 elif exitcode > 0: 

195 msg = f"Task {self} failed, exit code={exitcode}" 

196 else: 

197 msg = None 

198 return msg 

199 

200 def __str__(self): 

201 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

202 

203 

204class _JobList: 

205 """Simple list of _Job instances with few convenience methods. 

206 

207 Parameters 

208 ---------- 

209 iterable : iterable of `~lsst.pipe.base.QuantumNode` 

210 Sequence of Quanta to execute. This has to be ordered according to 

211 task dependencies. 

212 """ 

213 

214 def __init__(self, iterable): 

215 self.jobs = [_Job(qnode) for qnode in iterable] 

216 self.pending = self.jobs[:] 

217 self.running = [] 

218 self.finishedNodes = set() 

219 self.failedNodes = set() 

220 self.timedOutNodes = set() 

221 

222 def submit(self, job, butler, quantumExecutor, startMethod=None): 

223 """Submit one more job for execution 

224 

225 Parameters 

226 ---------- 

227 job : `_Job` 

228 Job to submit. 

229 butler : `lsst.daf.butler.Butler` 

230 Data butler instance. 

231 quantumExecutor : `QuantumExecutor` 

232 Executor for single quantum. 

233 startMethod : `str`, optional 

234 Start method from `multiprocessing` module. 

235 """ 

236 # this will raise if job is not in pending list 

237 self.pending.remove(job) 

238 job.start(butler, quantumExecutor, startMethod) 

239 self.running.append(job) 

240 

241 def setJobState(self, job, state): 

242 """Update job state. 

243 

244 Parameters 

245 ---------- 

246 job : `_Job` 

247 Job to submit. 

248 state : `JobState` 

249 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

250 FAILED_DEP state is acceptable. 

251 """ 

252 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP) 

253 assert state in allowedStates, f"State {state} not allowed here" 

254 

255 # remove job from pending/running lists 

256 if job.state == JobState.PENDING: 

257 self.pending.remove(job) 

258 elif job.state == JobState.RUNNING: 

259 self.running.remove(job) 

260 

261 qnode = job.qnode 

262 # it should not be in any of these, but just in case 

263 self.finishedNodes.discard(qnode) 

264 self.failedNodes.discard(qnode) 

265 self.timedOutNodes.discard(qnode) 

266 

267 job._state = state 

268 if state == JobState.FINISHED: 

269 self.finishedNodes.add(qnode) 

270 elif state == JobState.FAILED: 

271 self.failedNodes.add(qnode) 

272 elif state == JobState.FAILED_DEP: 

273 self.failedNodes.add(qnode) 

274 elif state == JobState.TIMED_OUT: 

275 self.failedNodes.add(qnode) 

276 self.timedOutNodes.add(qnode) 

277 else: 

278 raise ValueError(f"Unexpected state value: {state}") 

279 

280 def cleanup(self): 

281 """Do periodic cleanup for jobs that did not finish correctly. 

282 

283 If timed out jobs are killed but take too long to stop then regular 

284 cleanup will not work for them. Here we check all timed out jobs 

285 periodically and do cleanup if they managed to die by this time. 

286 """ 

287 for job in self.jobs: 

288 if job.state == JobState.TIMED_OUT and job.process is not None: 

289 job.cleanup() 

290 

291 

292class MPGraphExecutorError(Exception): 

293 """Exception class for errors raised by MPGraphExecutor.""" 

294 

295 pass 

296 

297 

298class MPTimeoutError(MPGraphExecutorError): 

299 """Exception raised when task execution times out.""" 

300 

301 pass 

302 

303 

304class MPGraphExecutor(QuantumGraphExecutor): 

305 """Implementation of QuantumGraphExecutor using same-host multiprocess 

306 execution of Quanta. 

307 

308 Parameters 

309 ---------- 

310 numProc : `int` 

311 Number of processes to use for executing tasks. 

312 timeout : `float` 

313 Time in seconds to wait for tasks to finish. 

314 quantumExecutor : `QuantumExecutor` 

315 Executor for single quantum. For multiprocess-style execution when 

316 ``numProc`` is greater than one this instance must support pickle. 

317 startMethod : `str`, optional 

318 Start method from `multiprocessing` module, `None` selects the best 

319 one for current platform. 

320 failFast : `bool`, optional 

321 If set to ``True`` then stop processing on first error from any task. 

322 executionGraphFixup : `ExecutionGraphFixup`, optional 

323 Instance used for modification of execution graph. 

324 """ 

325 

326 def __init__( 

327 self, 

328 numProc, 

329 timeout, 

330 quantumExecutor, 

331 *, 

332 startMethod=None, 

333 failFast=False, 

334 executionGraphFixup=None, 

335 ): 

336 self.numProc = numProc 

337 self.timeout = timeout 

338 self.quantumExecutor = quantumExecutor 

339 self.failFast = failFast 

340 self.executionGraphFixup = executionGraphFixup 

341 self.report: Optional[Report] = None 

342 

343 # We set default start method as spawn for MacOS and fork for Linux; 

344 # None for all other platforms to use multiprocessing default. 

345 if startMethod is None: 

346 methods = dict(linux="fork", darwin="spawn") 

347 startMethod = methods.get(sys.platform) 

348 self.startMethod = startMethod 

349 

350 def execute(self, graph, butler): 

351 # Docstring inherited from QuantumGraphExecutor.execute 

352 graph = self._fixupQuanta(graph) 

353 self.report = Report() 

354 try: 

355 if self.numProc > 1: 

356 self._executeQuantaMP(graph, butler) 

357 else: 

358 self._executeQuantaInProcess(graph, butler) 

359 except Exception as exc: 

360 self.report.set_exception(exc) 

361 raise 

362 

363 def _fixupQuanta(self, graph: QuantumGraph): 

364 """Call fixup code to modify execution graph. 

365 

366 Parameters 

367 ---------- 

368 graph : `QuantumGraph` 

369 `QuantumGraph` to modify 

370 

371 Returns 

372 ------- 

373 graph : `QuantumGraph` 

374 Modified `QuantumGraph`. 

375 

376 Raises 

377 ------ 

378 MPGraphExecutorError 

379 Raised if execution graph cannot be ordered after modification, 

380 i.e. it has dependency cycles. 

381 """ 

382 if not self.executionGraphFixup: 

383 return graph 

384 

385 _LOG.debug("Call execution graph fixup method") 

386 graph = self.executionGraphFixup.fixupQuanta(graph) 

387 

388 # Detect if there is now a cycle created within the graph 

389 if graph.findCycle(): 

390 raise MPGraphExecutorError("Updated execution graph has dependency cycle.") 

391 

392 return graph 

393 

394 def _executeQuantaInProcess(self, graph, butler): 

395 """Execute all Quanta in current process. 

396 

397 Parameters 

398 ---------- 

399 graph : `QuantumGraph` 

400 `QuantumGraph` that is to be executed 

401 butler : `lsst.daf.butler.Butler` 

402 Data butler instance 

403 """ 

404 successCount, totalCount = 0, len(graph) 

405 failedNodes = set() 

406 for qnode in graph: 

407 

408 # Any failed inputs mean that the quantum has to be skipped. 

409 inputNodes = graph.determineInputsToQuantumNode(qnode) 

410 if inputNodes & failedNodes: 

411 _LOG.error( 

412 "Upstream job failed for task <%s dataId=%s>, skipping this task.", 

413 qnode.taskDef, 

414 qnode.quantum.dataId, 

415 ) 

416 failedNodes.add(qnode) 

417 quantum_report = QuantumReport( 

418 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label 

419 ) 

420 self.report.quantaReports.append(quantum_report) 

421 continue 

422 

423 _LOG.debug("Executing %s", qnode) 

424 try: 

425 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

426 successCount += 1 

427 except Exception as exc: 

428 failedNodes.add(qnode) 

429 self.report.status = ExecutionStatus.FAILURE 

430 if self.failFast: 

431 raise MPGraphExecutorError( 

432 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed." 

433 ) from exc 

434 else: 

435 # Note that there could be exception safety issues, which 

436 # we presently ignore. 

437 _LOG.error( 

438 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.", 

439 qnode.taskDef, 

440 qnode.quantum.dataId, 

441 exc_info=exc, 

442 ) 

443 finally: 

444 # sqlalchemy has some objects that can last until a garbage 

445 # collection cycle is run, which can happen at unpredictable 

446 # times, run a collection loop here explicitly. 

447 gc.collect() 

448 

449 quantum_report = self.quantumExecutor.getReport() 

450 if quantum_report: 

451 self.report.quantaReports.append(quantum_report) 

452 

453 _LOG.info( 

454 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

455 successCount, 

456 len(failedNodes), 

457 totalCount - successCount - len(failedNodes), 

458 totalCount, 

459 ) 

460 

461 # Raise an exception if there were any failures. 

462 if failedNodes: 

463 raise MPGraphExecutorError("One or more tasks failed during execution.") 

464 

465 def _executeQuantaMP(self, graph, butler): 

466 """Execute all Quanta in separate processes. 

467 

468 Parameters 

469 ---------- 

470 graph : `QuantumGraph` 

471 `QuantumGraph` that is to be executed. 

472 butler : `lsst.daf.butler.Butler` 

473 Data butler instance 

474 """ 

475 

476 disable_implicit_threading() # To prevent thread contention 

477 

478 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

479 

480 # re-pack input quantum data into jobs list 

481 jobs = _JobList(graph) 

482 

483 # check that all tasks can run in sub-process 

484 for job in jobs.jobs: 

485 taskDef = job.qnode.taskDef 

486 if not taskDef.taskClass.canMultiprocess: 

487 raise MPGraphExecutorError( 

488 f"Task {taskDef.taskName} does not support multiprocessing; use single process" 

489 ) 

490 

491 finishedCount, failedCount = 0, 0 

492 while jobs.pending or jobs.running: 

493 

494 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

495 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

496 

497 # See if any jobs have finished 

498 for job in jobs.running: 

499 if not job.process.is_alive(): 

500 _LOG.debug("finished: %s", job) 

501 # finished 

502 exitcode = job.process.exitcode 

503 quantum_report = job.report() 

504 if quantum_report: 

505 self.report.quantaReports.append(quantum_report) 

506 if exitcode == 0: 

507 jobs.setJobState(job, JobState.FINISHED) 

508 job.cleanup() 

509 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

510 else: 

511 self.report.status = ExecutionStatus.FAILURE 

512 # failMessage() has to be called before cleanup() 

513 message = job.failMessage() 

514 jobs.setJobState(job, JobState.FAILED) 

515 job.cleanup() 

516 _LOG.debug("failed: %s", job) 

517 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

518 # stop all running jobs 

519 for stopJob in jobs.running: 

520 if stopJob is not job: 

521 stopJob.stop() 

522 raise MPGraphExecutorError(message) 

523 else: 

524 _LOG.error("%s; processing will continue for remaining tasks.", message) 

525 else: 

526 # check for timeout 

527 now = time.time() 

528 if now - job.started > self.timeout: 

529 # Do not override FAILURE status 

530 if self.report.status == ExecutionStatus.SUCCESS: 

531 self.report.status = ExecutionStatus.TIMEOUT 

532 jobs.setJobState(job, JobState.TIMED_OUT) 

533 _LOG.debug("Terminating job %s due to timeout", job) 

534 job.stop() 

535 quantum_report = job.report() 

536 if quantum_report: 

537 quantum_report.status = ExecutionStatus.TIMEOUT 

538 self.report.quantaReports.append(quantum_report) 

539 job.cleanup() 

540 if self.failFast: 

541 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

542 else: 

543 _LOG.error( 

544 "Timeout (%s sec) for task %s; task is killed, processing continues " 

545 "for remaining tasks.", 

546 self.timeout, 

547 job, 

548 ) 

549 

550 # Fail jobs whose inputs failed, this may need several iterations 

551 # if the order is not right, will be done in the next loop. 

552 if jobs.failedNodes: 

553 for job in jobs.pending: 

554 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

555 if jobInputNodes & jobs.failedNodes: 

556 quantum_report = QuantumReport( 

557 status=ExecutionStatus.SKIPPED, 

558 dataId=job.qnode.quantum.dataId, 

559 taskLabel=job.qnode.taskDef.label, 

560 ) 

561 self.report.quantaReports.append(quantum_report) 

562 jobs.setJobState(job, JobState.FAILED_DEP) 

563 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

564 

565 # see if we can start more jobs 

566 if len(jobs.running) < self.numProc: 

567 for job in jobs.pending: 

568 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

569 if jobInputNodes <= jobs.finishedNodes: 

570 # all dependencies have completed, can start new job 

571 if len(jobs.running) < self.numProc: 

572 _LOG.debug("Submitting %s", job) 

573 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

574 if len(jobs.running) >= self.numProc: 

575 # Cannot start any more jobs, wait until something 

576 # finishes. 

577 break 

578 

579 # Do cleanup for timed out jobs if necessary. 

580 jobs.cleanup() 

581 

582 # Print progress message if something changed. 

583 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

584 if (finishedCount, failedCount) != (newFinished, newFailed): 

585 finishedCount, failedCount = newFinished, newFailed 

586 totalCount = len(jobs.jobs) 

587 _LOG.info( 

588 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

589 finishedCount, 

590 failedCount, 

591 totalCount - finishedCount - failedCount, 

592 totalCount, 

593 ) 

594 

595 # Here we want to wait until one of the running jobs completes 

596 # but multiprocessing does not provide an API for that, for now 

597 # just sleep a little bit and go back to the loop. 

598 if jobs.running: 

599 time.sleep(0.1) 

600 

601 if jobs.failedNodes: 

602 # print list of failed jobs 

603 _LOG.error("Failed jobs:") 

604 for job in jobs.jobs: 

605 if job.state != JobState.FINISHED: 

606 _LOG.error(" - %s: %s", job.state.name, job) 

607 

608 # if any job failed raise an exception 

609 if jobs.failedNodes == jobs.timedOutNodes: 

610 raise MPTimeoutError("One or more tasks timed out during execution.") 

611 else: 

612 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.") 

613 

614 def getReport(self) -> Optional[Report]: 

615 # Docstring inherited from base class 

616 if self.report is None: 

617 raise RuntimeError("getReport() called before execute()") 

618 return self.report