Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 12%
307 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 09:14 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 09:14 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import threading
34import time
35from collections.abc import Iterable
36from enum import Enum
37from typing import Literal
39from lsst.daf.butler.cli.cliLog import CliLog
40from lsst.pipe.base import InvalidQuantumError, TaskDef
41from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
42from lsst.utils.threads import disable_implicit_threading
44from .executionGraphFixup import ExecutionGraphFixup
45from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
46from .reports import ExecutionStatus, QuantumReport, Report
48_LOG = logging.getLogger(__name__)
51# Possible states for the executing task:
52# - PENDING: job has not started yet
53# - RUNNING: job is currently executing
54# - FINISHED: job finished successfully
55# - FAILED: job execution failed (process returned non-zero status)
56# - TIMED_OUT: job is killed due to too long execution time
57# - FAILED_DEP: one of the dependencies of this job has failed/timed out
58JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
61class _Job:
62 """Class representing a job running single task.
64 Parameters
65 ----------
66 qnode: `~lsst.pipe.base.QuantumNode`
67 Quantum and some associated information.
68 """
70 def __init__(self, qnode: QuantumNode):
71 self.qnode = qnode
72 self.process: multiprocessing.process.BaseProcess | None = None
73 self._state = JobState.PENDING
74 self.started: float = 0.0
75 self._rcv_conn: multiprocessing.connection.Connection | None = None
76 self._terminated = False
78 @property
79 def state(self) -> JobState:
80 """Job processing state (JobState)"""
81 return self._state
83 @property
84 def terminated(self) -> bool:
85 """Return `True` if job was killed by stop() method and negative exit
86 code is returned from child process. (`bool`)
87 """
88 if self._terminated:
89 assert self.process is not None, "Process must be started"
90 if self.process.exitcode is not None:
91 return self.process.exitcode < 0
92 return False
94 def start(
95 self,
96 quantumExecutor: QuantumExecutor,
97 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
98 ) -> None:
99 """Start process which runs the task.
101 Parameters
102 ----------
103 quantumExecutor : `QuantumExecutor`
104 Executor for single quantum.
105 startMethod : `str`, optional
106 Start method from `multiprocessing` module.
107 """
108 # Unpickling of quantum has to happen after butler/executor, this is
109 # why it is pickled manually here.
110 quantum_pickle = pickle.dumps(self.qnode.quantum)
111 taskDef = self.qnode.taskDef
112 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
113 logConfigState = CliLog.configState
115 mp_ctx = multiprocessing.get_context(startMethod)
116 self.process = mp_ctx.Process(
117 target=_Job._executeJob,
118 args=(quantumExecutor, taskDef, quantum_pickle, logConfigState, snd_conn),
119 name=f"task-{self.qnode.quantum.dataId}",
120 )
121 self.process.start()
122 self.started = time.time()
123 self._state = JobState.RUNNING
125 @staticmethod
126 def _executeJob(
127 quantumExecutor: QuantumExecutor,
128 taskDef: TaskDef,
129 quantum_pickle: bytes,
130 logConfigState: list,
131 snd_conn: multiprocessing.connection.Connection,
132 ) -> None:
133 """Execute a job with arguments.
135 Parameters
136 ----------
137 quantumExecutor : `QuantumExecutor`
138 Executor for single quantum.
139 taskDef : `bytes`
140 Task definition structure.
141 quantum_pickle : `bytes`
142 Quantum for this task execution in pickled form.
143 snd_conn : `multiprocessing.Connection`
144 Connection to send job report to parent process.
145 """
146 # This terrible hack is a workaround for Python threading bug:
147 # https://github.com/python/cpython/issues/102512. Should be removed
148 # when fix for that bug is deployed. Inspired by
149 # https://github.com/QubesOS/qubes-core-admin-client/pull/236/files.
150 thread = threading.current_thread()
151 if isinstance(thread, threading._DummyThread):
152 if getattr(thread, "_tstate_lock", "") is None:
153 thread._set_tstate_lock() # type: ignore[attr-defined]
155 if logConfigState and not CliLog.configState:
156 # means that we are in a new spawned Python process and we have to
157 # re-initialize logging
158 CliLog.replayConfigState(logConfigState)
160 quantum = pickle.loads(quantum_pickle)
161 try:
162 quantumExecutor.execute(taskDef, quantum)
163 finally:
164 # If sending fails we do not want this new exception to be exposed.
165 try:
166 report = quantumExecutor.getReport()
167 snd_conn.send(report)
168 except Exception:
169 pass
171 def stop(self) -> None:
172 """Stop the process."""
173 assert self.process is not None, "Process must be started"
174 self.process.terminate()
175 # give it 1 second to finish or KILL
176 for i in range(10):
177 time.sleep(0.1)
178 if not self.process.is_alive():
179 break
180 else:
181 _LOG.debug("Killing process %s", self.process.name)
182 self.process.kill()
183 self._terminated = True
185 def cleanup(self) -> None:
186 """Release processes resources, has to be called for each finished
187 process.
188 """
189 if self.process and not self.process.is_alive():
190 self.process.close()
191 self.process = None
192 self._rcv_conn = None
194 def report(self) -> QuantumReport:
195 """Return task report, should be called after process finishes and
196 before cleanup().
197 """
198 assert self.process is not None, "Process must be started"
199 assert self._rcv_conn is not None, "Process must be started"
200 try:
201 report = self._rcv_conn.recv()
202 report.exitCode = self.process.exitcode
203 except Exception:
204 # Likely due to the process killed, but there may be other reasons.
205 # Exit code should not be None, this is to keep mypy happy.
206 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
207 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
208 report = QuantumReport.from_exit_code(
209 exitCode=exitcode,
210 dataId=self.qnode.quantum.dataId,
211 taskLabel=self.qnode.taskDef.label,
212 )
213 if self.terminated:
214 # Means it was killed, assume it's due to timeout
215 report.status = ExecutionStatus.TIMEOUT
216 return report
218 def failMessage(self) -> str:
219 """Return a message describing task failure"""
220 assert self.process is not None, "Process must be started"
221 assert self.process.exitcode is not None, "Process has to finish"
222 exitcode = self.process.exitcode
223 if exitcode < 0:
224 # Negative exit code means it is killed by signal
225 signum = -exitcode
226 msg = f"Task {self} failed, killed by signal {signum}"
227 # Just in case this is some very odd signal, expect ValueError
228 try:
229 strsignal = signal.strsignal(signum)
230 msg = f"{msg} ({strsignal})"
231 except ValueError:
232 pass
233 elif exitcode > 0:
234 msg = f"Task {self} failed, exit code={exitcode}"
235 else:
236 msg = ""
237 return msg
239 def __str__(self) -> str:
240 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
243class _JobList:
244 """Simple list of _Job instances with few convenience methods.
246 Parameters
247 ----------
248 iterable : iterable of `~lsst.pipe.base.QuantumNode`
249 Sequence of Quanta to execute. This has to be ordered according to
250 task dependencies.
251 """
253 def __init__(self, iterable: Iterable[QuantumNode]):
254 self.jobs = [_Job(qnode) for qnode in iterable]
255 self.pending = self.jobs[:]
256 self.running: list[_Job] = []
257 self.finishedNodes: set[QuantumNode] = set()
258 self.failedNodes: set[QuantumNode] = set()
259 self.timedOutNodes: set[QuantumNode] = set()
261 def submit(
262 self,
263 job: _Job,
264 quantumExecutor: QuantumExecutor,
265 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
266 ) -> None:
267 """Submit one more job for execution
269 Parameters
270 ----------
271 job : `_Job`
272 Job to submit.
273 quantumExecutor : `QuantumExecutor`
274 Executor for single quantum.
275 startMethod : `str`, optional
276 Start method from `multiprocessing` module.
277 """
278 # this will raise if job is not in pending list
279 self.pending.remove(job)
280 job.start(quantumExecutor, startMethod)
281 self.running.append(job)
283 def setJobState(self, job: _Job, state: JobState) -> None:
284 """Update job state.
286 Parameters
287 ----------
288 job : `_Job`
289 Job to submit.
290 state : `JobState`
291 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
292 FAILED_DEP state is acceptable.
293 """
294 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
295 assert state in allowedStates, f"State {state} not allowed here"
297 # remove job from pending/running lists
298 if job.state == JobState.PENDING:
299 self.pending.remove(job)
300 elif job.state == JobState.RUNNING:
301 self.running.remove(job)
303 qnode = job.qnode
304 # it should not be in any of these, but just in case
305 self.finishedNodes.discard(qnode)
306 self.failedNodes.discard(qnode)
307 self.timedOutNodes.discard(qnode)
309 job._state = state
310 if state == JobState.FINISHED:
311 self.finishedNodes.add(qnode)
312 elif state == JobState.FAILED:
313 self.failedNodes.add(qnode)
314 elif state == JobState.FAILED_DEP:
315 self.failedNodes.add(qnode)
316 elif state == JobState.TIMED_OUT:
317 self.failedNodes.add(qnode)
318 self.timedOutNodes.add(qnode)
319 else:
320 raise ValueError(f"Unexpected state value: {state}")
322 def cleanup(self) -> None:
323 """Do periodic cleanup for jobs that did not finish correctly.
325 If timed out jobs are killed but take too long to stop then regular
326 cleanup will not work for them. Here we check all timed out jobs
327 periodically and do cleanup if they managed to die by this time.
328 """
329 for job in self.jobs:
330 if job.state == JobState.TIMED_OUT and job.process is not None:
331 job.cleanup()
334class MPGraphExecutorError(Exception):
335 """Exception class for errors raised by MPGraphExecutor."""
337 pass
340class MPTimeoutError(MPGraphExecutorError):
341 """Exception raised when task execution times out."""
343 pass
346class MPGraphExecutor(QuantumGraphExecutor):
347 """Implementation of QuantumGraphExecutor using same-host multiprocess
348 execution of Quanta.
350 Parameters
351 ----------
352 numProc : `int`
353 Number of processes to use for executing tasks.
354 timeout : `float`
355 Time in seconds to wait for tasks to finish.
356 quantumExecutor : `QuantumExecutor`
357 Executor for single quantum. For multiprocess-style execution when
358 ``numProc`` is greater than one this instance must support pickle.
359 startMethod : `str`, optional
360 Start method from `multiprocessing` module, `None` selects the best
361 one for current platform.
362 failFast : `bool`, optional
363 If set to ``True`` then stop processing on first error from any task.
364 pdb : `str`, optional
365 Debugger to import and use (via the ``post_mortem`` function) in the
366 event of an exception.
367 executionGraphFixup : `ExecutionGraphFixup`, optional
368 Instance used for modification of execution graph.
369 """
371 def __init__(
372 self,
373 numProc: int,
374 timeout: float,
375 quantumExecutor: QuantumExecutor,
376 *,
377 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
378 failFast: bool = False,
379 pdb: str | None = None,
380 executionGraphFixup: ExecutionGraphFixup | None = None,
381 ):
382 self.numProc = numProc
383 self.timeout = timeout
384 self.quantumExecutor = quantumExecutor
385 self.failFast = failFast
386 self.pdb = pdb
387 self.executionGraphFixup = executionGraphFixup
388 self.report: Report | None = None
390 # We set default start method as spawn for MacOS and fork for Linux;
391 # None for all other platforms to use multiprocessing default.
392 if startMethod is None:
393 methods = dict(linux="fork", darwin="spawn")
394 startMethod = methods.get(sys.platform) # type: ignore
395 self.startMethod = startMethod
397 def execute(self, graph: QuantumGraph) -> None:
398 # Docstring inherited from QuantumGraphExecutor.execute
399 graph = self._fixupQuanta(graph)
400 self.report = Report()
401 try:
402 if self.numProc > 1:
403 self._executeQuantaMP(graph, self.report)
404 else:
405 self._executeQuantaInProcess(graph, self.report)
406 except Exception as exc:
407 self.report.set_exception(exc)
408 raise
410 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
411 """Call fixup code to modify execution graph.
413 Parameters
414 ----------
415 graph : `~lsst.pipe.base.QuantumGraph`
416 `~lsst.pipe.base.QuantumGraph` to modify.
418 Returns
419 -------
420 graph : `~lsst.pipe.base.QuantumGraph`
421 Modified `~lsst.pipe.base.QuantumGraph`.
423 Raises
424 ------
425 MPGraphExecutorError
426 Raised if execution graph cannot be ordered after modification,
427 i.e. it has dependency cycles.
428 """
429 if not self.executionGraphFixup:
430 return graph
432 _LOG.debug("Call execution graph fixup method")
433 graph = self.executionGraphFixup.fixupQuanta(graph)
435 # Detect if there is now a cycle created within the graph
436 if graph.findCycle():
437 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
439 return graph
441 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
442 """Execute all Quanta in current process.
444 Parameters
445 ----------
446 graph : `~lsst.pipe.base.QuantumGraph`
447 `~lsst.pipe.base.QuantumGraph` that is to be executed.
448 report : `Report`
449 Object for reporting execution status.
450 """
451 successCount, totalCount = 0, len(graph)
452 failedNodes: set[QuantumNode] = set()
453 for qnode in graph:
454 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
456 # Any failed inputs mean that the quantum has to be skipped.
457 inputNodes = graph.determineInputsToQuantumNode(qnode)
458 if inputNodes & failedNodes:
459 _LOG.error(
460 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
461 qnode.taskDef,
462 qnode.quantum.dataId,
463 )
464 failedNodes.add(qnode)
465 failed_quantum_report = QuantumReport(
466 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
467 )
468 report.quantaReports.append(failed_quantum_report)
469 continue
471 _LOG.debug("Executing %s", qnode)
472 try:
473 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum)
474 successCount += 1
475 except Exception as exc:
476 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
477 _LOG.error(
478 "Task <%s dataId=%s> failed; dropping into pdb.",
479 qnode.taskDef,
480 qnode.quantum.dataId,
481 exc_info=exc,
482 )
483 try:
484 pdb = importlib.import_module(self.pdb)
485 except ImportError as imp_exc:
486 raise MPGraphExecutorError(
487 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
488 ) from exc
489 if not hasattr(pdb, "post_mortem"):
490 raise MPGraphExecutorError(
491 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
492 ) from exc
493 pdb.post_mortem(exc.__traceback__)
494 failedNodes.add(qnode)
495 report.status = ExecutionStatus.FAILURE
496 if self.failFast:
497 raise MPGraphExecutorError(
498 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
499 ) from exc
500 else:
501 # Note that there could be exception safety issues, which
502 # we presently ignore.
503 _LOG.error(
504 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
505 qnode.taskDef,
506 qnode.quantum.dataId,
507 exc_info=exc,
508 )
509 finally:
510 # sqlalchemy has some objects that can last until a garbage
511 # collection cycle is run, which can happen at unpredictable
512 # times, run a collection loop here explicitly.
513 gc.collect()
515 quantum_report = self.quantumExecutor.getReport()
516 if quantum_report:
517 report.quantaReports.append(quantum_report)
519 _LOG.info(
520 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
521 successCount,
522 len(failedNodes),
523 totalCount - successCount - len(failedNodes),
524 totalCount,
525 )
527 # Raise an exception if there were any failures.
528 if failedNodes:
529 raise MPGraphExecutorError("One or more tasks failed during execution.")
531 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
532 """Execute all Quanta in separate processes.
534 Parameters
535 ----------
536 graph : `~lsst.pipe.base.QuantumGraph`
537 `~lsst.pipe.base.QuantumGraph` that is to be executed.
538 report : `Report`
539 Object for reporting execution status.
540 """
541 disable_implicit_threading() # To prevent thread contention
543 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
545 # re-pack input quantum data into jobs list
546 jobs = _JobList(graph)
548 # check that all tasks can run in sub-process
549 for job in jobs.jobs:
550 taskDef = job.qnode.taskDef
551 if not taskDef.taskClass.canMultiprocess:
552 raise MPGraphExecutorError(
553 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
554 )
556 finishedCount, failedCount = 0, 0
557 while jobs.pending or jobs.running:
558 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
559 _LOG.debug("#runningJobs: %s", len(jobs.running))
561 # See if any jobs have finished
562 for job in jobs.running:
563 assert job.process is not None, "Process cannot be None"
564 if not job.process.is_alive():
565 _LOG.debug("finished: %s", job)
566 # finished
567 exitcode = job.process.exitcode
568 quantum_report = job.report()
569 report.quantaReports.append(quantum_report)
570 if exitcode == 0:
571 jobs.setJobState(job, JobState.FINISHED)
572 job.cleanup()
573 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
574 else:
575 if job.terminated:
576 # Was killed due to timeout.
577 if report.status == ExecutionStatus.SUCCESS:
578 # Do not override global FAILURE status
579 report.status = ExecutionStatus.TIMEOUT
580 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
581 jobs.setJobState(job, JobState.TIMED_OUT)
582 else:
583 report.status = ExecutionStatus.FAILURE
584 # failMessage() has to be called before cleanup()
585 message = job.failMessage()
586 jobs.setJobState(job, JobState.FAILED)
588 job.cleanup()
589 _LOG.debug("failed: %s", job)
590 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
591 # stop all running jobs
592 for stopJob in jobs.running:
593 if stopJob is not job:
594 stopJob.stop()
595 if job.state is JobState.TIMED_OUT:
596 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
597 else:
598 raise MPGraphExecutorError(message)
599 else:
600 _LOG.error("%s; processing will continue for remaining tasks.", message)
601 else:
602 # check for timeout
603 now = time.time()
604 if now - job.started > self.timeout:
605 # Try to kill it, and there is a chance that it
606 # finishes successfully before it gets killed. Exit
607 # status is handled by the code above on next
608 # iteration.
609 _LOG.debug("Terminating job %s due to timeout", job)
610 job.stop()
612 # Fail jobs whose inputs failed, this may need several iterations
613 # if the order is not right, will be done in the next loop.
614 if jobs.failedNodes:
615 for job in jobs.pending:
616 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
617 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
618 if jobInputNodes & jobs.failedNodes:
619 quantum_report = QuantumReport(
620 status=ExecutionStatus.SKIPPED,
621 dataId=job.qnode.quantum.dataId,
622 taskLabel=job.qnode.taskDef.label,
623 )
624 report.quantaReports.append(quantum_report)
625 jobs.setJobState(job, JobState.FAILED_DEP)
626 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
628 # see if we can start more jobs
629 if len(jobs.running) < self.numProc:
630 for job in jobs.pending:
631 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
632 if jobInputNodes <= jobs.finishedNodes:
633 # all dependencies have completed, can start new job
634 if len(jobs.running) < self.numProc:
635 _LOG.debug("Submitting %s", job)
636 jobs.submit(job, self.quantumExecutor, self.startMethod)
637 if len(jobs.running) >= self.numProc:
638 # Cannot start any more jobs, wait until something
639 # finishes.
640 break
642 # Do cleanup for timed out jobs if necessary.
643 jobs.cleanup()
645 # Print progress message if something changed.
646 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
647 if (finishedCount, failedCount) != (newFinished, newFailed):
648 finishedCount, failedCount = newFinished, newFailed
649 totalCount = len(jobs.jobs)
650 _LOG.info(
651 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
652 finishedCount,
653 failedCount,
654 totalCount - finishedCount - failedCount,
655 totalCount,
656 )
658 # Here we want to wait until one of the running jobs completes
659 # but multiprocessing does not provide an API for that, for now
660 # just sleep a little bit and go back to the loop.
661 if jobs.running:
662 time.sleep(0.1)
664 if jobs.failedNodes:
665 # print list of failed jobs
666 _LOG.error("Failed jobs:")
667 for job in jobs.jobs:
668 if job.state != JobState.FINISHED:
669 _LOG.error(" - %s: %s", job.state.name, job)
671 # if any job failed raise an exception
672 if jobs.failedNodes == jobs.timedOutNodes:
673 raise MPTimeoutError("One or more tasks timed out during execution.")
674 else:
675 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
677 def getReport(self) -> Report | None:
678 # Docstring inherited from base class
679 if self.report is None:
680 raise RuntimeError("getReport() called before execute()")
681 return self.report