Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24# ------------------------------- 

25# Imports of standard modules -- 

26# ------------------------------- 

27from enum import Enum 

28import logging 

29import multiprocessing 

30import pickle 

31import sys 

32import time 

33 

34from lsst.pipe.base.graph.graph import QuantumGraph 

35 

36# ----------------------------- 

37# Imports for other modules -- 

38# ----------------------------- 

39from .quantumGraphExecutor import QuantumGraphExecutor 

40from lsst.base import disableImplicitThreading 

41from lsst.daf.butler.cli.cliLog import CliLog 

42 

43_LOG = logging.getLogger(__name__.partition(".")[2]) 

44 

45 

46# Possible states for the executing task: 

47# - PENDING: job has not started yet 

48# - RUNNING: job is currently executing 

49# - FINISHED: job finished successfully 

50# - FAILED: job execution failed (process returned non-zero status) 

51# - TIMED_OUT: job is killed due to too long execution time 

52# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

53JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

54 

55 

56class _Job: 

57 """Class representing a job running single task. 

58 

59 Parameters 

60 ---------- 

61 qnode: `~lsst.pipe.base.QuantumNode` 

62 Quantum and some associated information. 

63 """ 

64 def __init__(self, qnode): 

65 self.qnode = qnode 

66 self.process = None 

67 self._state = JobState.PENDING 

68 self.started = None 

69 

70 @property 

71 def state(self): 

72 """Job processing state (JobState)""" 

73 return self._state 

74 

75 def start(self, butler, quantumExecutor, startMethod=None): 

76 """Start process which runs the task. 

77 

78 Parameters 

79 ---------- 

80 butler : `lsst.daf.butler.Butler` 

81 Data butler instance. 

82 quantumExecutor : `QuantumExecutor` 

83 Executor for single quantum. 

84 startMethod : `str`, optional 

85 Start method from `multiprocessing` module. 

86 """ 

87 # Butler can have live database connections which is a problem with 

88 # fork-type activation. Make a pickle of butler to pass that across 

89 # fork. Unpickling of quantum has to happen after butler, this is why 

90 # it is pickled manually here. 

91 butler_pickle = pickle.dumps(butler) 

92 quantum_pickle = pickle.dumps(self.qnode.quantum) 

93 taskDef = self.qnode.taskDef 

94 logConfigState = CliLog.configState 

95 mp_ctx = multiprocessing.get_context(startMethod) 

96 self.process = mp_ctx.Process( 

97 target=_Job._executeJob, 

98 args=(quantumExecutor, taskDef, quantum_pickle, butler_pickle, logConfigState), 

99 name=f"task-{self.qnode.nodeId.number}" 

100 ) 

101 self.process.start() 

102 self.started = time.time() 

103 self._state = JobState.RUNNING 

104 

105 @staticmethod 

106 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler_pickle, logConfigState): 

107 """Execute a job with arguments. 

108 

109 Parameters 

110 ---------- 

111 quantumExecutor : `QuantumExecutor` 

112 Executor for single quantum. 

113 taskDef : `bytes` 

114 Task definition structure. 

115 quantum_pickle : `bytes` 

116 Quantum for this task execution in pickled form. 

117 butler_pickle : `bytes` 

118 Data butler instance in pickled form. 

119 """ 

120 if logConfigState and not CliLog.configState: 

121 # means that we are in a new spawned Python process and we have to 

122 # re-initialize logging 

123 CliLog.replayConfigState(logConfigState) 

124 

125 butler = pickle.loads(butler_pickle) 

126 quantum = pickle.loads(quantum_pickle) 

127 quantumExecutor.execute(taskDef, quantum, butler) 

128 

129 def stop(self): 

130 """Stop the process. 

131 """ 

132 self.process.terminate() 

133 # give it 1 second to finish or KILL 

134 for i in range(10): 

135 time.sleep(0.1) 

136 if not self.process.is_alive(): 

137 break 

138 else: 

139 _LOG.debug("Killing process %s", self.process.name) 

