Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 15%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

228 statements  

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24import gc 

25import logging 

26import multiprocessing 

27import pickle 

28import signal 

29import sys 

30import time 

31 

32# ------------------------------- 

33# Imports of standard modules -- 

34# ------------------------------- 

35from enum import Enum 

36 

37from lsst.base import disableImplicitThreading 

38from lsst.daf.butler.cli.cliLog import CliLog 

39from lsst.pipe.base import InvalidQuantumError 

40from lsst.pipe.base.graph.graph import QuantumGraph 

41 

42# ----------------------------- 

43# Imports for other modules -- 

44# ----------------------------- 

45from .quantumGraphExecutor import QuantumGraphExecutor 

46 

47_LOG = logging.getLogger(__name__) 

48 

49 

50# Possible states for the executing task: 

51# - PENDING: job has not started yet 

52# - RUNNING: job is currently executing 

53# - FINISHED: job finished successfully 

54# - FAILED: job execution failed (process returned non-zero status) 

55# - TIMED_OUT: job is killed due to too long execution time 

56# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

57JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

58 

59 

60class _Job: 

61 """Class representing a job running single task. 

62 

63 Parameters 

64 ---------- 

65 qnode: `~lsst.pipe.base.QuantumNode` 

66 Quantum and some associated information. 

67 """ 

68 

69 def __init__(self, qnode): 

70 self.qnode = qnode 

71 self.process = None 

72 self._state = JobState.PENDING 

73 self.started = None 

74 

75 @property 

76 def state(self): 

77 """Job processing state (JobState)""" 

78 return self._state 

79 

80 def start(self, butler, quantumExecutor, startMethod=None): 

81 """Start process which runs the task. 

82 

83 Parameters 

84 ---------- 

85 butler : `lsst.daf.butler.Butler` 

86 Data butler instance. 

87 quantumExecutor : `QuantumExecutor` 

88 Executor for single quantum. 

89 startMethod : `str`, optional 

90 Start method from `multiprocessing` module. 

91 """ 

92 # Unpickling of quantum has to happen after butler, this is why 

93 # it is pickled manually here. 

94 quantum_pickle = pickle.dumps(self.qnode.quantum) 

95 taskDef = self.qnode.taskDef 

96 logConfigState = CliLog.configState 

97 mp_ctx = multiprocessing.get_context(startMethod) 

98 self.process = mp_ctx.Process( 

99 target=_Job._executeJob, 

100 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState), 

101 name=f"task-{self.qnode.quantum.dataId}", 

102 ) 

103 self.process.start() 

104 self.started = time.time() 

105 self._state = JobState.RUNNING 

106 

107 @staticmethod 

108 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState): 

109 """Execute a job with arguments. 

110 

111 Parameters 

112 ---------- 

113 quantumExecutor : `QuantumExecutor` 

114 Executor for single quantum. 

115 taskDef : `bytes` 

116 Task definition structure. 

117 quantum_pickle : `bytes` 

118 Quantum for this task execution in pickled form. 

119 butler : `lss.daf.butler.Butler` 

120 Data butler instance. 

121 """ 

122 if logConfigState and not CliLog.configState: 

123 # means that we are in a new spawned Python process and we have to 

124 # re-initialize logging 

125 CliLog.replayConfigState(logConfigState) 

126 

127 # have to reset connection pool to avoid sharing connections 

128 if butler is not None: 

129 butler.registry.resetConnectionPool() 

130 

131 quantum = pickle.loads(quantum_pickle) 

132 quantumExecutor.execute(taskDef, quantum, butler) 

133 

134 def stop(self): 

135 """Stop the process.""" 

136 self.process.terminate() 

137 # give it 1 second to finish or KILL 

138 for i in range(10): 

139 time.sleep(0.1) 

140 if not self.process.is_alive(): 

141 break 

142 else: 

143 _LOG.debug("Killing process %s", self.process.name) 

144 self.process.kill() 

145 

146 def cleanup(self): 

147 """Release processes resources, has to be called for each finished 

148 process. 

149 """ 

150 if self.process and not self.process.is_alive(): 

151 self.process.close() 

152 self.process = None 

153 

154 def failMessage(self): 

155 """Return a message describing task failure""" 

156 exitcode = self.process.exitcode 

157 if exitcode < 0: 

158 # Negative exit code means it is killed by signal 

159 signum = -exitcode 

160 msg = f"Task {self} failed, killed by signal {signum}" 

