Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

199 statements  

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24# ------------------------------- 

25# Imports of standard modules -- 

26# ------------------------------- 

27from enum import Enum 

28import gc 

29import logging 

30import multiprocessing 

31import pickle 

32import sys 

33import time 

34 

35from lsst.pipe.base.graph.graph import QuantumGraph 

36from lsst.pipe.base import InvalidQuantumError 

37 

38# ----------------------------- 

39# Imports for other modules -- 

40# ----------------------------- 

41from .quantumGraphExecutor import QuantumGraphExecutor 

42from lsst.base import disableImplicitThreading 

43from lsst.daf.butler.cli.cliLog import CliLog 

44 

45_LOG = logging.getLogger(__name__) 

46 

47 

48# Possible states for the executing task: 

49# - PENDING: job has not started yet 

50# - RUNNING: job is currently executing 

51# - FINISHED: job finished successfully 

52# - FAILED: job execution failed (process returned non-zero status) 

53# - TIMED_OUT: job is killed due to too long execution time 

54# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

55JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

56 

57 

58class _Job: 

59 """Class representing a job running single task. 

60 

61 Parameters 

62 ---------- 

63 qnode: `~lsst.pipe.base.QuantumNode` 

64 Quantum and some associated information. 

65 """ 

66 def __init__(self, qnode): 

67 self.qnode = qnode 

68 self.process = None 

69 self._state = JobState.PENDING 

70 self.started = None 

71 

72 @property 

73 def state(self): 

74 """Job processing state (JobState)""" 

75 return self._state 

76 

77 def start(self, butler, quantumExecutor, startMethod=None): 

78 """Start process which runs the task. 

79 

80 Parameters 

81 ---------- 

82 butler : `lsst.daf.butler.Butler` 

83 Data butler instance. 

84 quantumExecutor : `QuantumExecutor` 

85 Executor for single quantum. 

86 startMethod : `str`, optional 

87 Start method from `multiprocessing` module. 

88 """ 

89 # Unpickling of quantum has to happen after butler, this is why 

90 # it is pickled manually here. 

91 quantum_pickle = pickle.dumps(self.qnode.quantum) 

92 taskDef = self.qnode.taskDef 

93 logConfigState = CliLog.configState 

94 mp_ctx = multiprocessing.get_context(startMethod) 

95 self.process = mp_ctx.Process( 

96 target=_Job._executeJob, 

97 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState), 

98 name=f"task-{self.qnode.nodeId.number}" 

99 ) 

100 self.process.start() 

101 self.started = time.time() 

102 self._state = JobState.RUNNING 

103 

104 @staticmethod 

105 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState): 

106 """Execute a job with arguments. 

107 

108 Parameters 

109 ---------- 

110 quantumExecutor : `QuantumExecutor` 

111 Executor for single quantum. 

112 taskDef : `bytes` 

113 Task definition structure. 

114 quantum_pickle : `bytes` 

115 Quantum for this task execution in pickled form. 

116 butler : `lss.daf.butler.Butler` 

117 Data butler instance. 

118 """ 

119 if logConfigState and not CliLog.configState: 

120 # means that we are in a new spawned Python process and we have to 

121 # re-initialize logging 

122 CliLog.replayConfigState(logConfigState) 

123 

124 # have to reset connection pool to avoid sharing connections 

125 if butler is not None: 

126 butler.registry.resetConnectionPool() 

127 

128 quantum = pickle.loads(quantum_pickle) 

129 quantumExecutor.execute(taskDef, quantum, butler) 

130 

131 def stop(self): 

132 """Stop the process. 

133 """ 

134 self.process.terminate() 

135 # give it 1 second to finish or KILL 

136 for i in range(10): 

137 time.sleep(0.1) 

138 if not self.process.is_alive(): 

139 break 

140 else: 

141 _LOG.debug("Killing process %s", self.process.name) 

142 self.process.kill() 

143 

144 def cleanup(self): 

145 """Release processes resources, has to be called for each finished 

146 process. 

147 """ 

