Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 16%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

212 statements  

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24import gc 

25import logging 

26import multiprocessing 

27import pickle 

28import sys 

29import time 

30 

31# ------------------------------- 

32# Imports of standard modules -- 

33# ------------------------------- 

34from enum import Enum 

35 

36from lsst.base import disableImplicitThreading 

37from lsst.daf.butler.cli.cliLog import CliLog 

38from lsst.pipe.base import InvalidQuantumError 

39from lsst.pipe.base.graph.graph import QuantumGraph 

40 

41# ----------------------------- 

42# Imports for other modules -- 

43# ----------------------------- 

44from .quantumGraphExecutor import QuantumGraphExecutor 

45 

46_LOG = logging.getLogger(__name__) 

47 

48 

49# Possible states for the executing task: 

50# - PENDING: job has not started yet 

51# - RUNNING: job is currently executing 

52# - FINISHED: job finished successfully 

53# - FAILED: job execution failed (process returned non-zero status) 

54# - TIMED_OUT: job is killed due to too long execution time 

55# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

56JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

57 

58 

59class _Job: 

60 """Class representing a job running single task. 

61 

62 Parameters 

63 ---------- 

64 qnode: `~lsst.pipe.base.QuantumNode` 

65 Quantum and some associated information. 

66 """ 

67 

68 def __init__(self, qnode): 

69 self.qnode = qnode 

70 self.process = None 

71 self._state = JobState.PENDING 

72 self.started = None 

73 

74 @property 

75 def state(self): 

76 """Job processing state (JobState)""" 

77 return self._state 

78 

79 def start(self, butler, quantumExecutor, startMethod=None): 

80 """Start process which runs the task. 

81 

82 Parameters 

83 ---------- 

84 butler : `lsst.daf.butler.Butler` 

85 Data butler instance. 

86 quantumExecutor : `QuantumExecutor` 

87 Executor for single quantum. 

88 startMethod : `str`, optional 

89 Start method from `multiprocessing` module. 

90 """ 

91 # Unpickling of quantum has to happen after butler, this is why 

92 # it is pickled manually here. 

93 quantum_pickle = pickle.dumps(self.qnode.quantum) 

94 taskDef = self.qnode.taskDef 

95 logConfigState = CliLog.configState 

96 mp_ctx = multiprocessing.get_context(startMethod) 

97 self.process = mp_ctx.Process( 

98 target=_Job._executeJob, 

99 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState), 

100 name=f"task-{self.qnode.quantum.dataId}", 

101 ) 

102 self.process.start() 

103 self.started = time.time() 

104 self._state = JobState.RUNNING 

105 

106 @staticmethod 

107 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState): 

108 """Execute a job with arguments. 

109 

110 Parameters 

111 ---------- 

112 quantumExecutor : `QuantumExecutor` 

113 Executor for single quantum. 

114 taskDef : `bytes` 

115 Task definition structure. 

116 quantum_pickle : `bytes` 

117 Quantum for this task execution in pickled form. 

118 butler : `lss.daf.butler.Butler` 

119 Data butler instance. 

120 """ 

121 if logConfigState and not CliLog.configState: 

122 # means that we are in a new spawned Python process and we have to 

123 # re-initialize logging 

124 CliLog.replayConfigState(logConfigState) 

125 

126 # have to reset connection pool to avoid sharing connections 

127 if butler is not None: 

128 butler.registry.resetConnectionPool() 

129 

130 quantum = pickle.loads(quantum_pickle) 

131 quantumExecutor.execute(taskDef, quantum, butler) 

132 

133 def stop(self): 

134 """Stop the process.""" 

135 self.process.terminate() 

136 # give it 1 second to finish or KILL 

137 for i in range(10): 

138 time.sleep(0.1) 

139 if not self.process.is_alive(): 

140 break 

141 else: 

142 _LOG.debug("Killing process %s", self.process.name) 

143 self.process.kill() 

144 

145 def cleanup(self): 

146 """Release processes resources, has to be called for each finished 

147 process. 

148 """ 

149 if self.process and not self.process.is_alive(): 

150 self.process.close() 

151 self.process = None 

152 

153 def __str__(self): 

154 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

155 

156 

157class _JobList: 

