Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 17%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

199 statements  

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24import gc 

25import logging 

26import multiprocessing 

27import pickle 

28import sys 

29import time 

30 

31# ------------------------------- 

32# Imports of standard modules -- 

33# ------------------------------- 

34from enum import Enum 

35 

36from lsst.base import disableImplicitThreading 

37from lsst.daf.butler.cli.cliLog import CliLog 

38from lsst.pipe.base import InvalidQuantumError 

39from lsst.pipe.base.graph.graph import QuantumGraph 

40 

41# ----------------------------- 

42# Imports for other modules -- 

43# ----------------------------- 

44from .quantumGraphExecutor import QuantumGraphExecutor 

45 

46_LOG = logging.getLogger(__name__) 

47 

48 

49# Possible states for the executing task: 

50# - PENDING: job has not started yet 

51# - RUNNING: job is currently executing 

52# - FINISHED: job finished successfully 

53# - FAILED: job execution failed (process returned non-zero status) 

54# - TIMED_OUT: job is killed due to too long execution time 

55# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

56JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

57 

58 

59class _Job: 

60 """Class representing a job running single task. 

61 

62 Parameters 

63 ---------- 

64 qnode: `~lsst.pipe.base.QuantumNode` 

65 Quantum and some associated information. 

66 """ 

67 

68 def __init__(self, qnode): 

69 self.qnode = qnode 

70 self.process = None 

71 self._state = JobState.PENDING 

72 self.started = None 

73 

74 @property 

75 def state(self): 

76 """Job processing state (JobState)""" 

77 return self._state 

78 

79 def start(self, butler, quantumExecutor, startMethod=None): 

80 """Start process which runs the task. 

81 

82 Parameters 

83 ---------- 

84 butler : `lsst.daf.butler.Butler` 

85 Data butler instance. 

86 quantumExecutor : `QuantumExecutor` 

87 Executor for single quantum. 

88 startMethod : `str`, optional 

89 Start method from `multiprocessing` module. 

90 """ 

91 # Unpickling of quantum has to happen after butler, this is why 

92 # it is pickled manually here. 

93 quantum_pickle = pickle.dumps(self.qnode.quantum) 

94 taskDef = self.qnode.taskDef 

95 logConfigState = CliLog.configState 

96 mp_ctx = multiprocessing.get_context(startMethod) 

97 self.process = mp_ctx.Process( 

98 target=_Job._executeJob, 

99 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState), 

100 name=f"task-{self.qnode.quantum.dataId}", 

101 ) 

102 self.process.start() 

103 self.started = time.time() 

104 self._state = JobState.RUNNING 

105 

106 @staticmethod 

107 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState): 

108 """Execute a job with arguments. 

109 

110 Parameters 

111 ---------- 

112 quantumExecutor : `QuantumExecutor` 

113 Executor for single quantum. 

114 taskDef : `bytes` 

115 Task definition structure. 

116 quantum_pickle : `bytes` 

117 Quantum for this task execution in pickled form. 

118 butler : `lss.daf.butler.Butler` 

119 Data butler instance. 

120 """ 

121 if logConfigState and not CliLog.configState: 

122 # means that we are in a new spawned Python process and we have to 

123 # re-initialize logging 

124 CliLog.replayConfigState(logConfigState) 

125 

126 # have to reset connection pool to avoid sharing connections 

127 if butler is not None: 

128 butler.registry.resetConnectionPool() 

129 

130 quantum = pickle.loads(quantum_pickle) 

131 quantumExecutor.execute(taskDef, quantum, butler) 

132 

133 def stop(self): 

134 """Stop the process.""" 

135 self.process.terminate() 

136 # give it 1 second to finish or KILL 

137 for i in range(10): 

138 time.sleep(0.1) 

139 if not self.process.is_alive(): 

140 break 

141 else: 

142 _LOG.debug("Killing process %s", self.process.name) 

143 self.process.kill() 

144 

145 def cleanup(self): 

146 """Release processes resources, has to be called for each finished 

147 process. 

148 """ 