140 self.process.kill() 

141 

142 def cleanup(self): 

143 """Release processes resources, has to be called for each finished 

144 process. 

145 """ 

146 if self.process and not self.process.is_alive(): 

147 self.process.close() 

148 self.process = None 

149 

150 def __str__(self): 

151 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

152 

153 

154class _JobList: 

155 """Simple list of _Job instances with few convenience methods. 

156 

157 Parameters 

158 ---------- 

159 iterable : iterable of `~lsst.pipe.base.QuantumIterData` 

160 Sequence if Quanta to execute. This has to be ordered according to 

161 task dependencies. 

162 """ 

163 def __init__(self, iterable): 

164 self.jobs = [_Job(qnode) for qnode in iterable] 

165 self.pending = self.jobs[:] 

166 self.running = [] 

167 self.finishedNodes = set() 

168 self.failedNodes = set() 

169 self.timedOutNodes = set() 

170 

171 def submit(self, job, butler, quantumExecutor, startMethod=None): 

172 """Submit one more job for execution 

173 

174 Parameters 

175 ---------- 

176 job : `_Job` 

177 Job to submit. 

178 butler : `lsst.daf.butler.Butler` 

179 Data butler instance. 

180 quantumExecutor : `QuantumExecutor` 

181 Executor for single quantum. 

182 startMethod : `str`, optional 

183 Start method from `multiprocessing` module. 

184 """ 

185 # this will raise if job is not in pending list 

186 self.pending.remove(job) 

187 job.start(butler, quantumExecutor, startMethod) 

188 self.running.append(job) 

189 

190 def setJobState(self, job, state): 

191 """Update job state. 

192 

193 Parameters 

194 ---------- 

195 job : `_Job` 

196 Job to submit. 

197 state : `JobState` 

198 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

199 FAILED_DEP state is acceptable. 

200 """ 

201 allowedStates = ( 

202 JobState.FINISHED, 

203 JobState.FAILED, 

204 JobState.TIMED_OUT, 

205 JobState.FAILED_DEP 

206 ) 

207 assert state in allowedStates, f"State {state} not allowed here" 

208 

209 # remove job from pending/running lists 

210 if job.state == JobState.PENDING: 

211 self.pending.remove(job) 

212 elif job.state == JobState.RUNNING: 

213 self.running.remove(job) 

214 

215 qnode = job.qnode 

216 # it should not be in any of these, but just in case 

217 self.finishedNodes.discard(qnode) 

218 self.failedNodes.discard(qnode) 

219 self.timedOutNodes.discard(qnode) 

220 

221 job._state = state 

222 if state == JobState.FINISHED: 

223 self.finishedNodes.add(qnode) 

224 elif state == JobState.FAILED: 

225 self.failedNodes.add(qnode) 

226 elif state == JobState.FAILED_DEP: 

227 self.failedNodes.add(qnode) 

228 elif state == JobState.TIMED_OUT: 

229 self.failedNodes.add(qnode) 

230 self.timedOutNodes.add(qnode) 

231 else: 

232 raise ValueError(f"Unexpected state value: {state}") 

233 

234 def cleanup(self): 

235 """Do periodic cleanup for jobs that did not finish correctly. 

236 

237 If timed out jobs are killed but take too long to stop then regular 

238 cleanup will not work for them. Here we check all timed out jobs 

239 periodically and do cleanup if they managed to die by this time. 

240 """ 

241 for job in self.jobs: 

242 if job.state == JobState.TIMED_OUT and job.process is not None: 

243 job.cleanup() 

244 

245 

246class MPGraphExecutorError(Exception): 

247 """Exception class for errors raised by MPGraphExecutor. 

248 """ 

249 pass 

250 

251 

252class MPTimeoutError(MPGraphExecutorError): 

253 """Exception raised when task execution times out. 

254 """ 

255 pass 

256 

257 

258class MPGraphExecutor(QuantumGraphExecutor): 