148 if self.process and not self.process.is_alive(): 

149 self.process.close() 

150 self.process = None 

151 

152 def __str__(self): 

153 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

154 

155 

156class _JobList: 

157 """Simple list of _Job instances with few convenience methods. 

158 

159 Parameters 

160 ---------- 

161 iterable : iterable of `~lsst.pipe.base.QuantumIterData` 

162 Sequence if Quanta to execute. This has to be ordered according to 

163 task dependencies. 

164 """ 

165 def __init__(self, iterable): 

166 self.jobs = [_Job(qnode) for qnode in iterable] 

167 self.pending = self.jobs[:] 

168 self.running = [] 

169 self.finishedNodes = set() 

170 self.failedNodes = set() 

171 self.timedOutNodes = set() 

172 

173 def submit(self, job, butler, quantumExecutor, startMethod=None): 

174 """Submit one more job for execution 

175 

176 Parameters 

177 ---------- 

178 job : `_Job` 

179 Job to submit. 

180 butler : `lsst.daf.butler.Butler` 

181 Data butler instance. 

182 quantumExecutor : `QuantumExecutor` 

183 Executor for single quantum. 

184 startMethod : `str`, optional 

185 Start method from `multiprocessing` module. 

186 """ 

187 # this will raise if job is not in pending list 

188 self.pending.remove(job) 

189 job.start(butler, quantumExecutor, startMethod) 

190 self.running.append(job) 

191 

192 def setJobState(self, job, state): 

193 """Update job state. 

194 

195 Parameters 

196 ---------- 

197 job : `_Job` 

198 Job to submit. 

199 state : `JobState` 

200 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

201 FAILED_DEP state is acceptable. 

202 """ 

203 allowedStates = ( 

204 JobState.FINISHED, 

205 JobState.FAILED, 

206 JobState.TIMED_OUT, 

207 JobState.FAILED_DEP 

208 ) 

209 assert state in allowedStates, f"State {state} not allowed here" 

210 

211 # remove job from pending/running lists 

212 if job.state == JobState.PENDING: 

213 self.pending.remove(job) 

214 elif job.state == JobState.RUNNING: 

215 self.running.remove(job) 

216 

217 qnode = job.qnode 

218 # it should not be in any of these, but just in case 

219 self.finishedNodes.discard(qnode) 

220 self.failedNodes.discard(qnode) 

221 self.timedOutNodes.discard(qnode) 

222 

223 job._state = state 

224 if state == JobState.FINISHED: 

225 self.finishedNodes.add(qnode) 

226 elif state == JobState.FAILED: 

227 self.failedNodes.add(qnode) 

228 elif state == JobState.FAILED_DEP: 

229 self.failedNodes.add(qnode) 

230 elif state == JobState.TIMED_OUT: 

231 self.failedNodes.add(qnode) 

232 self.timedOutNodes.add(qnode) 

233 else: 

234 raise ValueError(f"Unexpected state value: {state}") 

235 

236 def cleanup(self): 

237 """Do periodic cleanup for jobs that did not finish correctly. 

238 

239 If timed out jobs are killed but take too long to stop then regular 

240 cleanup will not work for them. Here we check all timed out jobs 

241 periodically and do cleanup if they managed to die by this time. 

242 """ 

243 for job in self.jobs: 

244 if job.state == JobState.TIMED_OUT and job.process is not None: 

245 job.cleanup() 

246 

247 

248class MPGraphExecutorError(Exception): 

249 """Exception class for errors raised by MPGraphExecutor. 

250 """ 

251 pass 

252 

253 

254class MPTimeoutError(MPGraphExecutorError): 

255 """Exception raised when task execution times out. 

256 """ 

257 pass 

258 

259 

260class MPGraphExecutor(QuantumGraphExecutor): 

