Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"] 

23 

24# ------------------------------- 

25# Imports of standard modules -- 

26# ------------------------------- 

27from enum import Enum 

28import logging 

29import multiprocessing 

30import pickle 

31import sys 

32import time 

33 

34from lsst.pipe.base.graph.graph import QuantumGraph 

35from lsst.pipe.base import InvalidQuantumError 

36 

37# ----------------------------- 

38# Imports for other modules -- 

39# ----------------------------- 

40from .quantumGraphExecutor import QuantumGraphExecutor 

41from lsst.base import disableImplicitThreading 

42from lsst.daf.butler.cli.cliLog import CliLog 

43 

44_LOG = logging.getLogger(__name__.partition(".")[2]) 

45 

46 

47# Possible states for the executing task: 

48# - PENDING: job has not started yet 

49# - RUNNING: job is currently executing 

50# - FINISHED: job finished successfully 

51# - FAILED: job execution failed (process returned non-zero status) 

52# - TIMED_OUT: job is killed due to too long execution time 

53# - FAILED_DEP: one of the dependencies of this job has failed/timed out 

54JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP") 

55 

56 

57class _Job: 

58 """Class representing a job running single task. 

59 

60 Parameters 

61 ---------- 

62 qnode: `~lsst.pipe.base.QuantumNode` 

63 Quantum and some associated information. 

64 """ 

65 def __init__(self, qnode): 

66 self.qnode = qnode 

67 self.process = None 

68 self._state = JobState.PENDING 

69 self.started = None 

70 

71 @property 

72 def state(self): 

73 """Job processing state (JobState)""" 

74 return self._state 

75 

76 def start(self, butler, quantumExecutor, startMethod=None): 

77 """Start process which runs the task. 

78 

79 Parameters 

80 ---------- 

81 butler : `lsst.daf.butler.Butler` 

82 Data butler instance. 

83 quantumExecutor : `QuantumExecutor` 

84 Executor for single quantum. 

85 startMethod : `str`, optional 

86 Start method from `multiprocessing` module. 

87 """ 

88 # Unpickling of quantum has to happen after butler, this is why 

89 # it is pickled manually here. 

90 quantum_pickle = pickle.dumps(self.qnode.quantum) 

91 taskDef = self.qnode.taskDef 

92 logConfigState = CliLog.configState 

93 mp_ctx = multiprocessing.get_context(startMethod) 

94 self.process = mp_ctx.Process( 

95 target=_Job._executeJob, 

96 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState), 

97 name=f"task-{self.qnode.nodeId.number}" 

98 ) 

99 self.process.start() 

100 self.started = time.time() 

101 self._state = JobState.RUNNING 

102 

103 @staticmethod 

104 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState): 

105 """Execute a job with arguments. 

106 

107 Parameters 

108 ---------- 

109 quantumExecutor : `QuantumExecutor` 

110 Executor for single quantum. 

111 taskDef : `bytes` 

112 Task definition structure. 

113 quantum_pickle : `bytes` 

114 Quantum for this task execution in pickled form. 

115 butler : `lss.daf.butler.Butler` 

116 Data butler instance. 

117 """ 

118 if logConfigState and not CliLog.configState: 

119 # means that we are in a new spawned Python process and we have to 

120 # re-initialize logging 

121 CliLog.replayConfigState(logConfigState) 

122 

123 # have to reset connection pool to avoid sharing connections 

124 if butler is not None: 

125 butler.registry.resetConnectionPool() 

126 

127 quantum = pickle.loads(quantum_pickle) 

128 quantumExecutor.execute(taskDef, quantum, butler) 

129 

130 def stop(self): 

131 """Stop the process. 

132 """ 

133 self.process.terminate() 

134 # give it 1 second to finish or KILL 

135 for i in range(10): 

136 time.sleep(0.1) 

137 if not self.process.is_alive(): 

138 break 

139 else: 

140 _LOG.debug("Killing process %s", self.process.name) 

141 self.process.kill() 

142 

143 def cleanup(self): 

144 """Release processes resources, has to be called for each finished 

145 process. 

146 """ 