149 if self.process and not self.process.is_alive(): 

150 self.process.close() 

151 self.process = None 

152 

153 def __str__(self): 

154 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

155 

156 

157class _JobList: 

158 """Simple list of _Job instances with few convenience methods. 

159 

160 Parameters 

161 ---------- 

162 iterable : iterable of `~lsst.pipe.base.QuantumIterData` 

163 Sequence if Quanta to execute. This has to be ordered according to 

164 task dependencies. 

165 """ 

166 

167 def __init__(self, iterable): 

168 self.jobs = [_Job(qnode) for qnode in iterable] 

169 self.pending = self.jobs[:] 

170 self.running = [] 

171 self.finishedNodes = set() 

172 self.failedNodes = set() 

173 self.timedOutNodes = set() 

174 

175 def submit(self, job, butler, quantumExecutor, startMethod=None): 

176 """Submit one more job for execution 

177 

178 Parameters 

179 ---------- 

180 job : `_Job` 

181 Job to submit. 

182 butler : `lsst.daf.butler.Butler` 

183 Data butler instance. 

184 quantumExecutor : `QuantumExecutor` 

185 Executor for single quantum. 

186 startMethod : `str`, optional 

187 Start method from `multiprocessing` module. 

188 """ 

189 # this will raise if job is not in pending list 

190 self.pending.remove(job) 

191 job.start(butler, quantumExecutor, startMethod) 

192 self.running.append(job) 

193 

194 def setJobState(self, job, state): 

195 """Update job state. 

196 

197 Parameters 

198 ---------- 

199 job : `_Job` 

200 Job to submit. 

201 state : `JobState` 

202 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

203 FAILED_DEP state is acceptable. 

204 """ 

205 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP) 

206 assert state in allowedStates, f"State {state} not allowed here" 

207 

208 # remove job from pending/running lists 

209 if job.state == JobState.PENDING: 

210 self.pending.remove(job) 

211 elif job.state == JobState.RUNNING: 

212 self.running.remove(job) 

213 

214 qnode = job.qnode 

215 # it should not be in any of these, but just in case 

216 self.finishedNodes.discard(qnode) 

217 self.failedNodes.discard(qnode) 

218 self.timedOutNodes.discard(qnode) 

219 

220 job._state = state 

221 if state == JobState.FINISHED: 

222 self.finishedNodes.add(qnode) 

223 elif state == JobState.FAILED: 

224 self.failedNodes.add(qnode) 

225 elif state == JobState.FAILED_DEP: 

226 self.failedNodes.add(qnode) 

227 elif state == JobState.TIMED_OUT: 

228 self.failedNodes.add(qnode) 

229 self.timedOutNodes.add(qnode) 

230 else: 

231 raise ValueError(f"Unexpected state value: {state}") 

232 

233 def cleanup(self): 

234 """Do periodic cleanup for jobs that did not finish correctly. 

235 

236 If timed out jobs are killed but take too long to stop then regular 

237 cleanup will not work for them. Here we check all timed out jobs 

238 periodically and do cleanup if they managed to die by this time. 

239 """ 

240 for job in self.jobs: 

241 if job.state == JobState.TIMED_OUT and job.process is not None: 

242 job.cleanup() 

243 

244 

245class MPGraphExecutorError(Exception): 

246 """Exception class for errors raised by MPGraphExecutor.""" 

247 

248 pass 

249 

250 

251class MPTimeoutError(MPGraphExecutorError): 

252 """Exception raised when task execution times out.""" 

253 

254 pass 

255 

256 

257class MPGraphExecutor(QuantumGraphExecutor): 

