Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%
307 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 10:53 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 10:53 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
32import gc
33import importlib
34import logging
35import multiprocessing
36import pickle
37import signal
38import sys
39import threading
40import time
41from collections.abc import Iterable
42from enum import Enum
43from typing import Literal
45from lsst.daf.butler.cli.cliLog import CliLog
46from lsst.pipe.base import InvalidQuantumError, TaskDef
47from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
48from lsst.utils.threads import disable_implicit_threading
50from .executionGraphFixup import ExecutionGraphFixup
51from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
52from .reports import ExecutionStatus, QuantumReport, Report
54_LOG = logging.getLogger(__name__)
57# Possible states for the executing task:
58# - PENDING: job has not started yet
59# - RUNNING: job is currently executing
60# - FINISHED: job finished successfully
61# - FAILED: job execution failed (process returned non-zero status)
62# - TIMED_OUT: job is killed due to too long execution time
63# - FAILED_DEP: one of the dependencies of this job has failed/timed out
64JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
67class _Job:
68 """Class representing a job running single task.
70 Parameters
71 ----------
72 qnode: `~lsst.pipe.base.QuantumNode`
73 Quantum and some associated information.
74 """
76 def __init__(self, qnode: QuantumNode):
77 self.qnode = qnode
78 self.process: multiprocessing.process.BaseProcess | None = None
79 self._state = JobState.PENDING
80 self.started: float = 0.0
81 self._rcv_conn: multiprocessing.connection.Connection | None = None
82 self._terminated = False
84 @property
85 def state(self) -> JobState:
86 """Job processing state (JobState)"""
87 return self._state
89 @property
90 def terminated(self) -> bool:
91 """Return `True` if job was killed by stop() method and negative exit
92 code is returned from child process. (`bool`)
93 """
94 if self._terminated:
95 assert self.process is not None, "Process must be started"
96 if self.process.exitcode is not None:
97 return self.process.exitcode < 0
98 return False
100 def start(
101 self,
102 quantumExecutor: QuantumExecutor,
103 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
104 ) -> None:
105 """Start process which runs the task.
107 Parameters
108 ----------
109 quantumExecutor : `QuantumExecutor`
110 Executor for single quantum.
111 startMethod : `str`, optional
112 Start method from `multiprocessing` module.
113 """
114 # Unpickling of quantum has to happen after butler/executor, this is
115 # why it is pickled manually here.
116 quantum_pickle = pickle.dumps(self.qnode.quantum)
117 taskDef = self.qnode.taskDef
118 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
119 logConfigState = CliLog.configState
121 mp_ctx = multiprocessing.get_context(startMethod)
122 self.process = mp_ctx.Process(
123 target=_Job._executeJob,
124 args=(quantumExecutor, taskDef, quantum_pickle, logConfigState, snd_conn),
125 name=f"task-{self.qnode.quantum.dataId}",
126 )
127 self.process.start()
128 self.started = time.time()
129 self._state = JobState.RUNNING
131 @staticmethod
132 def _executeJob(
133 quantumExecutor: QuantumExecutor,
134 taskDef: TaskDef,
135 quantum_pickle: bytes,
136 logConfigState: list,
137 snd_conn: multiprocessing.connection.Connection,
138 ) -> None:
139 """Execute a job with arguments.
141 Parameters
142 ----------
143 quantumExecutor : `QuantumExecutor`
144 Executor for single quantum.
145 taskDef : `bytes`
146 Task definition structure.
147 quantum_pickle : `bytes`
148 Quantum for this task execution in pickled form.
149 snd_conn : `multiprocessing.Connection`
150 Connection to send job report to parent process.
151 """
152 # This terrible hack is a workaround for Python threading bug:
153 # https://github.com/python/cpython/issues/102512. Should be removed
154 # when fix for that bug is deployed. Inspired by
155 # https://github.com/QubesOS/qubes-core-admin-client/pull/236/files.
156 thread = threading.current_thread()
157 if isinstance(thread, threading._DummyThread):
158 if getattr(thread, "_tstate_lock", "") is None:
159 thread._set_tstate_lock() # type: ignore[attr-defined]
161 if logConfigState and not CliLog.configState:
162 # means that we are in a new spawned Python process and we have to
163 # re-initialize logging
164 CliLog.replayConfigState(logConfigState)
166 quantum = pickle.loads(quantum_pickle)
167 try:
168 quantumExecutor.execute(taskDef, quantum)
169 finally:
170 # If sending fails we do not want this new exception to be exposed.
171 try:
172 report = quantumExecutor.getReport()
173 snd_conn.send(report)
174 except Exception:
175 pass
177 def stop(self) -> None:
178 """Stop the process."""
179 assert self.process is not None, "Process must be started"
180 self.process.terminate()
181 # give it 1 second to finish or KILL
182 for _ in range(10):
183 time.sleep(0.1)
184 if not self.process.is_alive():
185 break
186 else:
187 _LOG.debug("Killing process %s", self.process.name)
188 self.process.kill()
189 self._terminated = True
191 def cleanup(self) -> None:
192 """Release processes resources, has to be called for each finished
193 process.
194 """
195 if self.process and not self.process.is_alive():
196 self.process.close()
197 self.process = None
198 self._rcv_conn = None
200 def report(self) -> QuantumReport:
201 """Return task report, should be called after process finishes and
202 before cleanup().
203 """
204 assert self.process is not None, "Process must be started"
205 assert self._rcv_conn is not None, "Process must be started"
206 try:
207 report = self._rcv_conn.recv()
208 report.exitCode = self.process.exitcode
209 except Exception:
210 # Likely due to the process killed, but there may be other reasons.
211 # Exit code should not be None, this is to keep mypy happy.
212 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
213 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
214 report = QuantumReport.from_exit_code(
215 exitCode=exitcode,
216 dataId=self.qnode.quantum.dataId,
217 taskLabel=self.qnode.taskDef.label,
218 )
219 if self.terminated:
220 # Means it was killed, assume it's due to timeout
221 report.status = ExecutionStatus.TIMEOUT
222 return report
224 def failMessage(self) -> str:
225 """Return a message describing task failure"""
226 assert self.process is not None, "Process must be started"
227 assert self.process.exitcode is not None, "Process has to finish"
228 exitcode = self.process.exitcode
229 if exitcode < 0:
230 # Negative exit code means it is killed by signal
231 signum = -exitcode
232 msg = f"Task {self} failed, killed by signal {signum}"
233 # Just in case this is some very odd signal, expect ValueError
234 try:
235 strsignal = signal.strsignal(signum)
236 msg = f"{msg} ({strsignal})"
237 except ValueError:
238 pass
239 elif exitcode > 0:
240 msg = f"Task {self} failed, exit code={exitcode}"
241 else:
242 msg = ""
243 return msg
245 def __str__(self) -> str:
246 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
249class _JobList:
250 """Simple list of _Job instances with few convenience methods.
252 Parameters
253 ----------
254 iterable : iterable of `~lsst.pipe.base.QuantumNode`
255 Sequence of Quanta to execute. This has to be ordered according to
256 task dependencies.
257 """
259 def __init__(self, iterable: Iterable[QuantumNode]):
260 self.jobs = [_Job(qnode) for qnode in iterable]
261 self.pending = self.jobs[:]
262 self.running: list[_Job] = []
263 self.finishedNodes: set[QuantumNode] = set()
264 self.failedNodes: set[QuantumNode] = set()
265 self.timedOutNodes: set[QuantumNode] = set()
267 def submit(
268 self,
269 job: _Job,
270 quantumExecutor: QuantumExecutor,
271 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
272 ) -> None:
273 """Submit one more job for execution
275 Parameters
276 ----------
277 job : `_Job`
278 Job to submit.
279 quantumExecutor : `QuantumExecutor`
280 Executor for single quantum.
281 startMethod : `str`, optional
282 Start method from `multiprocessing` module.
283 """
284 # this will raise if job is not in pending list
285 self.pending.remove(job)
286 job.start(quantumExecutor, startMethod)
287 self.running.append(job)
289 def setJobState(self, job: _Job, state: JobState) -> None:
290 """Update job state.
292 Parameters
293 ----------
294 job : `_Job`
295 Job to submit.
296 state : `JobState`
297 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
298 FAILED_DEP state is acceptable.
299 """
300 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
301 assert state in allowedStates, f"State {state} not allowed here"
303 # remove job from pending/running lists
304 if job.state == JobState.PENDING:
305 self.pending.remove(job)
306 elif job.state == JobState.RUNNING:
307 self.running.remove(job)
309 qnode = job.qnode
310 # it should not be in any of these, but just in case
311 self.finishedNodes.discard(qnode)
312 self.failedNodes.discard(qnode)
313 self.timedOutNodes.discard(qnode)
315 job._state = state
316 if state == JobState.FINISHED:
317 self.finishedNodes.add(qnode)
318 elif state == JobState.FAILED:
319 self.failedNodes.add(qnode)
320 elif state == JobState.FAILED_DEP:
321 self.failedNodes.add(qnode)
322 elif state == JobState.TIMED_OUT:
323 self.failedNodes.add(qnode)
324 self.timedOutNodes.add(qnode)
325 else:
326 raise ValueError(f"Unexpected state value: {state}")
328 def cleanup(self) -> None:
329 """Do periodic cleanup for jobs that did not finish correctly.
331 If timed out jobs are killed but take too long to stop then regular
332 cleanup will not work for them. Here we check all timed out jobs
333 periodically and do cleanup if they managed to die by this time.
334 """
335 for job in self.jobs:
336 if job.state == JobState.TIMED_OUT and job.process is not None:
337 job.cleanup()
340class MPGraphExecutorError(Exception):
341 """Exception class for errors raised by MPGraphExecutor."""
343 pass
346class MPTimeoutError(MPGraphExecutorError):
347 """Exception raised when task execution times out."""
349 pass
352class MPGraphExecutor(QuantumGraphExecutor):
353 """Implementation of QuantumGraphExecutor using same-host multiprocess
354 execution of Quanta.
356 Parameters
357 ----------
358 numProc : `int`
359 Number of processes to use for executing tasks.
360 timeout : `float`
361 Time in seconds to wait for tasks to finish.
362 quantumExecutor : `QuantumExecutor`
363 Executor for single quantum. For multiprocess-style execution when
364 ``numProc`` is greater than one this instance must support pickle.
365 startMethod : `str`, optional
366 Start method from `multiprocessing` module, `None` selects the best
367 one for current platform.
368 failFast : `bool`, optional
369 If set to ``True`` then stop processing on first error from any task.
370 pdb : `str`, optional
371 Debugger to import and use (via the ``post_mortem`` function) in the
372 event of an exception.
373 executionGraphFixup : `ExecutionGraphFixup`, optional
374 Instance used for modification of execution graph.
375 """
377 def __init__(
378 self,
379 numProc: int,
380 timeout: float,
381 quantumExecutor: QuantumExecutor,
382 *,
383 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
384 failFast: bool = False,
385 pdb: str | None = None,
386 executionGraphFixup: ExecutionGraphFixup | None = None,
387 ):
388 self.numProc = numProc
389 self.timeout = timeout
390 self.quantumExecutor = quantumExecutor
391 self.failFast = failFast
392 self.pdb = pdb
393 self.executionGraphFixup = executionGraphFixup
394 self.report: Report | None = None
396 # We set default start method as spawn for MacOS and fork for Linux;
397 # None for all other platforms to use multiprocessing default.
398 if startMethod is None:
399 methods = dict(linux="fork", darwin="spawn")
400 startMethod = methods.get(sys.platform) # type: ignore
401 self.startMethod = startMethod
403 def execute(self, graph: QuantumGraph) -> None:
404 # Docstring inherited from QuantumGraphExecutor.execute
405 graph = self._fixupQuanta(graph)
406 self.report = Report()
407 try:
408 if self.numProc > 1:
409 self._executeQuantaMP(graph, self.report)
410 else:
411 self._executeQuantaInProcess(graph, self.report)
412 except Exception as exc:
413 self.report.set_exception(exc)
414 raise
416 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
417 """Call fixup code to modify execution graph.
419 Parameters
420 ----------
421 graph : `~lsst.pipe.base.QuantumGraph`
422 `~lsst.pipe.base.QuantumGraph` to modify.
424 Returns
425 -------
426 graph : `~lsst.pipe.base.QuantumGraph`
427 Modified `~lsst.pipe.base.QuantumGraph`.
429 Raises
430 ------
431 MPGraphExecutorError
432 Raised if execution graph cannot be ordered after modification,
433 i.e. it has dependency cycles.
434 """
435 if not self.executionGraphFixup:
436 return graph
438 _LOG.debug("Call execution graph fixup method")
439 graph = self.executionGraphFixup.fixupQuanta(graph)
441 # Detect if there is now a cycle created within the graph
442 if graph.findCycle():
443 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
445 return graph
447 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
448 """Execute all Quanta in current process.
450 Parameters
451 ----------
452 graph : `~lsst.pipe.base.QuantumGraph`
453 `~lsst.pipe.base.QuantumGraph` that is to be executed.
454 report : `Report`
455 Object for reporting execution status.
456 """
457 successCount, totalCount = 0, len(graph)
458 failedNodes: set[QuantumNode] = set()
459 for qnode in graph:
460 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
462 # Any failed inputs mean that the quantum has to be skipped.
463 inputNodes = graph.determineInputsToQuantumNode(qnode)
464 if inputNodes & failedNodes:
465 _LOG.error(
466 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
467 qnode.taskDef,
468 qnode.quantum.dataId,
469 )
470 failedNodes.add(qnode)
471 failed_quantum_report = QuantumReport(
472 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
473 )
474 report.quantaReports.append(failed_quantum_report)
475 continue
477 _LOG.debug("Executing %s", qnode)
478 try:
479 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum)
480 successCount += 1
481 except Exception as exc:
482 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
483 _LOG.error(
484 "Task <%s dataId=%s> failed; dropping into pdb.",
485 qnode.taskDef,
486 qnode.quantum.dataId,
487 exc_info=exc,
488 )
489 try:
490 pdb = importlib.import_module(self.pdb)
491 except ImportError as imp_exc:
492 raise MPGraphExecutorError(
493 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
494 ) from exc
495 if not hasattr(pdb, "post_mortem"):
496 raise MPGraphExecutorError(
497 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
498 ) from exc
499 pdb.post_mortem(exc.__traceback__)
500 failedNodes.add(qnode)
501 report.status = ExecutionStatus.FAILURE
502 if self.failFast:
503 raise MPGraphExecutorError(
504 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
505 ) from exc
506 else:
507 # Note that there could be exception safety issues, which
508 # we presently ignore.
509 _LOG.error(
510 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
511 qnode.taskDef,
512 qnode.quantum.dataId,
513 exc_info=exc,
514 )
515 finally:
516 # sqlalchemy has some objects that can last until a garbage
517 # collection cycle is run, which can happen at unpredictable
518 # times, run a collection loop here explicitly.
519 gc.collect()
521 quantum_report = self.quantumExecutor.getReport()
522 if quantum_report:
523 report.quantaReports.append(quantum_report)
525 _LOG.info(
526 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
527 successCount,
528 len(failedNodes),
529 totalCount - successCount - len(failedNodes),
530 totalCount,
531 )
533 # Raise an exception if there were any failures.
534 if failedNodes:
535 raise MPGraphExecutorError("One or more tasks failed during execution.")
537 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
538 """Execute all Quanta in separate processes.
540 Parameters
541 ----------
542 graph : `~lsst.pipe.base.QuantumGraph`
543 `~lsst.pipe.base.QuantumGraph` that is to be executed.
544 report : `Report`
545 Object for reporting execution status.
546 """
547 disable_implicit_threading() # To prevent thread contention
549 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
551 # re-pack input quantum data into jobs list
552 jobs = _JobList(graph)
554 # check that all tasks can run in sub-process
555 for job in jobs.jobs:
556 taskDef = job.qnode.taskDef
557 if not taskDef.taskClass.canMultiprocess:
558 raise MPGraphExecutorError(
559 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
560 )
562 finishedCount, failedCount = 0, 0
563 while jobs.pending or jobs.running:
564 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
565 _LOG.debug("#runningJobs: %s", len(jobs.running))
567 # See if any jobs have finished
568 for job in jobs.running:
569 assert job.process is not None, "Process cannot be None"
570 if not job.process.is_alive():
571 _LOG.debug("finished: %s", job)
572 # finished
573 exitcode = job.process.exitcode
574 quantum_report = job.report()
575 report.quantaReports.append(quantum_report)
576 if exitcode == 0:
577 jobs.setJobState(job, JobState.FINISHED)
578 job.cleanup()
579 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
580 else:
581 if job.terminated:
582 # Was killed due to timeout.
583 if report.status == ExecutionStatus.SUCCESS:
584 # Do not override global FAILURE status
585 report.status = ExecutionStatus.TIMEOUT
586 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
587 jobs.setJobState(job, JobState.TIMED_OUT)
588 else:
589 report.status = ExecutionStatus.FAILURE
590 # failMessage() has to be called before cleanup()
591 message = job.failMessage()
592 jobs.setJobState(job, JobState.FAILED)
594 job.cleanup()
595 _LOG.debug("failed: %s", job)
596 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
597 # stop all running jobs
598 for stopJob in jobs.running:
599 if stopJob is not job:
600 stopJob.stop()
601 if job.state is JobState.TIMED_OUT:
602 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
603 else:
604 raise MPGraphExecutorError(message)
605 else:
606 _LOG.error("%s; processing will continue for remaining tasks.", message)
607 else:
608 # check for timeout
609 now = time.time()
610 if now - job.started > self.timeout:
611 # Try to kill it, and there is a chance that it
612 # finishes successfully before it gets killed. Exit
613 # status is handled by the code above on next
614 # iteration.
615 _LOG.debug("Terminating job %s due to timeout", job)
616 job.stop()
618 # Fail jobs whose inputs failed, this may need several iterations
619 # if the order is not right, will be done in the next loop.
620 if jobs.failedNodes:
621 for job in jobs.pending:
622 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
623 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
624 if jobInputNodes & jobs.failedNodes:
625 quantum_report = QuantumReport(
626 status=ExecutionStatus.SKIPPED,
627 dataId=job.qnode.quantum.dataId,
628 taskLabel=job.qnode.taskDef.label,
629 )
630 report.quantaReports.append(quantum_report)
631 jobs.setJobState(job, JobState.FAILED_DEP)
632 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
634 # see if we can start more jobs
635 if len(jobs.running) < self.numProc:
636 for job in jobs.pending:
637 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
638 if jobInputNodes <= jobs.finishedNodes:
639 # all dependencies have completed, can start new job
640 if len(jobs.running) < self.numProc:
641 _LOG.debug("Submitting %s", job)
642 jobs.submit(job, self.quantumExecutor, self.startMethod)
643 if len(jobs.running) >= self.numProc:
644 # Cannot start any more jobs, wait until something
645 # finishes.
646 break
648 # Do cleanup for timed out jobs if necessary.
649 jobs.cleanup()
651 # Print progress message if something changed.
652 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
653 if (finishedCount, failedCount) != (newFinished, newFailed):
654 finishedCount, failedCount = newFinished, newFailed
655 totalCount = len(jobs.jobs)
656 _LOG.info(
657 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
658 finishedCount,
659 failedCount,
660 totalCount - finishedCount - failedCount,
661 totalCount,
662 )
664 # Here we want to wait until one of the running jobs completes
665 # but multiprocessing does not provide an API for that, for now
666 # just sleep a little bit and go back to the loop.
667 if jobs.running:
668 time.sleep(0.1)
670 if jobs.failedNodes:
671 # print list of failed jobs
672 _LOG.error("Failed jobs:")
673 for job in jobs.jobs:
674 if job.state != JobState.FINISHED:
675 _LOG.error(" - %s: %s", job.state.name, job)
677 # if any job failed raise an exception
678 if jobs.failedNodes == jobs.timedOutNodes:
679 raise MPTimeoutError("One or more tasks timed out during execution.")
680 else:
681 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
683 def getReport(self) -> Report | None:
684 # Docstring inherited from base class
685 if self.report is None:
686 raise RuntimeError("getReport() called before execute()")
687 return self.report