147 if self.process and not self.process.is_alive(): 

148 self.process.close() 

149 self.process = None 

150 

151 def __str__(self): 

152 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>" 

153 

154 

155class _JobList: 

156 """Simple list of _Job instances with few convenience methods. 

157 

158 Parameters 

159 ---------- 

160 iterable : iterable of `~lsst.pipe.base.QuantumIterData` 

161 Sequence if Quanta to execute. This has to be ordered according to 

162 task dependencies. 

163 """ 

164 def __init__(self, iterable): 

165 self.jobs = [_Job(qnode) for qnode in iterable] 

166 self.pending = self.jobs[:] 

167 self.running = [] 

168 self.finishedNodes = set() 

169 self.failedNodes = set() 

170 self.timedOutNodes = set() 

171 

172 def submit(self, job, butler, quantumExecutor, startMethod=None): 

173 """Submit one more job for execution 

174 

175 Parameters 

176 ---------- 

177 job : `_Job` 

178 Job to submit. 

179 butler : `lsst.daf.butler.Butler` 

180 Data butler instance. 

181 quantumExecutor : `QuantumExecutor` 

182 Executor for single quantum. 

183 startMethod : `str`, optional 

184 Start method from `multiprocessing` module. 

185 """ 

186 # this will raise if job is not in pending list 

187 self.pending.remove(job) 

188 job.start(butler, quantumExecutor, startMethod) 

189 self.running.append(job) 

190 

191 def setJobState(self, job, state): 

192 """Update job state. 

193 

194 Parameters 

195 ---------- 

196 job : `_Job` 

197 Job to submit. 

198 state : `JobState` 

199 New job state, note that only FINISHED, FAILED, TIMED_OUT, or 

200 FAILED_DEP state is acceptable. 

201 """ 

202 allowedStates = ( 

203 JobState.FINISHED, 

204 JobState.FAILED, 

205 JobState.TIMED_OUT, 

206 JobState.FAILED_DEP 

207 ) 

208 assert state in allowedStates, f"State {state} not allowed here" 

209 

210 # remove job from pending/running lists 

211 if job.state == JobState.PENDING: 

212 self.pending.remove(job) 

213 elif job.state == JobState.RUNNING: 

214 self.running.remove(job) 

215 

216 qnode = job.qnode 

217 # it should not be in any of these, but just in case 

218 self.finishedNodes.discard(qnode) 

219 self.failedNodes.discard(qnode) 

220 self.timedOutNodes.discard(qnode) 

221 

222 job._state = state 

223 if state == JobState.FINISHED: 

224 self.finishedNodes.add(qnode) 

225 elif state == JobState.FAILED: 

226 self.failedNodes.add(qnode) 

227 elif state == JobState.FAILED_DEP: 

228 self.failedNodes.add(qnode) 

229 elif state == JobState.TIMED_OUT: 

230 self.failedNodes.add(qnode) 

231 self.timedOutNodes.add(qnode) 

232 else: 

233 raise ValueError(f"Unexpected state value: {state}") 

234 

235 def cleanup(self): 

236 """Do periodic cleanup for jobs that did not finish correctly. 

237 

238 If timed out jobs are killed but take too long to stop then regular 

239 cleanup will not work for them. Here we check all timed out jobs 

240 periodically and do cleanup if they managed to die by this time. 

241 """ 

242 for job in self.jobs: 

243 if job.state == JobState.TIMED_OUT and job.process is not None: 

244 job.cleanup() 

245 

246 

247class MPGraphExecutorError(Exception): 

248 """Exception class for errors raised by MPGraphExecutor. 

249 """ 

250 pass 

251 

252 

253class MPTimeoutError(MPGraphExecutorError): 

254 """Exception raised when task execution times out. 

255 """ 

256 pass 

257 

258 

259class MPGraphExecutor(QuantumGraphExecutor): 