258 """Implementation of QuantumGraphExecutor using same-host multiprocess 

259 execution of Quanta. 

260 

261 Parameters 

262 ---------- 

263 numProc : `int` 

264 Number of processes to use for executing tasks. 

265 timeout : `float` 

266 Time in seconds to wait for tasks to finish. 

267 quantumExecutor : `QuantumExecutor` 

268 Executor for single quantum. For multiprocess-style execution when 

269 ``numProc`` is greater than one this instance must support pickle. 

270 startMethod : `str`, optional 

271 Start method from `multiprocessing` module, `None` selects the best 

272 one for current platform. 

273 failFast : `bool`, optional 

274 If set to ``True`` then stop processing on first error from any task. 

275 executionGraphFixup : `ExecutionGraphFixup`, optional 

276 Instance used for modification of execution graph. 

277 """ 

278 

279 def __init__( 

280 self, numProc, timeout, quantumExecutor, *, startMethod=None, failFast=False, executionGraphFixup=None 

281 ): 

282 self.numProc = numProc 

283 self.timeout = timeout 

284 self.quantumExecutor = quantumExecutor 

285 self.failFast = failFast 

286 self.executionGraphFixup = executionGraphFixup 

287 

288 # We set default start method as spawn for MacOS and fork for Linux; 

289 # None for all other platforms to use multiprocessing default. 

290 if startMethod is None: 

291 methods = dict(linux="fork", darwin="spawn") 

292 startMethod = methods.get(sys.platform) 

293 self.startMethod = startMethod 

294 

295 def execute(self, graph, butler): 

296 # Docstring inherited from QuantumGraphExecutor.execute 

297 graph = self._fixupQuanta(graph) 

298 if self.numProc > 1: 

299 self._executeQuantaMP(graph, butler) 

300 else: 

301 self._executeQuantaInProcess(graph, butler) 

302 

303 def _fixupQuanta(self, graph: QuantumGraph): 

304 """Call fixup code to modify execution graph. 

305 

306 Parameters 

307 ---------- 

308 graph : `QuantumGraph` 

309 `QuantumGraph` to modify 

310 

311 Returns 

312 ------- 

313 graph : `QuantumGraph` 

314 Modified `QuantumGraph`. 

315 

316 Raises 

317 ------ 

318 MPGraphExecutorError 

319 Raised if execution graph cannot be ordered after modification, 

320 i.e. it has dependency cycles. 

321 """ 

322 if not self.executionGraphFixup: 

323 return graph 

324 

325 _LOG.debug("Call execution graph fixup method") 

326 graph = self.executionGraphFixup.fixupQuanta(graph) 

327 

328 # Detect if there is now a cycle created within the graph 

329 if graph.findCycle(): 

330 raise MPGraphExecutorError("Updated execution graph has dependency cycle.") 

331 

332 return graph 

333 

334 def _executeQuantaInProcess(self, graph, butler): 

335 """Execute all Quanta in current process. 

336 

337 Parameters 

338 ---------- 

339 graph : `QuantumGraph` 

340 `QuantumGraph` that is to be executed 

341 butler : `lsst.daf.butler.Butler` 

342 Data butler instance 

343 """ 

344 # Note that in non-MP case any failed task will generate an exception 

345 # and kill the whole thing. In general we cannot guarantee exception 

346 # safety so easiest and safest thing is to let it die. 

347 count, totalCount = 0, len(graph) 

348 for qnode in graph: 

349 _LOG.debug("Executing %s", qnode) 

350 try: 

351 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

352 finally: 

353 # sqlalchemy has some objects that can last until a garbage 

354 # collection cycle is run, which can happen at unpredictable 

355 # times, run a collection loop here explicitly. 

356 gc.collect() 

357 count += 1 

358 _LOG.info( 

359 "Executed %d quanta, %d remain out of total %d quanta.", count, totalCount - count, totalCount 

360 ) 

361 

362 def _executeQuantaMP(self, graph, butler): 

363 """Execute all Quanta in separate processes. 

364 

365 Parameters 

366 ---------- 

367 graph : `QuantumGraph` 

368 `QuantumGraph` that is to be executed. 

369 butler : `lsst.daf.butler.Butler` 

370 Data butler instance 

371 """ 

372 

373 disableImplicitThreading() # To prevent thread contention 

374 

375 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

376 