158 """Simple list of _Job instances with few convenience methods. 

159 

160 Parameters 

161 ---------- 

162 iterable : iterable of `~lsst.pipe.base.QuantumNode` 

163 Sequence of Quanta to execute. This has to be ordered according to 

164 task dependencies. 

165 """ 

166 

167 def __init__(self, iterable): 

168 self.jobs = [_Job(qnode) for qnode in iterable] 

169 self.pending = self.jobs[:] 

170 self.running = [] 

171 self.finishedNodes = set() 

172 self.failedNodes = set() 

173 self.timedOutNodes = set() 

174 

175 def submit(self, job, butler, quantumExecutor, startMethod=None): 

176 """Submit one more job for execution 

177 

178 Parameters 

179 ---------- 

180 job : `_Job` 

181 Job to submit. 

182 butler : `lsst.daf.butler.Butler` 

183 Data butler instance. 

184 quantumExecutor : `QuantumExecutor` 

185 Executor for single quantum. 

186 startMethod : `str`, optional 

187 Start method from `multiprocessing` module. 

188 """ 

189 # this will raise if job is not in pending list 

190 self.pending.remove(job) 

191 job.start(butler, quantumExecutor, startMethod) 

192 self.running.append(job) 

193 

194 def setJobState(self, job, state): 

195 """Update job state. 

196 

197 Parameters 

198 ---------- 

199 job : `_Job` 

200 Job to submit. 

201 state : `JobState` 

202 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

203 FAILED_DEP state is acceptable. 

204 """ 

205 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP) 

206 assert state in allowedStates, f"State {state} not allowed here" 

207 

208 # remove job from pending/running lists 

209 if job.state == JobState.PENDING: 

210 self.pending.remove(job) 

211 elif job.state == JobState.RUNNING: 

212 self.running.remove(job) 

213 

214 qnode = job.qnode 

215 # it should not be in any of these, but just in case 

216 self.finishedNodes.discard(qnode) 

217 self.failedNodes.discard(qnode) 

218 self.timedOutNodes.discard(qnode) 

219 

220 job._state = state 

221 if state == JobState.FINISHED: 

222 self.finishedNodes.add(qnode) 

223 elif state == JobState.FAILED: 

224 self.failedNodes.add(qnode) 

225 elif state == JobState.FAILED_DEP: 

226 self.failedNodes.add(qnode) 

227 elif state == JobState.TIMED_OUT: 

228 self.failedNodes.add(qnode) 

229 self.timedOutNodes.add(qnode) 

230 else: 

231 raise ValueError(f"Unexpected state value: {state}") 

232 

233 def cleanup(self): 

234 """Do periodic cleanup for jobs that did not finish correctly. 

235 

236 If timed out jobs are killed but take too long to stop then regular 

237 cleanup will not work for them. Here we check all timed out jobs 

238 periodically and do cleanup if they managed to die by this time. 

239 """ 

240 for job in self.jobs: 

241 if job.state == JobState.TIMED_OUT and job.process is not None: 

242 job.cleanup() 

243 

244 

245class MPGraphExecutorError(Exception): 

246 """Exception class for errors raised by MPGraphExecutor.""" 

247 

248 pass 

249 

250 

251class MPTimeoutError(MPGraphExecutorError): 

252 """Exception raised when task execution times out.""" 

253 

254 pass 

255 

256 

257class MPGraphExecutor(QuantumGraphExecutor): 

258 """Implementation of QuantumGraphExecutor using same-host multiprocess 

259 execution of Quanta. 

260 

261 Parameters 

262 ---------- 

263 numProc : `int` 

264 Number of processes to use for executing tasks. 

265 timeout : `float` 

266 Time in seconds to wait for tasks to finish. 

267 quantumExecutor : `QuantumExecutor` 

268 Executor for single quantum. For multiprocess-style execution when 

269 ``numProc`` is greater than one this instance must support pickle. 

270 startMethod : `str`, optional 

271 Start method from `multiprocessing` module, `None` selects the best 

272 one for current platform. 

273 failFast : `bool`, optional 

274 If set to ``True`` then stop processing on first error from any task. 

275 executionGraphFixup : `ExecutionGraphFixup`, optional 

276 Instance used for modification of execution graph. 

277 """ 