261 """Implementation of QuantumGraphExecutor using same-host multiprocess 

262 execution of Quanta. 

263 

264 Parameters 

265 ---------- 

266 numProc : `int` 

267 Number of processes to use for executing tasks. 

268 timeout : `float` 

269 Time in seconds to wait for tasks to finish. 

270 quantumExecutor : `QuantumExecutor` 

271 Executor for single quantum. For multiprocess-style execution when 

272 ``numProc`` is greater than one this instance must support pickle. 

273 startMethod : `str`, optional 

274 Start method from `multiprocessing` module, `None` selects the best 

275 one for current platform. 

276 failFast : `bool`, optional 

277 If set to ``True`` then stop processing on first error from any task. 

278 executionGraphFixup : `ExecutionGraphFixup`, optional 

279 Instance used for modification of execution graph. 

280 """ 

281 def __init__(self, numProc, timeout, quantumExecutor, *, 

282 startMethod=None, failFast=False, executionGraphFixup=None): 

283 self.numProc = numProc 

284 self.timeout = timeout 

285 self.quantumExecutor = quantumExecutor 

286 self.failFast = failFast 

287 self.executionGraphFixup = executionGraphFixup 

288 

289 # We set default start method as spawn for MacOS and fork for Linux; 

290 # None for all other platforms to use multiprocessing default. 

291 if startMethod is None: 

292 methods = dict(linux="fork", darwin="spawn") 

293 startMethod = methods.get(sys.platform) 

294 self.startMethod = startMethod 

295 

296 def execute(self, graph, butler): 

297 # Docstring inherited from QuantumGraphExecutor.execute 

298 graph = self._fixupQuanta(graph) 

299 if self.numProc > 1: 

300 self._executeQuantaMP(graph, butler) 

301 else: 

302 self._executeQuantaInProcess(graph, butler) 

303 

304 def _fixupQuanta(self, graph: QuantumGraph): 

305 """Call fixup code to modify execution graph. 

306 

307 Parameters 

308 ---------- 

309 graph : `QuantumGraph` 

310 `QuantumGraph` to modify 

311 

312 Returns 

313 ------- 

314 graph : `QuantumGraph` 

315 Modified `QuantumGraph`. 

316 

317 Raises 

318 ------ 

319 MPGraphExecutorError 

320 Raised if execution graph cannot be ordered after modification, 

321 i.e. it has dependency cycles. 

322 """ 

323 if not self.executionGraphFixup: 

324 return graph 

325 

326 _LOG.debug("Call execution graph fixup method") 

327 graph = self.executionGraphFixup.fixupQuanta(graph) 

328 

329 # Detect if there is now a cycle created within the graph 

330 if graph.findCycle(): 

331 raise MPGraphExecutorError( 

332 "Updated execution graph has dependency cycle.") 

333 

334 return graph 

335 

336 def _executeQuantaInProcess(self, graph, butler): 

337 """Execute all Quanta in current process. 

338 

339 Parameters 

340 ---------- 

341 graph : `QuantumGraph` 

342 `QuantumGraph` that is to be executed 

343 butler : `lsst.daf.butler.Butler` 

344 Data butler instance 

345 """ 

346 # Note that in non-MP case any failed task will generate an exception 

347 # and kill the whole thing. In general we cannot guarantee exception 

348 # safety so easiest and safest thing is to let it die. 

349 count, totalCount = 0, len(graph) 

350 for qnode in graph: 

351 _LOG.debug("Executing %s", qnode) 

352 try: 

353 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

354 finally: 

355 # sqlalchemy has some objects that can last until a garbage 

356 # collection cycle is run, which can happen at unpredictable 

357 # times, run a collection loop here explicitly. 

358 gc.collect() 

359 count += 1 

360 _LOG.info("Executed %d quanta, %d remain out of total %d quanta.", 

361 count, totalCount - count, totalCount) 

362 

363 def _executeQuantaMP(self, graph, butler): 

364 """Execute all Quanta in separate processes. 

365 

366 Parameters 

367 ---------- 

368 graph : `QuantumGraph` 

369 `QuantumGraph` that is to be executed. 

370 butler : `lsst.daf.butler.Butler` 

371 Data butler instance 

372 """ 

373 

374 disableImplicitThreading() # To prevent thread contention 