259 """Implementation of QuantumGraphExecutor using same-host multiprocess 

260 execution of Quanta. 

261 

262 Parameters 

263 ---------- 

264 numProc : `int` 

265 Number of processes to use for executing tasks. 

266 timeout : `float` 

267 Time in seconds to wait for tasks to finish. 

268 quantumExecutor : `QuantumExecutor` 

269 Executor for single quantum. For multiprocess-style execution when 

270 ``numProc`` is greater than one this instance must support pickle. 

271 startMethod : `str`, optional 

272 Start method from `multiprocessing` module, `None` selects the best 

273 one for current platform. 

274 failFast : `bool`, optional 

275 If set to ``True`` then stop processing on first error from any task. 

276 executionGraphFixup : `ExecutionGraphFixup`, optional 

277 Instance used for modification of execution graph. 

278 """ 

279 def __init__(self, numProc, timeout, quantumExecutor, *, 

280 startMethod=None, failFast=False, executionGraphFixup=None): 

281 self.numProc = numProc 

282 self.timeout = timeout 

283 self.quantumExecutor = quantumExecutor 

284 self.failFast = failFast 

285 self.executionGraphFixup = executionGraphFixup 

286 

287 # We set default start method as spawn for MacOS and fork for Linux; 

288 # None for all other platforms to use multiprocessing default. 

289 if startMethod is None: 

290 methods = dict(linux="fork", darwin="spawn") 

291 startMethod = methods.get(sys.platform) 

292 self.startMethod = startMethod 

293 

294 def execute(self, graph, butler): 

295 # Docstring inherited from QuantumGraphExecutor.execute 

296 graph = self._fixupQuanta(graph) 

297 if self.numProc > 1: 

298 self._executeQuantaMP(graph, butler) 

299 else: 

300 self._executeQuantaInProcess(graph, butler) 

301 

302 def _fixupQuanta(self, graph: QuantumGraph): 

303 """Call fixup code to modify execution graph. 

304 

305 Parameters 

306 ---------- 

307 graph : `QuantumGraph` 

308 `QuantumGraph` to modify 

309 

310 Returns 

311 ------- 

312 graph : `QuantumGraph` 

313 Modified `QuantumGraph`. 

314 

315 Raises 

316 ------ 

317 MPGraphExecutorError 

318 Raised if execution graph cannot be ordered after modification, 

319 i.e. it has dependency cycles. 

320 """ 

321 if not self.executionGraphFixup: 

322 return graph 

323 

324 _LOG.debug("Call execution graph fixup method") 

325 graph = self.executionGraphFixup.fixupQuanta(graph) 

326 

327 # Detect if there is now a cycle created within the graph 

328 if graph.findCycle(): 

329 raise MPGraphExecutorError( 

330 "Updated execution graph has dependency cycle.") 

331 

332 return graph 

333 

334 def _executeQuantaInProcess(self, graph, butler): 

335 """Execute all Quanta in current process. 

336 

337 Parameters 

338 ---------- 

339 graph : `QuantumGraph` 

340 `QuantumGraph` that is to be executed 

341 butler : `lsst.daf.butler.Butler` 

342 Data butler instance 

343 """ 

344 # Note that in non-MP case any failed task will generate an exception 

345 # and kill the whole thing. In general we cannot guarantee exception 

346 # safety so easiest and safest thing is to let it die. 

347 count, totalCount = 0, len(graph) 

348 for qnode in graph: 

349 _LOG.debug("Executing %s", qnode) 

350 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

351 count += 1 

352 _LOG.info("Executed %d quanta, %d remain out of total %d quanta.", 

353 count, totalCount - count, totalCount) 

354 

355 def _executeQuantaMP(self, graph, butler): 

356 """Execute all Quanta in separate processes. 

357 

358 Parameters 

359 ---------- 

360 graph : `QuantumGraph` 

361 `QuantumGraph` that is to be executed. 

362 butler : `lsst.daf.butler.Butler` 

363 Data butler instance 

364 """ 

365 

366 disableImplicitThreading() # To prevent thread contention 

367 