161 # Just in case this is some very odd signal, expect ValueError 

162 try: 

163 strsignal = signal.strsignal(signum) 

164 msg = f"{msg} ({strsignal})" 

165 except ValueError: 

166 pass 

167 elif exitcode > 0: 

168 msg = f"Task {self} failed, exit code={exitcode}" 

169 else: 

170 msg = None 

171 return msg 

172 

173 def __str__(self): 

174 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

175 

176 

177class _JobList: 

178 """Simple list of _Job instances with few convenience methods. 

179 

180 Parameters 

181 ---------- 

182 iterable : iterable of `~lsst.pipe.base.QuantumNode` 

183 Sequence of Quanta to execute. This has to be ordered according to 

184 task dependencies. 

185 """ 

186 

187 def __init__(self, iterable): 

188 self.jobs = [_Job(qnode) for qnode in iterable] 

189 self.pending = self.jobs[:] 

190 self.running = [] 

191 self.finishedNodes = set() 

192 self.failedNodes = set() 

193 self.timedOutNodes = set() 

194 

195 def submit(self, job, butler, quantumExecutor, startMethod=None): 

196 """Submit one more job for execution 

197 

198 Parameters 

199 ---------- 

200 job : `_Job` 

201 Job to submit. 

202 butler : `lsst.daf.butler.Butler` 

203 Data butler instance. 

204 quantumExecutor : `QuantumExecutor` 

205 Executor for single quantum. 

206 startMethod : `str`, optional 

207 Start method from `multiprocessing` module. 

208 """ 

209 # this will raise if job is not in pending list 

210 self.pending.remove(job) 

211 job.start(butler, quantumExecutor, startMethod) 

212 self.running.append(job) 

213 

214 def setJobState(self, job, state): 

215 """Update job state. 

216 

217 Parameters 

218 ---------- 

219 job : `_Job` 

220 Job to submit. 

221 state : `JobState` 

222 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

223 FAILED_DEP state is acceptable. 

224 """ 

225 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP) 

226 assert state in allowedStates, f"State {state} not allowed here" 

227 

228 # remove job from pending/running lists 

229 if job.state == JobState.PENDING: 

230 self.pending.remove(job) 

231 elif job.state == JobState.RUNNING: 

232 self.running.remove(job) 

233 

234 qnode = job.qnode 

235 # it should not be in any of these, but just in case 

236 self.finishedNodes.discard(qnode) 

237 self.failedNodes.discard(qnode) 

238 self.timedOutNodes.discard(qnode) 

239 

240 job._state = state 

241 if state == JobState.FINISHED: 

242 self.finishedNodes.add(qnode) 

243 elif state == JobState.FAILED: 

244 self.failedNodes.add(qnode) 

245 elif state == JobState.FAILED_DEP: 

246 self.failedNodes.add(qnode) 

247 elif state == JobState.TIMED_OUT: 

248 self.failedNodes.add(qnode) 

249 self.timedOutNodes.add(qnode) 

250 else: 

251 raise ValueError(f"Unexpected state value: {state}") 

252 

253 def cleanup(self): 

254 """Do periodic cleanup for jobs that did not finish correctly. 

255 

256 If timed out jobs are killed but take too long to stop then regular 

257 cleanup will not work for them. Here we check all timed out jobs 

258 periodically and do cleanup if they managed to die by this time. 

259 """ 

260 for job in self.jobs: 

261 if job.state == JobState.TIMED_OUT and job.process is not None: 

262 job.cleanup() 

263 

264 

265class MPGraphExecutorError(Exception): 

266 """Exception class for errors raised by MPGraphExecutor.""" 

267 

268 pass 

269 

270 

271class MPTimeoutError(MPGraphExecutorError): 

272 """Exception raised when task execution times out.""" 

273 

274 pass 

275 

276 

277class MPGraphExecutor(QuantumGraphExecutor): 

278 """Implementation of QuantumGraphExecutor using same-host multiprocess 

279 execution of Quanta. 

280 

281 Parameters 

282 ---------- 

283 numProc : `int` 

284 Number of processes to use for executing tasks. 

285 timeout : `float` 

286 Time in seconds to wait for tasks to finish. 

287 quantumExecutor : `QuantumExecutor` 

288 Executor for single quantum. For multiprocess-style execution when 

289 ``numProc`` is greater than one this instance must support pickle. 

290 startMethod : `str`, optional 

291 Start method from `multiprocessing` module, `None` selects the best 

292 one for current platform. 

293 failFast : `bool`, optional 

294 If set to ``True`` then stop processing on first error from any task. 

295 executionGraphFixup : `ExecutionGraphFixup`, optional 

296 Instance used for modification of execution graph. 

297 """ 

