Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%
287 statements
« prev ^ index » next coverage.py v6.4, created at 2022-05-26 09:47 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-05-26 09:47 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
24import gc
25import importlib
26import logging
27import multiprocessing
28import pickle
29import signal
30import sys
31import time
32from enum import Enum
33from typing import Optional
35from lsst.daf.butler.cli.cliLog import CliLog
36from lsst.pipe.base import InvalidQuantumError
37from lsst.pipe.base.graph.graph import QuantumGraph
38from lsst.utils.threads import disable_implicit_threading
40# -----------------------------
41# Imports for other modules --
42# -----------------------------
43from .quantumGraphExecutor import QuantumGraphExecutor
44from .reports import ExecutionStatus, QuantumReport, Report
46_LOG = logging.getLogger(__name__)
49# Possible states for the executing task:
50# - PENDING: job has not started yet
51# - RUNNING: job is currently executing
52# - FINISHED: job finished successfully
53# - FAILED: job execution failed (process returned non-zero status)
54# - TIMED_OUT: job is killed due to too long execution time
55# - FAILED_DEP: one of the dependencies of this job has failed/timed out
56JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
59class _Job:
60 """Class representing a job running single task.
62 Parameters
63 ----------
64 qnode: `~lsst.pipe.base.QuantumNode`
65 Quantum and some associated information.
66 """
68 def __init__(self, qnode):
69 self.qnode = qnode
70 self.process = None
71 self._state = JobState.PENDING
72 self.started = None
73 self._rcv_conn = None
74 self._terminated = False
76 @property
77 def state(self):
78 """Job processing state (JobState)"""
79 return self._state
81 @property
82 def terminated(self):
83 """Return True if job was killed by stop() method and negative exit
84 code is returned from child process. (`bool`)"""
85 return self._terminated and self.process.exitcode < 0
87 def start(self, butler, quantumExecutor, startMethod=None):
88 """Start process which runs the task.
90 Parameters
91 ----------
92 butler : `lsst.daf.butler.Butler`
93 Data butler instance.
94 quantumExecutor : `QuantumExecutor`
95 Executor for single quantum.
96 startMethod : `str`, optional
97 Start method from `multiprocessing` module.
98 """
99 # Unpickling of quantum has to happen after butler, this is why
100 # it is pickled manually here.
101 quantum_pickle = pickle.dumps(self.qnode.quantum)
102 taskDef = self.qnode.taskDef
103 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
104 logConfigState = CliLog.configState
105 mp_ctx = multiprocessing.get_context(startMethod)
106 self.process = mp_ctx.Process(
107 target=_Job._executeJob,
108 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn),
109 name=f"task-{self.qnode.quantum.dataId}",
110 )
111 self.process.start()
112 self.started = time.time()
113 self._state = JobState.RUNNING
115 @staticmethod
116 def _executeJob(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn):
117 """Execute a job with arguments.
119 Parameters
120 ----------
121 quantumExecutor : `QuantumExecutor`
122 Executor for single quantum.
123 taskDef : `bytes`
124 Task definition structure.
125 quantum_pickle : `bytes`
126 Quantum for this task execution in pickled form.
127 butler : `lss.daf.butler.Butler`
128 Data butler instance.
129 snd_conn : `multiprocessing.Connection`
130 Connection to send job report to parent process.
131 """
132 if logConfigState and not CliLog.configState:
133 # means that we are in a new spawned Python process and we have to
134 # re-initialize logging
135 CliLog.replayConfigState(logConfigState)
137 # have to reset connection pool to avoid sharing connections
138 if butler is not None:
139 butler.registry.resetConnectionPool()
141 quantum = pickle.loads(quantum_pickle)
142 try:
143 quantumExecutor.execute(taskDef, quantum, butler)
144 finally:
145 # If sending fails we do not want this new exception to be exposed.
146 try:
147 report = quantumExecutor.getReport()
148 snd_conn.send(report)
149 except Exception:
150 pass
152 def stop(self):
153 """Stop the process."""
154 self.process.terminate()
155 # give it 1 second to finish or KILL
156 for i in range(10):
157 time.sleep(0.1)
158 if not self.process.is_alive():
159 break
160 else:
161 _LOG.debug("Killing process %s", self.process.name)
162 self.process.kill()
163 self._terminated = True
165 def cleanup(self):
166 """Release processes resources, has to be called for each finished
167 process.
168 """
169 if self.process and not self.process.is_alive():
170 self.process.close()
171 self.process = None
172 self._rcv_conn = None
174 def report(self) -> QuantumReport:
175 """Return task report, should be called after process finishes and
176 before cleanup().
177 """
178 try:
179 report = self._rcv_conn.recv()
180 report.exitCode = self.process.exitcode
181 except Exception:
182 # Likely due to the process killed, but there may be other reasons.
183 report = QuantumReport.from_exit_code(
184 exitCode=self.process.exitcode,
185 dataId=self.qnode.quantum.dataId,
186 taskLabel=self.qnode.taskDef.label,
187 )
188 if self.terminated:
189 # Means it was killed, assume it's due to timeout
190 report.status = ExecutionStatus.TIMEOUT
191 return report
193 def failMessage(self):
194 """Return a message describing task failure"""
195 exitcode = self.process.exitcode
196 if exitcode < 0:
197 # Negative exit code means it is killed by signal
198 signum = -exitcode
199 msg = f"Task {self} failed, killed by signal {signum}"
200 # Just in case this is some very odd signal, expect ValueError
201 try:
202 strsignal = signal.strsignal(signum)
203 msg = f"{msg} ({strsignal})"
204 except ValueError:
205 pass
206 elif exitcode > 0:
207 msg = f"Task {self} failed, exit code={exitcode}"
208 else:
209 msg = None
210 return msg
212 def __str__(self):
213 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
216class _JobList:
217 """Simple list of _Job instances with few convenience methods.
219 Parameters
220 ----------
221 iterable : iterable of `~lsst.pipe.base.QuantumNode`
222 Sequence of Quanta to execute. This has to be ordered according to
223 task dependencies.
224 """
226 def __init__(self, iterable):
227 self.jobs = [_Job(qnode) for qnode in iterable]
228 self.pending = self.jobs[:]
229 self.running = []
230 self.finishedNodes = set()
231 self.failedNodes = set()
232 self.timedOutNodes = set()
234 def submit(self, job, butler, quantumExecutor, startMethod=None):
235 """Submit one more job for execution
237 Parameters
238 ----------
239 job : `_Job`
240 Job to submit.
241 butler : `lsst.daf.butler.Butler`
242 Data butler instance.
243 quantumExecutor : `QuantumExecutor`
244 Executor for single quantum.
245 startMethod : `str`, optional
246 Start method from `multiprocessing` module.
247 """
248 # this will raise if job is not in pending list
249 self.pending.remove(job)
250 job.start(butler, quantumExecutor, startMethod)
251 self.running.append(job)
253 def setJobState(self, job, state):
254 """Update job state.
256 Parameters
257 ----------
258 job : `_Job`
259 Job to submit.
260 state : `JobState`
261 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
262 FAILED_DEP state is acceptable.
263 """
264 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
265 assert state in allowedStates, f"State {state} not allowed here"
267 # remove job from pending/running lists
268 if job.state == JobState.PENDING:
269 self.pending.remove(job)
270 elif job.state == JobState.RUNNING:
271 self.running.remove(job)
273 qnode = job.qnode
274 # it should not be in any of these, but just in case
275 self.finishedNodes.discard(qnode)
276 self.failedNodes.discard(qnode)
277 self.timedOutNodes.discard(qnode)
279 job._state = state
280 if state == JobState.FINISHED:
281 self.finishedNodes.add(qnode)
282 elif state == JobState.FAILED:
283 self.failedNodes.add(qnode)
284 elif state == JobState.FAILED_DEP:
285 self.failedNodes.add(qnode)
286 elif state == JobState.TIMED_OUT:
287 self.failedNodes.add(qnode)
288 self.timedOutNodes.add(qnode)
289 else:
290 raise ValueError(f"Unexpected state value: {state}")
292 def cleanup(self):
293 """Do periodic cleanup for jobs that did not finish correctly.
295 If timed out jobs are killed but take too long to stop then regular
296 cleanup will not work for them. Here we check all timed out jobs
297 periodically and do cleanup if they managed to die by this time.
298 """
299 for job in self.jobs:
300 if job.state == JobState.TIMED_OUT and job.process is not None:
301 job.cleanup()
304class MPGraphExecutorError(Exception):
305 """Exception class for errors raised by MPGraphExecutor."""
307 pass
310class MPTimeoutError(MPGraphExecutorError):
311 """Exception raised when task execution times out."""
313 pass
316class MPGraphExecutor(QuantumGraphExecutor):
317 """Implementation of QuantumGraphExecutor using same-host multiprocess
318 execution of Quanta.
320 Parameters
321 ----------
322 numProc : `int`
323 Number of processes to use for executing tasks.
324 timeout : `float`
325 Time in seconds to wait for tasks to finish.
326 quantumExecutor : `QuantumExecutor`
327 Executor for single quantum. For multiprocess-style execution when
328 ``numProc`` is greater than one this instance must support pickle.
329 startMethod : `str`, optional
330 Start method from `multiprocessing` module, `None` selects the best
331 one for current platform.
332 failFast : `bool`, optional
333 If set to ``True`` then stop processing on first error from any task.
334 pdb : `str`, optional
335 Debugger to import and use (via the ``post_mortem`` function) in the
336 event of an exception.
337 executionGraphFixup : `ExecutionGraphFixup`, optional
338 Instance used for modification of execution graph.
339 """
341 def __init__(
342 self,
343 numProc,
344 timeout,
345 quantumExecutor,
346 *,
347 startMethod=None,
348 failFast=False,
349 pdb=None,
350 executionGraphFixup=None,
351 ):
352 self.numProc = numProc
353 self.timeout = timeout
354 self.quantumExecutor = quantumExecutor
355 self.failFast = failFast
356 self.pdb = pdb
357 self.executionGraphFixup = executionGraphFixup
358 self.report: Optional[Report] = None
360 # We set default start method as spawn for MacOS and fork for Linux;
361 # None for all other platforms to use multiprocessing default.
362 if startMethod is None:
363 methods = dict(linux="fork", darwin="spawn")
364 startMethod = methods.get(sys.platform)
365 self.startMethod = startMethod
367 def execute(self, graph, butler):
368 # Docstring inherited from QuantumGraphExecutor.execute
369 graph = self._fixupQuanta(graph)
370 self.report = Report()
371 try:
372 if self.numProc > 1:
373 self._executeQuantaMP(graph, butler)
374 else:
375 self._executeQuantaInProcess(graph, butler)
376 except Exception as exc:
377 self.report.set_exception(exc)
378 raise
380 def _fixupQuanta(self, graph: QuantumGraph):
381 """Call fixup code to modify execution graph.
383 Parameters
384 ----------
385 graph : `QuantumGraph`
386 `QuantumGraph` to modify
388 Returns
389 -------
390 graph : `QuantumGraph`
391 Modified `QuantumGraph`.
393 Raises
394 ------
395 MPGraphExecutorError
396 Raised if execution graph cannot be ordered after modification,
397 i.e. it has dependency cycles.
398 """
399 if not self.executionGraphFixup:
400 return graph
402 _LOG.debug("Call execution graph fixup method")
403 graph = self.executionGraphFixup.fixupQuanta(graph)
405 # Detect if there is now a cycle created within the graph
406 if graph.findCycle():
407 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
409 return graph
411 def _executeQuantaInProcess(self, graph, butler):
412 """Execute all Quanta in current process.
414 Parameters
415 ----------
416 graph : `QuantumGraph`
417 `QuantumGraph` that is to be executed
418 butler : `lsst.daf.butler.Butler`
419 Data butler instance
420 """
421 successCount, totalCount = 0, len(graph)
422 failedNodes = set()
423 for qnode in graph:
425 # Any failed inputs mean that the quantum has to be skipped.
426 inputNodes = graph.determineInputsToQuantumNode(qnode)
427 if inputNodes & failedNodes:
428 _LOG.error(
429 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
430 qnode.taskDef,
431 qnode.quantum.dataId,
432 )
433 failedNodes.add(qnode)
434 quantum_report = QuantumReport(
435 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
436 )
437 self.report.quantaReports.append(quantum_report)
438 continue
440 _LOG.debug("Executing %s", qnode)
441 try:
442 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler)
443 successCount += 1
444 except Exception as exc:
445 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
446 _LOG.error(
447 "Task <%s dataId=%s> failed; dropping into pdb.",
448 qnode.taskDef,
449 qnode.quantum.dataId,
450 exc_info=exc,
451 )
452 try:
453 pdb = importlib.import_module(self.pdb)
454 except ImportError as imp_exc:
455 raise MPGraphExecutorError(
456 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
457 ) from exc
458 if not hasattr(pdb, "post_mortem"):
459 raise MPGraphExecutorError(
460 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
461 ) from exc
462 pdb.post_mortem(exc.__traceback__)
463 failedNodes.add(qnode)
464 self.report.status = ExecutionStatus.FAILURE
465 if self.failFast:
466 raise MPGraphExecutorError(
467 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
468 ) from exc
469 else:
470 # Note that there could be exception safety issues, which
471 # we presently ignore.
472 _LOG.error(
473 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
474 qnode.taskDef,
475 qnode.quantum.dataId,
476 exc_info=exc,
477 )
478 finally:
479 # sqlalchemy has some objects that can last until a garbage
480 # collection cycle is run, which can happen at unpredictable
481 # times, run a collection loop here explicitly.
482 gc.collect()
484 quantum_report = self.quantumExecutor.getReport()
485 if quantum_report:
486 self.report.quantaReports.append(quantum_report)
488 _LOG.info(
489 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
490 successCount,
491 len(failedNodes),
492 totalCount - successCount - len(failedNodes),
493 totalCount,
494 )
496 # Raise an exception if there were any failures.
497 if failedNodes:
498 raise MPGraphExecutorError("One or more tasks failed during execution.")
500 def _executeQuantaMP(self, graph, butler):
501 """Execute all Quanta in separate processes.
503 Parameters
504 ----------
505 graph : `QuantumGraph`
506 `QuantumGraph` that is to be executed.
507 butler : `lsst.daf.butler.Butler`
508 Data butler instance
509 """
511 disable_implicit_threading() # To prevent thread contention
513 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
515 # re-pack input quantum data into jobs list
516 jobs = _JobList(graph)
518 # check that all tasks can run in sub-process
519 for job in jobs.jobs:
520 taskDef = job.qnode.taskDef
521 if not taskDef.taskClass.canMultiprocess:
522 raise MPGraphExecutorError(
523 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
524 )
526 finishedCount, failedCount = 0, 0
527 while jobs.pending or jobs.running:
529 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
530 _LOG.debug("#runningJobs: %s", len(jobs.running))
532 # See if any jobs have finished
533 for job in jobs.running:
534 if not job.process.is_alive():
535 _LOG.debug("finished: %s", job)
536 # finished
537 exitcode = job.process.exitcode
538 quantum_report = job.report()
539 self.report.quantaReports.append(quantum_report)
540 if exitcode == 0:
541 jobs.setJobState(job, JobState.FINISHED)
542 job.cleanup()
543 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
544 else:
545 if job.terminated:
546 # Was killed due to timeout.
547 if self.report.status == ExecutionStatus.SUCCESS:
548 # Do not override global FAILURE status
549 self.report.status = ExecutionStatus.TIMEOUT
550 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
551 jobs.setJobState(job, JobState.TIMED_OUT)
552 else:
553 self.report.status = ExecutionStatus.FAILURE
554 # failMessage() has to be called before cleanup()
555 message = job.failMessage()
556 jobs.setJobState(job, JobState.FAILED)
558 job.cleanup()
559 _LOG.debug("failed: %s", job)
560 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
561 # stop all running jobs
562 for stopJob in jobs.running:
563 if stopJob is not job:
564 stopJob.stop()
565 if job.state is JobState.TIMED_OUT:
566 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
567 else:
568 raise MPGraphExecutorError(message)
569 else:
570 _LOG.error("%s; processing will continue for remaining tasks.", message)
571 else:
572 # check for timeout
573 now = time.time()
574 if now - job.started > self.timeout:
575 # Try to kill it, and there is a chance that it
576 # finishes successfully before it gets killed. Exit
577 # status is handled by the code above on next
578 # iteration.
579 _LOG.debug("Terminating job %s due to timeout", job)
580 job.stop()
582 # Fail jobs whose inputs failed, this may need several iterations
583 # if the order is not right, will be done in the next loop.
584 if jobs.failedNodes:
585 for job in jobs.pending:
586 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
587 if jobInputNodes & jobs.failedNodes:
588 quantum_report = QuantumReport(
589 status=ExecutionStatus.SKIPPED,
590 dataId=job.qnode.quantum.dataId,
591 taskLabel=job.qnode.taskDef.label,
592 )
593 self.report.quantaReports.append(quantum_report)
594 jobs.setJobState(job, JobState.FAILED_DEP)
595 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
597 # see if we can start more jobs
598 if len(jobs.running) < self.numProc:
599 for job in jobs.pending:
600 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
601 if jobInputNodes <= jobs.finishedNodes:
602 # all dependencies have completed, can start new job
603 if len(jobs.running) < self.numProc:
604 _LOG.debug("Submitting %s", job)
605 jobs.submit(job, butler, self.quantumExecutor, self.startMethod)
606 if len(jobs.running) >= self.numProc:
607 # Cannot start any more jobs, wait until something
608 # finishes.
609 break
611 # Do cleanup for timed out jobs if necessary.
612 jobs.cleanup()
614 # Print progress message if something changed.
615 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
616 if (finishedCount, failedCount) != (newFinished, newFailed):
617 finishedCount, failedCount = newFinished, newFailed
618 totalCount = len(jobs.jobs)
619 _LOG.info(
620 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
621 finishedCount,
622 failedCount,
623 totalCount - finishedCount - failedCount,
624 totalCount,
625 )
627 # Here we want to wait until one of the running jobs completes
628 # but multiprocessing does not provide an API for that, for now
629 # just sleep a little bit and go back to the loop.
630 if jobs.running:
631 time.sleep(0.1)
633 if jobs.failedNodes:
634 # print list of failed jobs
635 _LOG.error("Failed jobs:")
636 for job in jobs.jobs:
637 if job.state != JobState.FINISHED:
638 _LOG.error(" - %s: %s", job.state.name, job)
640 # if any job failed raise an exception
641 if jobs.failedNodes == jobs.timedOutNodes:
642 raise MPTimeoutError("One or more tasks timed out during execution.")
643 else:
644 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
646 def getReport(self) -> Optional[Report]:
647 # Docstring inherited from base class
648 if self.report is None:
649 raise RuntimeError("getReport() called before execute()")
650 return self.report