368 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

369 

370 # re-pack input quantum data into jobs list 

371 jobs = _JobList(graph) 

372 

373 # check that all tasks can run in sub-process 

374 for job in jobs.jobs: 

375 taskDef = job.qnode.taskDef 

376 if not taskDef.taskClass.canMultiprocess: 

377 raise MPGraphExecutorError(f"Task {taskDef.taskName} does not support multiprocessing;" 

378 " use single process") 

379 

380 finishedCount, failedCount = 0, 0 

381 while jobs.pending or jobs.running: 

382 

383 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

384 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

385 

386 # See if any jobs have finished 

387 for job in jobs.running: 

388 if not job.process.is_alive(): 

389 _LOG.debug("finished: %s", job) 

390 # finished 

391 exitcode = job.process.exitcode 

392 if exitcode == 0: 

393 jobs.setJobState(job, JobState.FINISHED) 

394 job.cleanup() 

395 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

396 else: 

397 jobs.setJobState(job, JobState.FAILED) 

398 job.cleanup() 

399 _LOG.debug("failed: %s", job) 

400 if self.failFast: 

401 for stopJob in jobs.running: 

402 if stopJob is not job: 

403 stopJob.stop() 

404 raise MPGraphExecutorError( 

405 f"Task {job} failed, exit code={exitcode}." 

406 ) 

407 else: 

408 _LOG.error( 

409 "Task %s failed; processing will continue for remaining tasks.", job 

410 ) 

411 else: 

412 # check for timeout 

413 now = time.time() 

414 if now - job.started > self.timeout: 

415 jobs.setJobState(job, JobState.TIMED_OUT) 

416 _LOG.debug("Terminating job %s due to timeout", job) 

417 job.stop() 

418 job.cleanup() 

419 if self.failFast: 

420 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

421 else: 

422 _LOG.error( 

423 "Timeout (%s sec) for task %s; task is killed, processing continues " 

424 "for remaining tasks.", self.timeout, job 

425 ) 

426 

427 # Fail jobs whose inputs failed, this may need several iterations 

428 # if the order is not right, will be done in the next loop. 

429 if jobs.failedNodes: 

430 for job in jobs.pending: 

431 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

432 if jobInputNodes & jobs.failedNodes: 

433 jobs.setJobState(job, JobState.FAILED_DEP) 

434 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

435 

436 # see if we can start more jobs 

437 if len(jobs.running) < self.numProc: 

438 for job in jobs.pending: 

439 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

440 if jobInputNodes <= jobs.finishedNodes: 

441 # all dependencies have completed, can start new job 

442 if len(jobs.running) < self.numProc: 

443 _LOG.debug("Submitting %s", job) 

444 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

445 if len(jobs.running) >= self.numProc: 

446 # cannot start any more jobs, wait until something finishes 

447 break 

448 

449 # Do cleanup for timed out jobs if necessary. 

450 jobs.cleanup() 

451 

452 # Print progress message if something changed. 

453 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

454 if (finishedCount, failedCount) != (newFinished, newFailed): 

455 finishedCount, failedCount = newFinished, newFailed 

456 totalCount = len(jobs.jobs) 

457 _LOG.info("Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

458 finishedCount, failedCount, totalCount - finishedCount - failedCount, totalCount) 

459 

460 # Here we want to wait until one of the running jobs completes 

461 # but multiprocessing does not provide an API for that, for now 

462 # just sleep a little bit and go back to the loop. 

463 if jobs.running: 

464 time.sleep(0.1) 

465 

466 if jobs.failedNodes: 

467 # print list of failed jobs 

468 _LOG.error("Failed jobs:") 

469 for job in jobs.jobs: 

470 if job.state != JobState.FINISHED: 

471 _LOG.error(" - %s: %s", job.state.name, job) 

472 

473 # if any job failed raise an exception 

474 if jobs.failedNodes == jobs.timedOutNodes: 

475 raise MPTimeoutError("One or more tasks timed out during execution.") 

476 else: 

477 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")