298 

299 def __init__( 

300 self, numProc, timeout, quantumExecutor, *, startMethod=None, failFast=False, executionGraphFixup=None 

301 ): 

302 self.numProc = numProc 

303 self.timeout = timeout 

304 self.quantumExecutor = quantumExecutor 

305 self.failFast = failFast 

306 self.executionGraphFixup = executionGraphFixup 

307 

308 # We set default start method as spawn for MacOS and fork for Linux; 

309 # None for all other platforms to use multiprocessing default. 

310 if startMethod is None: 

311 methods = dict(linux="fork", darwin="spawn") 

312 startMethod = methods.get(sys.platform) 

313 self.startMethod = startMethod 

314 

315 def execute(self, graph, butler): 

316 # Docstring inherited from QuantumGraphExecutor.execute 

317 graph = self._fixupQuanta(graph) 

318 if self.numProc > 1: 

319 self._executeQuantaMP(graph, butler) 

320 else: 

321 self._executeQuantaInProcess(graph, butler) 

322 

323 def _fixupQuanta(self, graph: QuantumGraph): 

324 """Call fixup code to modify execution graph. 

325 

326 Parameters 

327 ---------- 

328 graph : `QuantumGraph` 

329 `QuantumGraph` to modify 

330 

331 Returns 

332 ------- 

333 graph : `QuantumGraph` 

334 Modified `QuantumGraph`. 

335 

336 Raises 

337 ------ 

338 MPGraphExecutorError 

339 Raised if execution graph cannot be ordered after modification, 

340 i.e. it has dependency cycles. 

341 """ 

342 if not self.executionGraphFixup: 

343 return graph 

344 

345 _LOG.debug("Call execution graph fixup method") 

346 graph = self.executionGraphFixup.fixupQuanta(graph) 

347 

348 # Detect if there is now a cycle created within the graph 

349 if graph.findCycle(): 

350 raise MPGraphExecutorError("Updated execution graph has dependency cycle.") 

351 

352 return graph 

353 

354 def _executeQuantaInProcess(self, graph, butler): 

355 """Execute all Quanta in current process. 

356 

357 Parameters 

358 ---------- 

359 graph : `QuantumGraph` 

360 `QuantumGraph` that is to be executed 

361 butler : `lsst.daf.butler.Butler` 

362 Data butler instance 

363 """ 

364 successCount, totalCount = 0, len(graph) 

365 failedNodes = set() 

366 for qnode in graph: 

367 

368 # Any failed inputs mean that the quantum has to be skipped. 

369 inputNodes = graph.determineInputsToQuantumNode(qnode) 

370 if inputNodes & failedNodes: 

371 _LOG.error( 

372 "Upstream job failed for task <%s dataId=%s>, skipping this task.", 

373 qnode.taskDef, 

374 qnode.quantum.dataId, 

375 ) 

376 failedNodes.add(qnode) 

377 continue 

378 

379 _LOG.debug("Executing %s", qnode) 

380 try: 

381 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

382 successCount += 1 

383 except Exception as exc: 

384 failedNodes.add(qnode) 

385 if self.failFast: 

386 raise MPGraphExecutorError( 

387 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed." 

388 ) from exc 

389 else: 

390 # Note that there could be exception safety issues, which 

391 # we presently ignore. 

392 _LOG.error( 

393 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.", 

394 qnode.taskDef, 

395 qnode.quantum.dataId, 

396 exc_info=exc, 

397 ) 

398 finally: 

399 # sqlalchemy has some objects that can last until a garbage 

400 # collection cycle is run, which can happen at unpredictable 

401 # times, run a collection loop here explicitly. 

402 gc.collect() 

403 

404 _LOG.info( 

405 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

406 successCount, 

407 len(failedNodes), 

408 totalCount - successCount - len(failedNodes), 

409 totalCount, 

410 ) 

411 

412 # Raise an exception if there were any failures. 

413 if failedNodes: 

414 raise MPGraphExecutorError("One or more tasks failed during execution.") 

415 

416 def _executeQuantaMP(self, graph, butler): 