377 # re-pack input quantum data into jobs list 

378 jobs = _JobList(graph) 

379 

380 # check that all tasks can run in sub-process 

381 for job in jobs.jobs: 

382 taskDef = job.qnode.taskDef 

383 if not taskDef.taskClass.canMultiprocess: 

384 raise MPGraphExecutorError( 

385 f"Task {taskDef.taskName} does not support multiprocessing; use single process" 

386 ) 

387 

388 finishedCount, failedCount = 0, 0 

389 while jobs.pending or jobs.running: 

390 

391 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

392 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

393 

394 # See if any jobs have finished 

395 for job in jobs.running: 

396 if not job.process.is_alive(): 

397 _LOG.debug("finished: %s", job) 

398 # finished 

399 exitcode = job.process.exitcode 

400 if exitcode == 0: 

401 jobs.setJobState(job, JobState.FINISHED) 

402 job.cleanup() 

403 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

404 else: 

405 jobs.setJobState(job, JobState.FAILED) 

406 job.cleanup() 

407 _LOG.debug("failed: %s", job) 

408 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

409 for stopJob in jobs.running: 

410 if stopJob is not job: 

411 stopJob.stop() 

412 raise MPGraphExecutorError(f"Task {job} failed, exit code={exitcode}.") 

413 else: 

414 _LOG.error("Task %s failed; processing will continue for remaining tasks.", job) 

415 else: 

416 # check for timeout 

417 now = time.time() 

418 if now - job.started > self.timeout: 

419 jobs.setJobState(job, JobState.TIMED_OUT) 

420 _LOG.debug("Terminating job %s due to timeout", job) 

421 job.stop() 

422 job.cleanup() 

423 if self.failFast: 

424 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

425 else: 

426 _LOG.error( 

427 "Timeout (%s sec) for task %s; task is killed, processing continues " 

428 "for remaining tasks.", 

429 self.timeout, 

430 job, 

431 ) 

432 

433 # Fail jobs whose inputs failed, this may need several iterations 

434 # if the order is not right, will be done in the next loop. 

435 if jobs.failedNodes: 

436 for job in jobs.pending: 

437 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

438 if jobInputNodes & jobs.failedNodes: 

439 jobs.setJobState(job, JobState.FAILED_DEP) 

440 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

441 

442 # see if we can start more jobs 

443 if len(jobs.running) < self.numProc: 

444 for job in jobs.pending: 

445 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

446 if jobInputNodes <= jobs.finishedNodes: 

447 # all dependencies have completed, can start new job 

448 if len(jobs.running) < self.numProc: 

449 _LOG.debug("Submitting %s", job) 

450 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

451 if len(jobs.running) >= self.numProc: 

452 # Cannot start any more jobs, wait until something 

453 # finishes. 

454 break 

455 

456 # Do cleanup for timed out jobs if necessary. 

457 jobs.cleanup() 

458 

459 # Print progress message if something changed. 

460 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

461 if (finishedCount, failedCount) != (newFinished, newFailed): 

462 finishedCount, failedCount = newFinished, newFailed 

463 totalCount = len(jobs.jobs) 

464 _LOG.info( 

465 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

466 finishedCount, 

467 failedCount, 

468 totalCount - finishedCount - failedCount, 

469 totalCount, 

470 ) 

471 

472 # Here we want to wait until one of the running jobs completes 

473 # but multiprocessing does not provide an API for that, for now 

474 # just sleep a little bit and go back to the loop. 

475 if jobs.running: 

476 time.sleep(0.1) 

477 

478 if jobs.failedNodes: 

479 # print list of failed jobs 

480 _LOG.error("Failed jobs:") 

481 for job in jobs.jobs: 

482 if job.state != JobState.FINISHED: 

483 _LOG.error(" - %s: %s", job.state.name, job) 

484 

485 # if any job failed raise an exception 

486 if jobs.failedNodes == jobs.timedOutNodes: 

487 raise MPTimeoutError("One or more tasks timed out during execution.") 

488 else: 

489 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")