Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%
307 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 12:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 12:21 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
32import gc
33import importlib
34import logging
35import multiprocessing
36import pickle
37import signal
38import sys
39import threading
40import time
41from collections.abc import Iterable
42from enum import Enum
43from typing import Literal
45from lsst.daf.butler.cli.cliLog import CliLog
46from lsst.pipe.base import InvalidQuantumError, TaskDef
47from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
48from lsst.utils.threads import disable_implicit_threading
50from .executionGraphFixup import ExecutionGraphFixup
51from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
52from .reports import ExecutionStatus, QuantumReport, Report
54_LOG = logging.getLogger(__name__)
57# Possible states for the executing task:
58# - PENDING: job has not started yet
59# - RUNNING: job is currently executing
60# - FINISHED: job finished successfully
61# - FAILED: job execution failed (process returned non-zero status)
62# - TIMED_OUT: job is killed due to too long execution time
63# - FAILED_DEP: one of the dependencies of this job has failed/timed out
64JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
67class _Job:
68 """Class representing a job running single task.
70 Parameters
71 ----------
72 qnode: `~lsst.pipe.base.QuantumNode`
73 Quantum and some associated information.
74 """
76 def __init__(self, qnode: QuantumNode):
77 self.qnode = qnode
78 self.process: multiprocessing.process.BaseProcess | None = None
79 self._state = JobState.PENDING
80 self.started: float = 0.0
81 self._rcv_conn: multiprocessing.connection.Connection | None = None
82 self._terminated = False
84 @property
85 def state(self) -> JobState:
86 """Job processing state (JobState)"""
87 return self._state
89 @property
90 def terminated(self) -> bool:
91 """Return `True` if job was killed by stop() method and negative exit
92 code is returned from child process. (`bool`)
93 """
94 if self._terminated:
95 assert self.process is not None, "Process must be started"
96 if self.process.exitcode is not None:
97 return self.process.exitcode < 0
98 return False
100 def start(
101 self,
102 quantumExecutor: QuantumExecutor,
103 startMethod: Literal["spawn"] | Literal["forkserver"],
104 ) -> None:
105 """Start process which runs the task.
107 Parameters
108 ----------
109 quantumExecutor : `QuantumExecutor`
110 Executor for single quantum.
111 startMethod : `str`, optional
112 Start method from `multiprocessing` module.
113 """
114 # Unpickling of quantum has to happen after butler/executor, this is
115 # why it is pickled manually here.
116 quantum_pickle = pickle.dumps(self.qnode.quantum)
117 taskDef = self.qnode.taskDef
118 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
119 logConfigState = CliLog.configState
121 mp_ctx = multiprocessing.get_context(startMethod)
122 self.process = mp_ctx.Process( # type: ignore[attr-defined]
123 target=_Job._executeJob,
124 args=(quantumExecutor, taskDef, quantum_pickle, logConfigState, snd_conn),
125 name=f"task-{self.qnode.quantum.dataId}",
126 )
127 # mypy is getting confused by multiprocessing.
128 assert self.process is not None
129 self.process.start()
130 self.started = time.time()
131 self._state = JobState.RUNNING
133 @staticmethod
134 def _executeJob(
135 quantumExecutor: QuantumExecutor,
136 taskDef: TaskDef,
137 quantum_pickle: bytes,
138 logConfigState: list,
139 snd_conn: multiprocessing.connection.Connection,
140 ) -> None:
141 """Execute a job with arguments.
143 Parameters
144 ----------
145 quantumExecutor : `QuantumExecutor`
146 Executor for single quantum.
147 taskDef : `bytes`
148 Task definition structure.
149 quantum_pickle : `bytes`
150 Quantum for this task execution in pickled form.
151 snd_conn : `multiprocessing.Connection`
152 Connection to send job report to parent process.
153 """
154 # This terrible hack is a workaround for Python threading bug:
155 # https://github.com/python/cpython/issues/102512. Should be removed
156 # when fix for that bug is deployed. Inspired by
157 # https://github.com/QubesOS/qubes-core-admin-client/pull/236/files.
158 thread = threading.current_thread()
159 if isinstance(thread, threading._DummyThread):
160 if getattr(thread, "_tstate_lock", "") is None:
161 thread._set_tstate_lock() # type: ignore[attr-defined]
163 if logConfigState and not CliLog.configState:
164 # means that we are in a new spawned Python process and we have to
165 # re-initialize logging
166 CliLog.replayConfigState(logConfigState)
168 quantum = pickle.loads(quantum_pickle)
169 try:
170 quantumExecutor.execute(taskDef, quantum)
171 finally:
172 # If sending fails we do not want this new exception to be exposed.
173 try:
174 report = quantumExecutor.getReport()
175 snd_conn.send(report)
176 except Exception:
177 pass
179 def stop(self) -> None:
180 """Stop the process."""
181 assert self.process is not None, "Process must be started"
182 self.process.terminate()
183 # give it 1 second to finish or KILL
184 for _ in range(10):
185 time.sleep(0.1)
186 if not self.process.is_alive():
187 break
188 else:
189 _LOG.debug("Killing process %s", self.process.name)
190 self.process.kill()
191 self._terminated = True
193 def cleanup(self) -> None:
194 """Release processes resources, has to be called for each finished
195 process.
196 """
197 if self.process and not self.process.is_alive():
198 self.process.close()
199 self.process = None
200 self._rcv_conn = None
202 def report(self) -> QuantumReport:
203 """Return task report, should be called after process finishes and
204 before cleanup().
205 """
206 assert self.process is not None, "Process must be started"
207 assert self._rcv_conn is not None, "Process must be started"
208 try:
209 report = self._rcv_conn.recv()
210 report.exitCode = self.process.exitcode
211 except Exception:
212 # Likely due to the process killed, but there may be other reasons.
213 # Exit code should not be None, this is to keep mypy happy.
214 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
215 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
216 report = QuantumReport.from_exit_code(
217 exitCode=exitcode,
218 dataId=self.qnode.quantum.dataId,
219 taskLabel=self.qnode.taskDef.label,
220 )
221 if self.terminated:
222 # Means it was killed, assume it's due to timeout
223 report.status = ExecutionStatus.TIMEOUT
224 return report
226 def failMessage(self) -> str:
227 """Return a message describing task failure"""
228 assert self.process is not None, "Process must be started"
229 assert self.process.exitcode is not None, "Process has to finish"
230 exitcode = self.process.exitcode
231 if exitcode < 0:
232 # Negative exit code means it is killed by signal
233 signum = -exitcode
234 msg = f"Task {self} failed, killed by signal {signum}"
235 # Just in case this is some very odd signal, expect ValueError
236 try:
237 strsignal = signal.strsignal(signum)
238 msg = f"{msg} ({strsignal})"
239 except ValueError:
240 pass
241 elif exitcode > 0:
242 msg = f"Task {self} failed, exit code={exitcode}"
243 else:
244 msg = ""
245 return msg
247 def __str__(self) -> str:
248 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
251class _JobList:
252 """Simple list of _Job instances with few convenience methods.
254 Parameters
255 ----------
256 iterable : iterable of `~lsst.pipe.base.QuantumNode`
257 Sequence of Quanta to execute. This has to be ordered according to
258 task dependencies.
259 """
261 def __init__(self, iterable: Iterable[QuantumNode]):
262 self.jobs = [_Job(qnode) for qnode in iterable]
263 self.pending = self.jobs[:]
264 self.running: list[_Job] = []
265 self.finishedNodes: set[QuantumNode] = set()
266 self.failedNodes: set[QuantumNode] = set()
267 self.timedOutNodes: set[QuantumNode] = set()
269 def submit(
270 self,
271 job: _Job,
272 quantumExecutor: QuantumExecutor,
273 startMethod: Literal["spawn"] | Literal["forkserver"],
274 ) -> None:
275 """Submit one more job for execution
277 Parameters
278 ----------
279 job : `_Job`
280 Job to submit.
281 quantumExecutor : `QuantumExecutor`
282 Executor for single quantum.
283 startMethod : `str`, optional
284 Start method from `multiprocessing` module.
285 """
286 # this will raise if job is not in pending list
287 self.pending.remove(job)
288 job.start(quantumExecutor, startMethod)
289 self.running.append(job)
291 def setJobState(self, job: _Job, state: JobState) -> None:
292 """Update job state.
294 Parameters
295 ----------
296 job : `_Job`
297 Job to submit.
298 state : `JobState`
299 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
300 FAILED_DEP state is acceptable.
301 """
302 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
303 assert state in allowedStates, f"State {state} not allowed here"
305 # remove job from pending/running lists
306 if job.state == JobState.PENDING:
307 self.pending.remove(job)
308 elif job.state == JobState.RUNNING:
309 self.running.remove(job)
311 qnode = job.qnode
312 # it should not be in any of these, but just in case
313 self.finishedNodes.discard(qnode)
314 self.failedNodes.discard(qnode)
315 self.timedOutNodes.discard(qnode)
317 job._state = state
318 if state == JobState.FINISHED:
319 self.finishedNodes.add(qnode)
320 elif state == JobState.FAILED:
321 self.failedNodes.add(qnode)
322 elif state == JobState.FAILED_DEP:
323 self.failedNodes.add(qnode)
324 elif state == JobState.TIMED_OUT:
325 self.failedNodes.add(qnode)
326 self.timedOutNodes.add(qnode)
327 else:
328 raise ValueError(f"Unexpected state value: {state}")
330 def cleanup(self) -> None:
331 """Do periodic cleanup for jobs that did not finish correctly.
333 If timed out jobs are killed but take too long to stop then regular
334 cleanup will not work for them. Here we check all timed out jobs
335 periodically and do cleanup if they managed to die by this time.
336 """
337 for job in self.jobs:
338 if job.state == JobState.TIMED_OUT and job.process is not None:
339 job.cleanup()
342class MPGraphExecutorError(Exception):
343 """Exception class for errors raised by MPGraphExecutor."""
345 pass
348class MPTimeoutError(MPGraphExecutorError):
349 """Exception raised when task execution times out."""
351 pass
354class MPGraphExecutor(QuantumGraphExecutor):
355 """Implementation of QuantumGraphExecutor using same-host multiprocess
356 execution of Quanta.
358 Parameters
359 ----------
360 numProc : `int`
361 Number of processes to use for executing tasks.
362 timeout : `float`
363 Time in seconds to wait for tasks to finish.
364 quantumExecutor : `QuantumExecutor`
365 Executor for single quantum. For multiprocess-style execution when
366 ``numProc`` is greater than one this instance must support pickle.
367 startMethod : `str`, optional
368 Start method from `multiprocessing` module, `None` selects the best
369 one for current platform.
370 failFast : `bool`, optional
371 If set to ``True`` then stop processing on first error from any task.
372 pdb : `str`, optional
373 Debugger to import and use (via the ``post_mortem`` function) in the
374 event of an exception.
375 executionGraphFixup : `ExecutionGraphFixup`, optional
376 Instance used for modification of execution graph.
377 """
379 def __init__(
380 self,
381 numProc: int,
382 timeout: float,
383 quantumExecutor: QuantumExecutor,
384 *,
385 startMethod: Literal["spawn"] | Literal["forkserver"] | None = None,
386 failFast: bool = False,
387 pdb: str | None = None,
388 executionGraphFixup: ExecutionGraphFixup | None = None,
389 ):
390 self.numProc = numProc
391 self.timeout = timeout
392 self.quantumExecutor = quantumExecutor
393 self.failFast = failFast
394 self.pdb = pdb
395 self.executionGraphFixup = executionGraphFixup
396 self.report: Report | None = None
398 # We set default start method as spawn for all platforms.
399 if startMethod is None:
400 startMethod = "spawn"
401 self.startMethod = startMethod
403 def execute(self, graph: QuantumGraph) -> None:
404 # Docstring inherited from QuantumGraphExecutor.execute
405 graph = self._fixupQuanta(graph)
406 self.report = Report()
407 try:
408 if self.numProc > 1:
409 self._executeQuantaMP(graph, self.report)
410 else:
411 self._executeQuantaInProcess(graph, self.report)
412 except Exception as exc:
413 self.report.set_exception(exc)
414 raise
416 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
417 """Call fixup code to modify execution graph.
419 Parameters
420 ----------
421 graph : `~lsst.pipe.base.QuantumGraph`
422 `~lsst.pipe.base.QuantumGraph` to modify.
424 Returns
425 -------
426 graph : `~lsst.pipe.base.QuantumGraph`
427 Modified `~lsst.pipe.base.QuantumGraph`.
429 Raises
430 ------
431 MPGraphExecutorError
432 Raised if execution graph cannot be ordered after modification,
433 i.e. it has dependency cycles.
434 """
435 if not self.executionGraphFixup:
436 return graph
438 _LOG.debug("Call execution graph fixup method")
439 graph = self.executionGraphFixup.fixupQuanta(graph)
441 # Detect if there is now a cycle created within the graph
442 if graph.findCycle():
443 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
445 return graph
447 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
448 """Execute all Quanta in current process.
450 Parameters
451 ----------
452 graph : `~lsst.pipe.base.QuantumGraph`
453 `~lsst.pipe.base.QuantumGraph` that is to be executed.
454 report : `Report`
455 Object for reporting execution status.
456 """
457 successCount, totalCount = 0, len(graph)
458 failedNodes: set[QuantumNode] = set()
459 for qnode in graph:
460 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
462 # Any failed inputs mean that the quantum has to be skipped.
463 inputNodes = graph.determineInputsToQuantumNode(qnode)
464 if inputNodes & failedNodes:
465 _LOG.error(
466 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
467 qnode.taskDef,
468 qnode.quantum.dataId,
469 )
470 failedNodes.add(qnode)
471 failed_quantum_report = QuantumReport(
472 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
473 )
474 report.quantaReports.append(failed_quantum_report)
475 continue
477 _LOG.debug("Executing %s", qnode)
478 try:
479 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum)
480 successCount += 1
481 except Exception as exc:
482 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
483 _LOG.error(
484 "Task <%s dataId=%s> failed; dropping into pdb.",
485 qnode.taskDef,
486 qnode.quantum.dataId,
487 exc_info=exc,
488 )
489 try:
490 pdb = importlib.import_module(self.pdb)
491 except ImportError as imp_exc:
492 raise MPGraphExecutorError(
493 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
494 ) from exc
495 if not hasattr(pdb, "post_mortem"):
496 raise MPGraphExecutorError(
497 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
498 ) from exc
499 pdb.post_mortem(exc.__traceback__)
500 failedNodes.add(qnode)
501 report.status = ExecutionStatus.FAILURE
502 if self.failFast:
503 raise MPGraphExecutorError(
504 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
505 ) from exc
506 else:
507 # Note that there could be exception safety issues, which
508 # we presently ignore.
509 _LOG.error(
510 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
511 qnode.taskDef,
512 qnode.quantum.dataId,
513 exc_info=exc,
514 )
515 finally:
516 # sqlalchemy has some objects that can last until a garbage
517 # collection cycle is run, which can happen at unpredictable
518 # times, run a collection loop here explicitly.
519 gc.collect()
521 quantum_report = self.quantumExecutor.getReport()
522 if quantum_report:
523 report.quantaReports.append(quantum_report)
525 _LOG.info(
526 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
527 successCount,
528 len(failedNodes),
529 totalCount - successCount - len(failedNodes),
530 totalCount,
531 )
533 # Raise an exception if there were any failures.
534 if failedNodes:
535 raise MPGraphExecutorError("One or more tasks failed during execution.")
537 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
538 """Execute all Quanta in separate processes.
540 Parameters
541 ----------
542 graph : `~lsst.pipe.base.QuantumGraph`
543 `~lsst.pipe.base.QuantumGraph` that is to be executed.
544 report : `Report`
545 Object for reporting execution status.
546 """
547 disable_implicit_threading() # To prevent thread contention
549 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
551 # re-pack input quantum data into jobs list
552 jobs = _JobList(graph)
554 # check that all tasks can run in sub-process
555 for job in jobs.jobs:
556 taskDef = job.qnode.taskDef
557 if not taskDef.taskClass.canMultiprocess:
558 raise MPGraphExecutorError(
559 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
560 )
562 finishedCount, failedCount = 0, 0
563 while jobs.pending or jobs.running:
564 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
565 _LOG.debug("#runningJobs: %s", len(jobs.running))
567 # See if any jobs have finished
568 for job in jobs.running:
569 assert job.process is not None, "Process cannot be None"
570 if not job.process.is_alive():
571 _LOG.debug("finished: %s", job)
572 # finished
573 exitcode = job.process.exitcode
574 quantum_report = job.report()
575 report.quantaReports.append(quantum_report)
576 if exitcode == 0:
577 jobs.setJobState(job, JobState.FINISHED)
578 job.cleanup()
579 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
580 else:
581 if job.terminated:
582 # Was killed due to timeout.
583 if report.status == ExecutionStatus.SUCCESS:
584 # Do not override global FAILURE status
585 report.status = ExecutionStatus.TIMEOUT
586 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
587 jobs.setJobState(job, JobState.TIMED_OUT)
588 else:
589 report.status = ExecutionStatus.FAILURE
590 # failMessage() has to be called before cleanup()
591 message = job.failMessage()
592 jobs.setJobState(job, JobState.FAILED)
594 job.cleanup()
595 _LOG.debug("failed: %s", job)
596 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
597 # stop all running jobs
598 for stopJob in jobs.running:
599 if stopJob is not job:
600 stopJob.stop()
601 if job.state is JobState.TIMED_OUT:
602 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
603 else:
604 raise MPGraphExecutorError(message)
605 else:
606 _LOG.error("%s; processing will continue for remaining tasks.", message)
607 else:
608 # check for timeout
609 now = time.time()
610 if now - job.started > self.timeout:
611 # Try to kill it, and there is a chance that it
612 # finishes successfully before it gets killed. Exit
613 # status is handled by the code above on next
614 # iteration.
615 _LOG.debug("Terminating job %s due to timeout", job)
616 job.stop()
618 # Fail jobs whose inputs failed, this may need several iterations
619 # if the order is not right, will be done in the next loop.
620 if jobs.failedNodes:
621 for job in jobs.pending:
622 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
623 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
624 if jobInputNodes & jobs.failedNodes:
625 quantum_report = QuantumReport(
626 status=ExecutionStatus.SKIPPED,
627 dataId=job.qnode.quantum.dataId,
628 taskLabel=job.qnode.taskDef.label,
629 )
630 report.quantaReports.append(quantum_report)
631 jobs.setJobState(job, JobState.FAILED_DEP)
632 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
634 # see if we can start more jobs
635 if len(jobs.running) < self.numProc:
636 for job in jobs.pending:
637 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
638 if jobInputNodes <= jobs.finishedNodes:
639 # all dependencies have completed, can start new job
640 if len(jobs.running) < self.numProc:
641 _LOG.debug("Submitting %s", job)
642 jobs.submit(job, self.quantumExecutor, self.startMethod)
643 if len(jobs.running) >= self.numProc:
644 # Cannot start any more jobs, wait until something
645 # finishes.
646 break
648 # Do cleanup for timed out jobs if necessary.
649 jobs.cleanup()
651 # Print progress message if something changed.
652 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
653 if (finishedCount, failedCount) != (newFinished, newFailed):
654 finishedCount, failedCount = newFinished, newFailed
655 totalCount = len(jobs.jobs)
656 _LOG.info(
657 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
658 finishedCount,
659 failedCount,
660 totalCount - finishedCount - failedCount,
661 totalCount,
662 )
664 # Here we want to wait until one of the running jobs completes
665 # but multiprocessing does not provide an API for that, for now
666 # just sleep a little bit and go back to the loop.
667 if jobs.running:
668 time.sleep(0.1)
670 if jobs.failedNodes:
671 # print list of failed jobs
672 _LOG.error("Failed jobs:")
673 for job in jobs.jobs:
674 if job.state != JobState.FINISHED:
675 _LOG.error(" - %s: %s", job.state.name, job)
677 # if any job failed raise an exception
678 if jobs.failedNodes == jobs.timedOutNodes:
679 raise MPTimeoutError("One or more tasks timed out during execution.")
680 else:
681 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
683 def getReport(self) -> Report | None:
684 # Docstring inherited from base class
685 if self.report is None:
686 raise RuntimeError("getReport() called before execute()")
687 return self.report