417 """Execute all Quanta in separate processes. 

418 

419 Parameters 

420 ---------- 

421 graph : `QuantumGraph` 

422 `QuantumGraph` that is to be executed. 

423 butler : `lsst.daf.butler.Butler` 

424 Data butler instance 

425 """ 

426 

427 disableImplicitThreading() # To prevent thread contention 

428 

429 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

430 

431 # re-pack input quantum data into jobs list 

432 jobs = _JobList(graph) 

433 

434 # check that all tasks can run in sub-process 

435 for job in jobs.jobs: 

436 taskDef = job.qnode.taskDef 

437 if not taskDef.taskClass.canMultiprocess: 

438 raise MPGraphExecutorError( 

439 f"Task {taskDef.taskName} does not support multiprocessing; use single process" 

440 ) 

441 

442 finishedCount, failedCount = 0, 0 

443 while jobs.pending or jobs.running: 

444 

445 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

446 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

447 

448 # See if any jobs have finished 

449 for job in jobs.running: 

450 if not job.process.is_alive(): 

451 _LOG.debug("finished: %s", job) 

452 # finished 

453 exitcode = job.process.exitcode 

454 if exitcode == 0: 

455 jobs.setJobState(job, JobState.FINISHED) 

456 job.cleanup() 

457 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

458 else: 

459 # failMessage() has to be called before cleanup() 

460 message = job.failMessage() 

461 jobs.setJobState(job, JobState.FAILED) 

462 job.cleanup() 

463 _LOG.debug("failed: %s", job) 

464 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

465 # stop all running jobs 

466 for stopJob in jobs.running: 

467 if stopJob is not job: 

468 stopJob.stop() 

469 raise MPGraphExecutorError(message) 

470 else: 

471 _LOG.error("%s; processing will continue for remaining tasks.", message) 

472 else: 

473 # check for timeout 

474 now = time.time() 

475 if now - job.started > self.timeout: 

476 jobs.setJobState(job, JobState.TIMED_OUT) 

477 _LOG.debug("Terminating job %s due to timeout", job) 

478 job.stop() 

479 job.cleanup() 

480 if self.failFast: 

481 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

482 else: 

483 _LOG.error( 

484 "Timeout (%s sec) for task %s; task is killed, processing continues " 

485 "for remaining tasks.", 

486 self.timeout, 

487 job, 

488 ) 

489 

490 # Fail jobs whose inputs failed, this may need several iterations 

491 # if the order is not right, will be done in the next loop. 

492 if jobs.failedNodes: 

493 for job in jobs.pending: 

494 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

495 if jobInputNodes & jobs.failedNodes: 

496 jobs.setJobState(job, JobState.FAILED_DEP) 

497 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

498 

499 # see if we can start more jobs 

500 if len(jobs.running) < self.numProc: 

501 for job in jobs.pending: 

502 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

503 if jobInputNodes <= jobs.finishedNodes: 

504 # all dependencies have completed, can start new job 

505 if len(jobs.running) < self.numProc: 

506 _LOG.debug("Submitting %s", job) 

507 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

508 if len(jobs.running) >= self.numProc: 

509 # Cannot start any more jobs, wait until something 

510 # finishes. 

511 break 

512 

513 # Do cleanup for timed out jobs if necessary. 

514 jobs.cleanup() 

515 

516 # Print progress message if something changed. 

517 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

518 if (finishedCount, failedCount) != (newFinished, newFailed): 

519 finishedCount, failedCount = newFinished, newFailed 

520 totalCount = len(jobs.jobs) 

521 _LOG.info( 

522 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

523 finishedCount, 

524 failedCount, 

525 totalCount - finishedCount - failedCount, 

526 totalCount, 

527 ) 

528 

529 # Here we want to wait until one of the running jobs completes 

530 # but multiprocessing does not provide an API for that, for now 

531 # just sleep a little bit and go back to the loop. 

532 if jobs.running: 

533 time.sleep(0.1) 

534 

535 if jobs.failedNodes: 

536 # print list of failed jobs 

537 _LOG.error("Failed jobs:") 

538 for job in jobs.jobs: 

539 if job.state != JobState.FINISHED: 

540 _LOG.error(" - %s: %s", job.state.name, job) 

541 

542 # if any job failed raise an exception 

543 if jobs.failedNodes == jobs.timedOutNodes: 

544 raise MPTimeoutError("One or more tasks timed out during execution.") 

545 else: 

546 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")