260 """Implementation of QuantumGraphExecutor using same-host multiprocess 

261 execution of Quanta. 

262 

263 Parameters 

264 ---------- 

265 numProc : `int` 

266 Number of processes to use for executing tasks. 

267 timeout : `float` 

268 Time in seconds to wait for tasks to finish. 

269 quantumExecutor : `QuantumExecutor` 

270 Executor for single quantum. For multiprocess-style execution when 

271 ``numProc`` is greater than one this instance must support pickle. 

272 startMethod : `str`, optional 

273 Start method from `multiprocessing` module, `None` selects the best 

274 one for current platform. 

275 failFast : `bool`, optional 

276 If set to ``True`` then stop processing on first error from any task. 

277 executionGraphFixup : `ExecutionGraphFixup`, optional 

278 Instance used for modification of execution graph. 

279 """ 

280 def __init__(self, numProc, timeout, quantumExecutor, *, 

281 startMethod=None, failFast=False, executionGraphFixup=None): 

282 self.numProc = numProc 

283 self.timeout = timeout 

284 self.quantumExecutor = quantumExecutor 

285 self.failFast = failFast 

286 self.executionGraphFixup = executionGraphFixup 

287 

288 # We set default start method as spawn for MacOS and fork for Linux; 

289 # None for all other platforms to use multiprocessing default. 

290 if startMethod is None: 

291 methods = dict(linux="fork", darwin="spawn") 

292 startMethod = methods.get(sys.platform) 

293 self.startMethod = startMethod 

294 

295 def execute(self, graph, butler): 

296 # Docstring inherited from QuantumGraphExecutor.execute 

297 graph = self._fixupQuanta(graph) 

298 if self.numProc > 1: 

299 self._executeQuantaMP(graph, butler) 

300 else: 

301 self._executeQuantaInProcess(graph, butler) 

302 

303 def _fixupQuanta(self, graph: QuantumGraph): 

304 """Call fixup code to modify execution graph. 

305 

306 Parameters 

307 ---------- 

308 graph : `QuantumGraph` 

309 `QuantumGraph` to modify 

310 

311 Returns 

312 ------- 

313 graph : `QuantumGraph` 

314 Modified `QuantumGraph`. 

315 

316 Raises 

317 ------ 

318 MPGraphExecutorError 

319 Raised if execution graph cannot be ordered after modification, 

320 i.e. it has dependency cycles. 

321 """ 

322 if not self.executionGraphFixup: 

323 return graph 

324 

325 _LOG.debug("Call execution graph fixup method") 

326 graph = self.executionGraphFixup.fixupQuanta(graph) 

327 

328 # Detect if there is now a cycle created within the graph 

329 if graph.findCycle(): 

330 raise MPGraphExecutorError( 

331 "Updated execution graph has dependency cycle.") 

332 

333 return graph 

334 

335 def _executeQuantaInProcess(self, graph, butler): 

336 """Execute all Quanta in current process. 

337 

338 Parameters 

339 ---------- 

340 graph : `QuantumGraph` 

341 `QuantumGraph` that is to be executed 

342 butler : `lsst.daf.butler.Butler` 

343 Data butler instance 

344 """ 

345 # Note that in non-MP case any failed task will generate an exception 

346 # and kill the whole thing. In general we cannot guarantee exception 

347 # safety so easiest and safest thing is to let it die. 

348 count, totalCount = 0, len(graph) 

349 for qnode in graph: 

350 _LOG.debug("Executing %s", qnode) 

351 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler) 

352 count += 1 

353 _LOG.info("Executed %d quanta, %d remain out of total %d quanta.", 

354 count, totalCount - count, totalCount) 

355 

356 def _executeQuantaMP(self, graph, butler): 

357 """Execute all Quanta in separate processes. 

358 

359 Parameters 

360 ---------- 

361 graph : `QuantumGraph` 

362 `QuantumGraph` that is to be executed. 

363 butler : `lsst.daf.butler.Butler` 

364 Data butler instance 

365 """ 

366 

367 disableImplicitThreading() # To prevent thread contention 

368 

369 _LOG.debug("Using %r for multiprocessing start method", self.startMethod) 

370 

371 # re-pack input quantum data into jobs list 

372 jobs = _JobList(graph) 

373 

374 # check that all tasks can run in sub-process 

375 for job in jobs.jobs: 

376 taskDef = job.qnode.taskDef 