278 

279 def __init__( 

280 self, numProc, timeout, quantumExecutor, *, startMethod=None, failFast=False, executionGraphFixup=None 

281 ): 

282 self.numProc = numProc 

283 self.timeout = timeout 

284 self.quantumExecutor = quantumExecutor 

285 self.failFast = failFast 

286 self.executionGraphFixup = executionGraphFixup 

287 

288 # We set default start method as spawn for MacOS and fork for Linux; 

289 # None for all other platforms to use multiprocessing default. 

290 if startMethod is None: 

291 methods = dict(linux="fork", darwin="spawn") 

292 startMethod = methods.get(sys.platform) 

293 self.startMethod = startMethod 

294 

295 def execute(self, graph, butler): 

296 # Docstring inherited from QuantumGraphExecutor.execute 

297 graph = self._fixupQuanta(graph) 

298 if self.numProc > 1: 

299 self._executeQuantaMP(graph, butler) 

300 else: 

301 self._executeQuantaInProcess(graph, butler) 

302 

303 def _fixupQuanta(self, graph: QuantumGraph): 

304 """Call fixup code to modify execution graph. 

305 

306 Parameters 

307 ---------- 

308 graph : `QuantumGraph` 

309 `QuantumGraph` to modify 

310 

311 Returns 

312 ------- 

313 graph : `QuantumGraph` 

314 Modified `QuantumGraph`. 

315 

316 Raises 

317 ------ 

318 MPGraphExecutorError 

319 Raised if execution graph cannot be ordered after modification, 

320 i.e. it has dependency cycles. 

321 """ 

322 if not self.executionGraphFixup: 

323 return graph 

324 

325 _LOG.debug("Call execution graph fixup method") 

326 graph = self.executionGraphFixup.fixupQuanta(graph) 

327 

328 # Detect if there is now a cycle created within the graph 

329 if graph.findCycle(): 

330 raise MPGraphExecutorError("Updated execution graph has dependency cycle.") 

331 

332 return graph 

333 

334 def _executeQuantaInProcess(self, graph, butler): 

335 """Execute all Quanta in current process. 

336 

337 Parameters 

338 ---------- 

339 graph : `QuantumGraph` 

340 `QuantumGraph` that is to be executed 

341 butler : `lsst.daf.butler.Butler` 

342 Data butler instance 

343 """ 

344 successCount, totalCount = 0, len(graph) 

345 failedNodes = set() 

346 for qnode in graph: 

347 

348 # Any failed inputs mean that the quantum has to be skipped. 

349 inputNodes = graph.determineInputsToQuantumNode(qnode) 

350 if inputNodes & failedNodes: 

351 _LOG.error( 

352 "Upstream job failed for task <%s dataId=%s>, skipping this task.", 

353 qnode.taskDef, 

354 qnode.quantum.dataId, 

355 ) 

356 failedNodes.add(qnode) 

357 continue 

358 

359 _LOG.debug("Executing %s", qnode) 

360 try: 

361 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

362 successCount += 1 

363 except Exception as exc: 

364 failedNodes.add(qnode) 

365 if self.failFast: 

366 raise MPGraphExecutorError( 

367 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed." 

368 ) from exc 

369 else: 

370 # Note that there could be exception safety issues, which 

371 # we presently ignore. 

372 _LOG.error( 

373 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.", 

374 qnode.taskDef, 

375 qnode.quantum.dataId, 

376 exc_info=exc, 

377 ) 

378 finally: 

379 # sqlalchemy has some objects that can last until a garbage 

380 # collection cycle is run, which can happen at unpredictable 

381 # times, run a collection loop here explicitly. 

382 gc.collect() 

383 

384 _LOG.info( 

385 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

386 successCount, 

387 len(failedNodes), 

388 totalCount - successCount - len(failedNodes), 

389 totalCount, 

390 ) 

391 

392 # Raise an exception if there were any failures. 

393 if failedNodes: 

394 raise MPGraphExecutorError("One or more tasks failed during execution.") 

395 

396 def _executeQuantaMP(self, graph, butler): 

397 """Execute all Quanta in separate processes. 

398 

399 Parameters 

400 ---------- 

401 graph : `QuantumGraph` 

402 `QuantumGraph` that is to be executed. 

403 butler : `lsst.daf.butler.Butler` 

404 Data butler instance 

405 """ 