375 

376 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

377 

378 # re-pack input quantum data into jobs list 

379 jobs = _JobList(graph) 

380 

381 # check that all tasks can run in sub-process 

382 for job in jobs.jobs: 

383 taskDef = job.qnode.taskDef 

384 if not taskDef.taskClass.canMultiprocess: 

385 raise MPGraphExecutorError(f"Task {taskDef.taskName} does not support multiprocessing;" 

386 " use single process") 

387 

388 finishedCount, failedCount = 0, 0 

389 while jobs.pending or jobs.running: 

390 

391 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

392 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

393 

394 # See if any jobs have finished 

395 for job in jobs.running: 

396 if not job.process.is_alive(): 

397 _LOG.debug("finished: %s", job) 

398 # finished 

399 exitcode = job.process.exitcode 

400 if exitcode == 0: 

401 jobs.setJobState(job, JobState.FINISHED) 

402 job.cleanup() 

403 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

404 else: 

405 jobs.setJobState(job, JobState.FAILED) 

406 job.cleanup() 

407 _LOG.debug("failed: %s", job) 

408 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

409 for stopJob in jobs.running: 

410 if stopJob is not job: 

411 stopJob.stop() 

412 raise MPGraphExecutorError( 

413 f"Task {job} failed, exit code={exitcode}." 

414 ) 

415 else: 

416 _LOG.error( 

417 "Task %s failed; processing will continue for remaining tasks.", job 

418 ) 

419 else: 

420 # check for timeout 

421 now = time.time() 

422 if now - job.started > self.timeout: 

423 jobs.setJobState(job, JobState.TIMED_OUT) 

424 _LOG.debug("Terminating job %s due to timeout", job) 

425 job.stop() 

426 job.cleanup() 

427 if self.failFast: 

428 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

429 else: 

430 _LOG.error( 

431 "Timeout (%s sec) for task %s; task is killed, processing continues " 

432 "for remaining tasks.", self.timeout, job 

433 ) 

434 

435 # Fail jobs whose inputs failed, this may need several iterations 

436 # if the order is not right, will be done in the next loop. 

437 if jobs.failedNodes: 

438 for job in jobs.pending: 

439 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

440 if jobInputNodes & jobs.failedNodes: 

441 jobs.setJobState(job, JobState.FAILED_DEP) 

442 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

443 

444 # see if we can start more jobs 

445 if len(jobs.running) < self.numProc: 

446 for job in jobs.pending: 

447 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

448 if jobInputNodes <= jobs.finishedNodes: 

449 # all dependencies have completed, can start new job 

450 if len(jobs.running) < self.numProc: 

451 _LOG.debug("Submitting %s", job) 

452 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

453 if len(jobs.running) >= self.numProc: 

454 # Cannot start any more jobs, wait until something 

455 # finishes. 

456 break 

457 

458 # Do cleanup for timed out jobs if necessary. 

459 jobs.cleanup() 

460 

461 # Print progress message if something changed. 

462 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

463 if (finishedCount, failedCount) != (newFinished, newFailed): 

464 finishedCount, failedCount = newFinished, newFailed 

465 totalCount = len(jobs.jobs) 

466 _LOG.info("Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

467 finishedCount, failedCount, totalCount - finishedCount - failedCount, totalCount) 

468 

469 # Here we want to wait until one of the running jobs completes 

470 # but multiprocessing does not provide an API for that, for now 

471 # just sleep a little bit and go back to the loop. 

472 if jobs.running: 

473 time.sleep(0.1) 

474 

475 if jobs.failedNodes: 

476 # print list of failed jobs 

477 _LOG.error("Failed jobs:") 

478 for job in jobs.jobs: 

479 if job.state != JobState.FINISHED: 

480 _LOG.error(" - %s: %s", job.state.name, job) 

481 

482 # if any job failed raise an exception 

483 if jobs.failedNodes == jobs.timedOutNodes: 

484 raise MPTimeoutError("One or more tasks timed out during execution.") 

485 else: 

486 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")