377 if not taskDef.taskClass.canMultiprocess: 

378 raise MPGraphExecutorError(f"Task {taskDef.taskName} does not support multiprocessing;" 

379 " use single process") 

380 

381 finishedCount, failedCount = 0, 0 

382 while jobs.pending or jobs.running: 

383 

384 _LOG.debug("#pendingJobs: %s", len(jobs.pending)) 

385 _LOG.debug("#runningJobs: %s", len(jobs.running)) 

386 

387 # See if any jobs have finished 

388 for job in jobs.running: 

389 if not job.process.is_alive(): 

390 _LOG.debug("finished: %s", job) 

391 # finished 

392 exitcode = job.process.exitcode 

393 if exitcode == 0: 

394 jobs.setJobState(job, JobState.FINISHED) 

395 job.cleanup() 

396 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started) 

397 else: 

398 jobs.setJobState(job, JobState.FAILED) 

399 job.cleanup() 

400 _LOG.debug("failed: %s", job) 

401 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE: 

402 for stopJob in jobs.running: 

403 if stopJob is not job: 

404 stopJob.stop() 

405 raise MPGraphExecutorError( 

406 f"Task {job} failed, exit code={exitcode}." 

407 ) 

408 else: 

409 _LOG.error( 

410 "Task %s failed; processing will continue for remaining tasks.", job 

411 ) 

412 else: 

413 # check for timeout 

414 now = time.time() 

415 if now - job.started > self.timeout: 

416 jobs.setJobState(job, JobState.TIMED_OUT) 

417 _LOG.debug("Terminating job %s due to timeout", job) 

418 job.stop() 

419 job.cleanup() 

420 if self.failFast: 

421 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.") 

422 else: 

423 _LOG.error( 

424 "Timeout (%s sec) for task %s; task is killed, processing continues " 

425 "for remaining tasks.", self.timeout, job 

426 ) 

427 

428 # Fail jobs whose inputs failed, this may need several iterations 

429 # if the order is not right, will be done in the next loop. 

430 if jobs.failedNodes: 

431 for job in jobs.pending: 

432 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

433 if jobInputNodes & jobs.failedNodes: 

434 jobs.setJobState(job, JobState.FAILED_DEP) 

435 _LOG.error("Upstream job failed for task %s, skipping this task.", job) 

436 

437 # see if we can start more jobs 

438 if len(jobs.running) < self.numProc: 

439 for job in jobs.pending: 

440 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode) 

441 if jobInputNodes <= jobs.finishedNodes: 

442 # all dependencies have completed, can start new job 

443 if len(jobs.running) < self.numProc: 

444 _LOG.debug("Submitting %s", job) 

445 jobs.submit(job, butler, self.quantumExecutor, self.startMethod) 

446 if len(jobs.running) >= self.numProc: 

447 # Cannot start any more jobs, wait until something 

448 # finishes. 

449 break 

450 

451 # Do cleanup for timed out jobs if necessary. 

452 jobs.cleanup() 

453 

454 # Print progress message if something changed. 

455 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes) 

456 if (finishedCount, failedCount) != (newFinished, newFailed): 

457 finishedCount, failedCount = newFinished, newFailed 

458 totalCount = len(jobs.jobs) 

459 _LOG.info("Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.", 

460 finishedCount, failedCount, totalCount - finishedCount - failedCount, totalCount) 

461 

462 # Here we want to wait until one of the running jobs completes 

463 # but multiprocessing does not provide an API for that, for now 

464 # just sleep a little bit and go back to the loop. 

465 if jobs.running: 

466 time.sleep(0.1) 

467 

468 if jobs.failedNodes: 

469 # print list of failed jobs 

470 _LOG.error("Failed jobs:") 

471 for job in jobs.jobs: 

472 if job.state != JobState.FINISHED: 

473 _LOG.error(" - %s: %s", job.state.name, job) 

474 

475 # if any job failed raise an exception 

476 if jobs.failedNodes == jobs.timedOutNodes: 

477 raise MPTimeoutError("One or more tasks timed out during execution.") 

478 else: 

479 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")