406 

407 disableImplicitThreading() # To prevent thread contention 

408 

409 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

410 

411 # re-pack input quantum data into jobs list 

412 jobs = _JobList(graph) 

413 

414 # check that all tasks can run in sub-process 

415 for job in jobs.jobs: 

416 taskDef = job.qnode.taskDef 

417 if not taskDef.taskClass.canMultiprocess: 

418 raise MPGraphExecutorError( 

419 f"Task {taskDef.taskName} does not support multiprocessing; use single process" 

420 ) 

421 

422 finishedCount, failedCount = 0, 0 

423 while jobs.pending or jobs.running: 

424 

425 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

426 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

427 

428 # See if any jobs have finished 

429 for job in jobs.running: 

430 if not job.process.is_alive(): 

431 _LOG.debug("finished: %s", job) 

432 # finished 

433 exitcode = job.process.exitcode 

434 if exitcode == 0: 

435 jobs.setJobState(job, JobState.FINISHED) 

436 job.cleanup() 

437 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

438 else: 

439 jobs.setJobState(job, JobState.FAILED) 

440 job.cleanup() 

441 _LOG.debug("failed: %s", job) 

442 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

443 for stopJob in jobs.running: 

444 if stopJob is not job: 

445 stopJob.stop() 

446 raise MPGraphExecutorError(f"Task {job} failed, exit code={exitcode}.") 

447 else: 

448 _LOG.error("Task %s failed; processing will continue for remaining tasks.", job) 

449 else: 

450 # check for timeout 

451 now = time.time() 

452 if now - job.started > self.timeout: 

453 jobs.setJobState(job, JobState.TIMED_OUT) 

454 _LOG.debug("Terminating job %s due to timeout", job) 

455 job.stop() 

456 job.cleanup() 

457 if self.failFast: 

458 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

459 else: 

460 _LOG.error( 

461 "Timeout (%s sec) for task %s; task is killed, processing continues " 

462 "for remaining tasks.", 

463 self.timeout, 

464 job, 

465 ) 

466 

467 # Fail jobs whose inputs failed, this may need several iterations 

468 # if the order is not right, will be done in the next loop. 

469 if jobs.failedNodes: 

470 for job in jobs.pending: 

471 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

472 if jobInputNodes & jobs.failedNodes: 

473 jobs.setJobState(job, JobState.FAILED_DEP) 

474 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

475 

476 # see if we can start more jobs 

477 if len(jobs.running) < self.numProc: 

478 for job in jobs.pending: 

479 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

480 if jobInputNodes <= jobs.finishedNodes: 

481 # all dependencies have completed, can start new job 

482 if len(jobs.running) < self.numProc: 

483 _LOG.debug("Submitting %s", job) 

484 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

485 if len(jobs.running) >= self.numProc: 

486 # Cannot start any more jobs, wait until something 

487 # finishes. 

488 break 

489 

490 # Do cleanup for timed out jobs if necessary. 

491 jobs.cleanup() 

492 

493 # Print progress message if something changed. 

494 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

495 if (finishedCount, failedCount) != (newFinished, newFailed): 

496 finishedCount, failedCount = newFinished, newFailed 

497 totalCount = len(jobs.jobs) 

498 _LOG.info( 

499 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

500 finishedCount, 

501 failedCount, 

502 totalCount - finishedCount - failedCount, 

503 totalCount, 

504 ) 

505 

506 # Here we want to wait until one of the running jobs completes 

507 # but multiprocessing does not provide an API for that, for now 

508 # just sleep a little bit and go back to the loop. 

509 if jobs.running: 

510 time.sleep(0.1) 

511 

512 if jobs.failedNodes: 

513 # print list of failed jobs 

514 _LOG.error("Failed jobs:") 

515 for job in jobs.jobs: 

516 if job.state != JobState.FINISHED: 

517 _LOG.error(" - %s: %s", job.state.name, job) 

518 

519 # if any job failed raise an exception 

520 if jobs.failedNodes == jobs.timedOutNodes: 

521 raise MPTimeoutError("One or more tasks timed out during execution.") 

522 else: 

523 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")