Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%
309 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 09:59 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 09:59 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
32import gc
33import importlib
34import logging
35import multiprocessing
36import pickle
37import signal
38import sys
39import threading
40import time
41from collections.abc import Iterable
42from enum import Enum
43from typing import Literal
45from lsst.daf.butler.cli.cliLog import CliLog
46from lsst.pipe.base import InvalidQuantumError
47from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
48from lsst.pipe.base.pipeline_graph import TaskNode
49from lsst.utils.threads import disable_implicit_threading
51from .executionGraphFixup import ExecutionGraphFixup
52from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
53from .reports import ExecutionStatus, QuantumReport, Report
55_LOG = logging.getLogger(__name__)
58# Possible states for the executing task:
59# - PENDING: job has not started yet
60# - RUNNING: job is currently executing
61# - FINISHED: job finished successfully
62# - FAILED: job execution failed (process returned non-zero status)
63# - TIMED_OUT: job is killed due to too long execution time
64# - FAILED_DEP: one of the dependencies of this job has failed/timed out
65JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
68class _Job:
69 """Class representing a job running single task.
71 Parameters
72 ----------
73 qnode: `~lsst.pipe.base.QuantumNode`
74 Quantum and some associated information.
75 """
77 def __init__(self, qnode: QuantumNode):
78 self.qnode = qnode
79 self.process: multiprocessing.process.BaseProcess | None = None
80 self._state = JobState.PENDING
81 self.started: float = 0.0
82 self._rcv_conn: multiprocessing.connection.Connection | None = None
83 self._terminated = False
85 @property
86 def state(self) -> JobState:
87 """Job processing state (JobState)."""
88 return self._state
90 @property
91 def terminated(self) -> bool:
92 """Return `True` if job was killed by stop() method and negative exit
93 code is returned from child process (`bool`).
94 """
95 if self._terminated:
96 assert self.process is not None, "Process must be started"
97 if self.process.exitcode is not None:
98 return self.process.exitcode < 0
99 return False
101 def start(
102 self,
103 quantumExecutor: QuantumExecutor,
104 startMethod: Literal["spawn"] | Literal["forkserver"],
105 ) -> None:
106 """Start process which runs the task.
108 Parameters
109 ----------
110 quantumExecutor : `QuantumExecutor`
111 Executor for single quantum.
112 startMethod : `str`, optional
113 Start method from `multiprocessing` module.
114 """
115 # Unpickling of quantum has to happen after butler/executor, this is
116 # why it is pickled manually here.
117 quantum_pickle = pickle.dumps(self.qnode.quantum)
118 task_node = self.qnode.task_node
119 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
120 logConfigState = CliLog.configState
122 mp_ctx = multiprocessing.get_context(startMethod)
123 self.process = mp_ctx.Process( # type: ignore[attr-defined]
124 target=_Job._executeJob,
125 args=(quantumExecutor, task_node, quantum_pickle, logConfigState, snd_conn),
126 name=f"task-{self.qnode.quantum.dataId}",
127 )
128 # mypy is getting confused by multiprocessing.
129 assert self.process is not None
130 self.process.start()
131 self.started = time.time()
132 self._state = JobState.RUNNING
134 @staticmethod
135 def _executeJob(
136 quantumExecutor: QuantumExecutor,
137 task_node: TaskNode,
138 quantum_pickle: bytes,
139 logConfigState: list,
140 snd_conn: multiprocessing.connection.Connection,
141 ) -> None:
142 """Execute a job with arguments.
144 Parameters
145 ----------
146 quantumExecutor : `QuantumExecutor`
147 Executor for single quantum.
148 task_node : `lsst.pipe.base.pipeline_graph.TaskNode`
149 Task definition structure.
150 quantum_pickle : `bytes`
151 Quantum for this task execution in pickled form.
152 snd_conn : `multiprocessing.Connection`
153 Connection to send job report to parent process.
154 """
155 # This terrible hack is a workaround for Python threading bug:
156 # https://github.com/python/cpython/issues/102512. Should be removed
157 # when fix for that bug is deployed. Inspired by
158 # https://github.com/QubesOS/qubes-core-admin-client/pull/236/files.
159 thread = threading.current_thread()
160 if isinstance(thread, threading._DummyThread):
161 if getattr(thread, "_tstate_lock", "") is None:
162 thread._set_tstate_lock() # type: ignore[attr-defined]
164 if logConfigState and not CliLog.configState:
165 # means that we are in a new spawned Python process and we have to
166 # re-initialize logging
167 CliLog.replayConfigState(logConfigState)
169 quantum = pickle.loads(quantum_pickle)
170 try:
171 quantumExecutor.execute(task_node, quantum)
172 finally:
173 # If sending fails we do not want this new exception to be exposed.
174 try:
175 report = quantumExecutor.getReport()
176 snd_conn.send(report)
177 except Exception:
178 pass
180 def stop(self) -> None:
181 """Stop the process."""
182 assert self.process is not None, "Process must be started"
183 self.process.terminate()
184 # give it 1 second to finish or KILL
185 for _ in range(10):
186 time.sleep(0.1)
187 if not self.process.is_alive():
188 break
189 else:
190 _LOG.debug("Killing process %s", self.process.name)
191 self.process.kill()
192 self._terminated = True
194 def cleanup(self) -> None:
195 """Release processes resources, has to be called for each finished
196 process.
197 """
198 if self.process and not self.process.is_alive():
199 self.process.close()
200 self.process = None
201 self._rcv_conn = None
203 def report(self) -> QuantumReport:
204 """Return task report, should be called after process finishes and
205 before cleanup().
206 """
207 assert self.process is not None, "Process must be started"
208 assert self._rcv_conn is not None, "Process must be started"
209 try:
210 report = self._rcv_conn.recv()
211 report.exitCode = self.process.exitcode
212 except Exception:
213 # Likely due to the process killed, but there may be other reasons.
214 # Exit code should not be None, this is to keep mypy happy.
215 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
216 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
217 report = QuantumReport.from_exit_code(
218 exitCode=exitcode,
219 dataId=self.qnode.quantum.dataId,
220 taskLabel=self.qnode.task_node.label,
221 )
222 if self.terminated:
223 # Means it was killed, assume it's due to timeout
224 report.status = ExecutionStatus.TIMEOUT
225 return report
227 def failMessage(self) -> str:
228 """Return a message describing task failure."""
229 assert self.process is not None, "Process must be started"
230 assert self.process.exitcode is not None, "Process has to finish"
231 exitcode = self.process.exitcode
232 if exitcode < 0:
233 # Negative exit code means it is killed by signal
234 signum = -exitcode
235 msg = f"Task {self} failed, killed by signal {signum}"
236 # Just in case this is some very odd signal, expect ValueError
237 try:
238 strsignal = signal.strsignal(signum)
239 msg = f"{msg} ({strsignal})"
240 except ValueError:
241 pass
242 elif exitcode > 0:
243 msg = f"Task {self} failed, exit code={exitcode}"
244 else:
245 msg = ""
246 return msg
248 def __str__(self) -> str:
249 return f"<{self.qnode.task_node.label} dataId={self.qnode.quantum.dataId}>"
252class _JobList:
253 """Simple list of _Job instances with few convenience methods.
255 Parameters
256 ----------
257 iterable : iterable of `~lsst.pipe.base.QuantumNode`
258 Sequence of Quanta to execute. This has to be ordered according to
259 task dependencies.
260 """
262 def __init__(self, iterable: Iterable[QuantumNode]):
263 self.jobs = [_Job(qnode) for qnode in iterable]
264 self.pending = self.jobs[:]
265 self.running: list[_Job] = []
266 self.finishedNodes: set[QuantumNode] = set()
267 self.failedNodes: set[QuantumNode] = set()
268 self.timedOutNodes: set[QuantumNode] = set()
270 def submit(
271 self,
272 job: _Job,
273 quantumExecutor: QuantumExecutor,
274 startMethod: Literal["spawn"] | Literal["forkserver"],
275 ) -> None:
276 """Submit one more job for execution.
278 Parameters
279 ----------
280 job : `_Job`
281 Job to submit.
282 quantumExecutor : `QuantumExecutor`
283 Executor for single quantum.
284 startMethod : `str`, optional
285 Start method from `multiprocessing` module.
286 """
287 # this will raise if job is not in pending list
288 self.pending.remove(job)
289 job.start(quantumExecutor, startMethod)
290 self.running.append(job)
292 def setJobState(self, job: _Job, state: JobState) -> None:
293 """Update job state.
295 Parameters
296 ----------
297 job : `_Job`
298 Job to submit.
299 state : `JobState`
300 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
301 FAILED_DEP state is acceptable.
302 """
303 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
304 assert state in allowedStates, f"State {state} not allowed here"
306 # remove job from pending/running lists
307 if job.state == JobState.PENDING:
308 self.pending.remove(job)
309 elif job.state == JobState.RUNNING:
310 self.running.remove(job)
312 qnode = job.qnode
313 # it should not be in any of these, but just in case
314 self.finishedNodes.discard(qnode)
315 self.failedNodes.discard(qnode)
316 self.timedOutNodes.discard(qnode)
318 job._state = state
319 if state == JobState.FINISHED:
320 self.finishedNodes.add(qnode)
321 elif state == JobState.FAILED:
322 self.failedNodes.add(qnode)
323 elif state == JobState.FAILED_DEP:
324 self.failedNodes.add(qnode)
325 elif state == JobState.TIMED_OUT:
326 self.failedNodes.add(qnode)
327 self.timedOutNodes.add(qnode)
328 else:
329 raise ValueError(f"Unexpected state value: {state}")
331 def cleanup(self) -> None:
332 """Do periodic cleanup for jobs that did not finish correctly.
334 If timed out jobs are killed but take too long to stop then regular
335 cleanup will not work for them. Here we check all timed out jobs
336 periodically and do cleanup if they managed to die by this time.
337 """
338 for job in self.jobs:
339 if job.state == JobState.TIMED_OUT and job.process is not None:
340 job.cleanup()
343class MPGraphExecutorError(Exception):
344 """Exception class for errors raised by MPGraphExecutor."""
346 pass
349class MPTimeoutError(MPGraphExecutorError):
350 """Exception raised when task execution times out."""
352 pass
355class MPGraphExecutor(QuantumGraphExecutor):
356 """Implementation of QuantumGraphExecutor using same-host multiprocess
357 execution of Quanta.
359 Parameters
360 ----------
361 numProc : `int`
362 Number of processes to use for executing tasks.
363 timeout : `float`
364 Time in seconds to wait for tasks to finish.
365 quantumExecutor : `QuantumExecutor`
366 Executor for single quantum. For multiprocess-style execution when
367 ``numProc`` is greater than one this instance must support pickle.
368 startMethod : `str`, optional
369 Start method from `multiprocessing` module, `None` selects the best
370 one for current platform.
371 failFast : `bool`, optional
372 If set to ``True`` then stop processing on first error from any task.
373 pdb : `str`, optional
374 Debugger to import and use (via the ``post_mortem`` function) in the
375 event of an exception.
376 executionGraphFixup : `ExecutionGraphFixup`, optional
377 Instance used for modification of execution graph.
378 """
380 def __init__(
381 self,
382 numProc: int,
383 timeout: float,
384 quantumExecutor: QuantumExecutor,
385 *,
386 startMethod: Literal["spawn"] | Literal["forkserver"] | None = None,
387 failFast: bool = False,
388 pdb: str | None = None,
389 executionGraphFixup: ExecutionGraphFixup | None = None,
390 ):
391 self.numProc = numProc
392 self.timeout = timeout
393 self.quantumExecutor = quantumExecutor
394 self.failFast = failFast
395 self.pdb = pdb
396 self.executionGraphFixup = executionGraphFixup
397 self.report: Report | None = None
399 # We set default start method as spawn for all platforms.
400 if startMethod is None:
401 startMethod = "spawn"
402 self.startMethod = startMethod
404 def execute(self, graph: QuantumGraph) -> None:
405 # Docstring inherited from QuantumGraphExecutor.execute
406 graph = self._fixupQuanta(graph)
407 self.report = Report(qgraphSummary=graph.getSummary())
408 try:
409 if self.numProc > 1:
410 self._executeQuantaMP(graph, self.report)
411 else:
412 self._executeQuantaInProcess(graph, self.report)
413 except Exception as exc:
414 self.report.set_exception(exc)
415 raise
417 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
418 """Call fixup code to modify execution graph.
420 Parameters
421 ----------
422 graph : `~lsst.pipe.base.QuantumGraph`
423 `~lsst.pipe.base.QuantumGraph` to modify.
425 Returns
426 -------
427 graph : `~lsst.pipe.base.QuantumGraph`
428 Modified `~lsst.pipe.base.QuantumGraph`.
430 Raises
431 ------
432 MPGraphExecutorError
433 Raised if execution graph cannot be ordered after modification,
434 i.e. it has dependency cycles.
435 """
436 if not self.executionGraphFixup:
437 return graph
439 _LOG.debug("Call execution graph fixup method")
440 graph = self.executionGraphFixup.fixupQuanta(graph)
442 # Detect if there is now a cycle created within the graph
443 if graph.findCycle():
444 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
446 return graph
448 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
449 """Execute all Quanta in current process.
451 Parameters
452 ----------
453 graph : `~lsst.pipe.base.QuantumGraph`
454 `~lsst.pipe.base.QuantumGraph` that is to be executed.
455 report : `Report`
456 Object for reporting execution status.
457 """
458 successCount, totalCount = 0, len(graph)
459 failedNodes: set[QuantumNode] = set()
460 for qnode in graph:
461 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
462 task_node = qnode.task_node
464 # Any failed inputs mean that the quantum has to be skipped.
465 inputNodes = graph.determineInputsToQuantumNode(qnode)
466 if inputNodes & failedNodes:
467 _LOG.error(
468 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
469 task_node.label,
470 qnode.quantum.dataId,
471 )
472 failedNodes.add(qnode)
473 failed_quantum_report = QuantumReport(
474 status=ExecutionStatus.SKIPPED,
475 dataId=qnode.quantum.dataId,
476 taskLabel=task_node.label,
477 )
478 report.quantaReports.append(failed_quantum_report)
479 continue
481 _LOG.debug("Executing %s", qnode)
482 try:
483 self.quantumExecutor.execute(task_node, qnode.quantum)
484 successCount += 1
485 except Exception as exc:
486 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
487 _LOG.error(
488 "Task <%s dataId=%s> failed; dropping into pdb.",
489 task_node.label,
490 qnode.quantum.dataId,
491 exc_info=exc,
492 )
493 try:
494 pdb = importlib.import_module(self.pdb)
495 except ImportError as imp_exc:
496 raise MPGraphExecutorError(
497 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
498 ) from exc
499 if not hasattr(pdb, "post_mortem"):
500 raise MPGraphExecutorError(
501 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
502 ) from exc
503 pdb.post_mortem(exc.__traceback__)
504 failedNodes.add(qnode)
505 report.status = ExecutionStatus.FAILURE
506 if self.failFast:
507 raise MPGraphExecutorError(
508 f"Task <{task_node.label} dataId={qnode.quantum.dataId}> failed."
509 ) from exc
510 else:
511 # Note that there could be exception safety issues, which
512 # we presently ignore.
513 _LOG.error(
514 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
515 task_node.label,
516 qnode.quantum.dataId,
517 exc_info=exc,
518 )
519 finally:
520 # sqlalchemy has some objects that can last until a garbage
521 # collection cycle is run, which can happen at unpredictable
522 # times, run a collection loop here explicitly.
523 gc.collect()
525 quantum_report = self.quantumExecutor.getReport()
526 if quantum_report:
527 report.quantaReports.append(quantum_report)
529 _LOG.info(
530 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
531 successCount,
532 len(failedNodes),
533 totalCount - successCount - len(failedNodes),
534 totalCount,
535 )
537 # Raise an exception if there were any failures.
538 if failedNodes:
539 raise MPGraphExecutorError("One or more tasks failed during execution.")
541 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
542 """Execute all Quanta in separate processes.
544 Parameters
545 ----------
546 graph : `~lsst.pipe.base.QuantumGraph`
547 `~lsst.pipe.base.QuantumGraph` that is to be executed.
548 report : `Report`
549 Object for reporting execution status.
550 """
551 disable_implicit_threading() # To prevent thread contention
553 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
555 # re-pack input quantum data into jobs list
556 jobs = _JobList(graph)
558 # check that all tasks can run in sub-process
559 for job in jobs.jobs:
560 task_node = job.qnode.task_node
561 if not task_node.task_class.canMultiprocess:
562 raise MPGraphExecutorError(
563 f"Task {task_node.label!r} does not support multiprocessing; use single process"
564 )
566 finishedCount, failedCount = 0, 0
567 while jobs.pending or jobs.running:
568 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
569 _LOG.debug("#runningJobs: %s", len(jobs.running))
571 # See if any jobs have finished
572 for job in jobs.running:
573 assert job.process is not None, "Process cannot be None"
574 if not job.process.is_alive():
575 _LOG.debug("finished: %s", job)
576 # finished
577 exitcode = job.process.exitcode
578 quantum_report = job.report()
579 report.quantaReports.append(quantum_report)
580 if exitcode == 0:
581 jobs.setJobState(job, JobState.FINISHED)
582 job.cleanup()
583 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
584 else:
585 if job.terminated:
586 # Was killed due to timeout.
587 if report.status == ExecutionStatus.SUCCESS:
588 # Do not override global FAILURE status
589 report.status = ExecutionStatus.TIMEOUT
590 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
591 jobs.setJobState(job, JobState.TIMED_OUT)
592 else:
593 report.status = ExecutionStatus.FAILURE
594 # failMessage() has to be called before cleanup()
595 message = job.failMessage()
596 jobs.setJobState(job, JobState.FAILED)
598 job.cleanup()
599 _LOG.debug("failed: %s", job)
600 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
601 # stop all running jobs
602 for stopJob in jobs.running:
603 if stopJob is not job:
604 stopJob.stop()
605 if job.state is JobState.TIMED_OUT:
606 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
607 else:
608 raise MPGraphExecutorError(message)
609 else:
610 _LOG.error("%s; processing will continue for remaining tasks.", message)
611 else:
612 # check for timeout
613 now = time.time()
614 if now - job.started > self.timeout:
615 # Try to kill it, and there is a chance that it
616 # finishes successfully before it gets killed. Exit
617 # status is handled by the code above on next
618 # iteration.
619 _LOG.debug("Terminating job %s due to timeout", job)
620 job.stop()
622 # Fail jobs whose inputs failed, this may need several iterations
623 # if the order is not right, will be done in the next loop.
624 if jobs.failedNodes:
625 for job in jobs.pending:
626 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
627 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
628 if jobInputNodes & jobs.failedNodes:
629 quantum_report = QuantumReport(
630 status=ExecutionStatus.SKIPPED,
631 dataId=job.qnode.quantum.dataId,
632 taskLabel=job.qnode.task_node.label,
633 )
634 report.quantaReports.append(quantum_report)
635 jobs.setJobState(job, JobState.FAILED_DEP)
636 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
638 # see if we can start more jobs
639 if len(jobs.running) < self.numProc:
640 for job in jobs.pending:
641 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
642 if jobInputNodes <= jobs.finishedNodes:
643 # all dependencies have completed, can start new job
644 if len(jobs.running) < self.numProc:
645 _LOG.debug("Submitting %s", job)
646 jobs.submit(job, self.quantumExecutor, self.startMethod)
647 if len(jobs.running) >= self.numProc:
648 # Cannot start any more jobs, wait until something
649 # finishes.
650 break
652 # Do cleanup for timed out jobs if necessary.
653 jobs.cleanup()
655 # Print progress message if something changed.
656 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
657 if (finishedCount, failedCount) != (newFinished, newFailed):
658 finishedCount, failedCount = newFinished, newFailed
659 totalCount = len(jobs.jobs)
660 _LOG.info(
661 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
662 finishedCount,
663 failedCount,
664 totalCount - finishedCount - failedCount,
665 totalCount,
666 )
668 # Here we want to wait until one of the running jobs completes
669 # but multiprocessing does not provide an API for that, for now
670 # just sleep a little bit and go back to the loop.
671 if jobs.running:
672 time.sleep(0.1)
674 if jobs.failedNodes:
675 # print list of failed jobs
676 _LOG.error("Failed jobs:")
677 for job in jobs.jobs:
678 if job.state != JobState.FINISHED:
679 _LOG.error(" - %s: %s", job.state.name, job)
681 # if any job failed raise an exception
682 if jobs.failedNodes == jobs.timedOutNodes:
683 raise MPTimeoutError("One or more tasks timed out during execution.")
684 else:
685 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
687 def getReport(self) -> Report | None:
688 # Docstring inherited from base class
689 if self.report is None:
690 raise RuntimeError("getReport() called before execute()")
691 return self.report