Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 13%
311 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-22 02:21 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-22 02:21 -0700
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import threading
34import time
35import warnings
36from collections.abc import Iterable
37from enum import Enum
38from typing import Literal, Optional
40from lsst.daf.butler import UnresolvedRefWarning
41from lsst.daf.butler.cli.cliLog import CliLog
42from lsst.pipe.base import InvalidQuantumError, TaskDef
43from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
44from lsst.utils.threads import disable_implicit_threading
46from .executionGraphFixup import ExecutionGraphFixup
47from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
48from .reports import ExecutionStatus, QuantumReport, Report
50_LOG = logging.getLogger(__name__)
53# Possible states for the executing task:
54# - PENDING: job has not started yet
55# - RUNNING: job is currently executing
56# - FINISHED: job finished successfully
57# - FAILED: job execution failed (process returned non-zero status)
58# - TIMED_OUT: job is killed due to too long execution time
59# - FAILED_DEP: one of the dependencies of this job has failed/timed out
60JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
63class _Job:
64 """Class representing a job running single task.
66 Parameters
67 ----------
68 qnode: `~lsst.pipe.base.QuantumNode`
69 Quantum and some associated information.
70 """
72 def __init__(self, qnode: QuantumNode):
73 self.qnode = qnode
74 self.process: Optional[multiprocessing.process.BaseProcess] = None
75 self._state = JobState.PENDING
76 self.started: float = 0.0
77 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
78 self._terminated = False
80 @property
81 def state(self) -> JobState:
82 """Job processing state (JobState)"""
83 return self._state
85 @property
86 def terminated(self) -> bool:
87 """Return True if job was killed by stop() method and negative exit
88 code is returned from child process. (`bool`)"""
89 if self._terminated:
90 assert self.process is not None, "Process must be started"
91 if self.process.exitcode is not None:
92 return self.process.exitcode < 0
93 return False
95 def start(
96 self,
97 quantumExecutor: QuantumExecutor,
98 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
99 ) -> None:
100 """Start process which runs the task.
102 Parameters
103 ----------
104 quantumExecutor : `QuantumExecutor`
105 Executor for single quantum.
106 startMethod : `str`, optional
107 Start method from `multiprocessing` module.
108 """
109 # Unpickling of quantum has to happen after butler/executor, this is
110 # why it is pickled manually here.
111 quantum_pickle = pickle.dumps(self.qnode.quantum)
112 taskDef = self.qnode.taskDef
113 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
114 logConfigState = CliLog.configState
116 mp_ctx = multiprocessing.get_context(startMethod)
117 self.process = mp_ctx.Process(
118 target=_Job._executeJob,
119 args=(quantumExecutor, taskDef, quantum_pickle, logConfigState, snd_conn),
120 name=f"task-{self.qnode.quantum.dataId}",
121 )
122 self.process.start()
123 self.started = time.time()
124 self._state = JobState.RUNNING
126 @staticmethod
127 def _executeJob(
128 quantumExecutor: QuantumExecutor,
129 taskDef: TaskDef,
130 quantum_pickle: bytes,
131 logConfigState: list,
132 snd_conn: multiprocessing.connection.Connection,
133 ) -> None:
134 """Execute a job with arguments.
136 Parameters
137 ----------
138 quantumExecutor : `QuantumExecutor`
139 Executor for single quantum.
140 taskDef : `bytes`
141 Task definition structure.
142 quantum_pickle : `bytes`
143 Quantum for this task execution in pickled form.
144 snd_conn : `multiprocessing.Connection`
145 Connection to send job report to parent process.
146 """
147 # This terrible hack is a workaround for Python threading bug:
148 # https://github.com/python/cpython/issues/102512. Should be removed
149 # when fix for that bug is deployed. Inspired by
150 # https://github.com/QubesOS/qubes-core-admin-client/pull/236/files.
151 thread = threading.current_thread()
152 if isinstance(thread, threading._DummyThread):
153 if getattr(thread, "_tstate_lock", "") is None:
154 thread._set_tstate_lock() # type: ignore[attr-defined]
156 if logConfigState and not CliLog.configState:
157 # means that we are in a new spawned Python process and we have to
158 # re-initialize logging
159 CliLog.replayConfigState(logConfigState)
161 with warnings.catch_warnings():
162 # Loading the pickle file can trigger unresolved ref warnings
163 # so we must hide them.
164 warnings.simplefilter("ignore", category=UnresolvedRefWarning)
165 quantum = pickle.loads(quantum_pickle)
166 try:
167 quantumExecutor.execute(taskDef, quantum)
168 finally:
169 # If sending fails we do not want this new exception to be exposed.
170 try:
171 report = quantumExecutor.getReport()
172 snd_conn.send(report)
173 except Exception:
174 pass
176 def stop(self) -> None:
177 """Stop the process."""
178 assert self.process is not None, "Process must be started"
179 self.process.terminate()
180 # give it 1 second to finish or KILL
181 for i in range(10):
182 time.sleep(0.1)
183 if not self.process.is_alive():
184 break
185 else:
186 _LOG.debug("Killing process %s", self.process.name)
187 self.process.kill()
188 self._terminated = True
190 def cleanup(self) -> None:
191 """Release processes resources, has to be called for each finished
192 process.
193 """
194 if self.process and not self.process.is_alive():
195 self.process.close()
196 self.process = None
197 self._rcv_conn = None
199 def report(self) -> QuantumReport:
200 """Return task report, should be called after process finishes and
201 before cleanup().
202 """
203 assert self.process is not None, "Process must be started"
204 assert self._rcv_conn is not None, "Process must be started"
205 try:
206 report = self._rcv_conn.recv()
207 report.exitCode = self.process.exitcode
208 except Exception:
209 # Likely due to the process killed, but there may be other reasons.
210 # Exit code should not be None, this is to keep mypy happy.
211 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
212 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
213 report = QuantumReport.from_exit_code(
214 exitCode=exitcode,
215 dataId=self.qnode.quantum.dataId,
216 taskLabel=self.qnode.taskDef.label,
217 )
218 if self.terminated:
219 # Means it was killed, assume it's due to timeout
220 report.status = ExecutionStatus.TIMEOUT
221 return report
223 def failMessage(self) -> str:
224 """Return a message describing task failure"""
225 assert self.process is not None, "Process must be started"
226 assert self.process.exitcode is not None, "Process has to finish"
227 exitcode = self.process.exitcode
228 if exitcode < 0:
229 # Negative exit code means it is killed by signal
230 signum = -exitcode
231 msg = f"Task {self} failed, killed by signal {signum}"
232 # Just in case this is some very odd signal, expect ValueError
233 try:
234 strsignal = signal.strsignal(signum)
235 msg = f"{msg} ({strsignal})"
236 except ValueError:
237 pass
238 elif exitcode > 0:
239 msg = f"Task {self} failed, exit code={exitcode}"
240 else:
241 msg = ""
242 return msg
244 def __str__(self) -> str:
245 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
248class _JobList:
249 """Simple list of _Job instances with few convenience methods.
251 Parameters
252 ----------
253 iterable : iterable of `~lsst.pipe.base.QuantumNode`
254 Sequence of Quanta to execute. This has to be ordered according to
255 task dependencies.
256 """
258 def __init__(self, iterable: Iterable[QuantumNode]):
259 self.jobs = [_Job(qnode) for qnode in iterable]
260 self.pending = self.jobs[:]
261 self.running: list[_Job] = []
262 self.finishedNodes: set[QuantumNode] = set()
263 self.failedNodes: set[QuantumNode] = set()
264 self.timedOutNodes: set[QuantumNode] = set()
266 def submit(
267 self,
268 job: _Job,
269 quantumExecutor: QuantumExecutor,
270 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
271 ) -> None:
272 """Submit one more job for execution
274 Parameters
275 ----------
276 job : `_Job`
277 Job to submit.
278 quantumExecutor : `QuantumExecutor`
279 Executor for single quantum.
280 startMethod : `str`, optional
281 Start method from `multiprocessing` module.
282 """
283 # this will raise if job is not in pending list
284 self.pending.remove(job)
285 job.start(quantumExecutor, startMethod)
286 self.running.append(job)
288 def setJobState(self, job: _Job, state: JobState) -> None:
289 """Update job state.
291 Parameters
292 ----------
293 job : `_Job`
294 Job to submit.
295 state : `JobState`
296 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
297 FAILED_DEP state is acceptable.
298 """
299 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
300 assert state in allowedStates, f"State {state} not allowed here"
302 # remove job from pending/running lists
303 if job.state == JobState.PENDING:
304 self.pending.remove(job)
305 elif job.state == JobState.RUNNING:
306 self.running.remove(job)
308 qnode = job.qnode
309 # it should not be in any of these, but just in case
310 self.finishedNodes.discard(qnode)
311 self.failedNodes.discard(qnode)
312 self.timedOutNodes.discard(qnode)
314 job._state = state
315 if state == JobState.FINISHED:
316 self.finishedNodes.add(qnode)
317 elif state == JobState.FAILED:
318 self.failedNodes.add(qnode)
319 elif state == JobState.FAILED_DEP:
320 self.failedNodes.add(qnode)
321 elif state == JobState.TIMED_OUT:
322 self.failedNodes.add(qnode)
323 self.timedOutNodes.add(qnode)
324 else:
325 raise ValueError(f"Unexpected state value: {state}")
327 def cleanup(self) -> None:
328 """Do periodic cleanup for jobs that did not finish correctly.
330 If timed out jobs are killed but take too long to stop then regular
331 cleanup will not work for them. Here we check all timed out jobs
332 periodically and do cleanup if they managed to die by this time.
333 """
334 for job in self.jobs:
335 if job.state == JobState.TIMED_OUT and job.process is not None:
336 job.cleanup()
339class MPGraphExecutorError(Exception):
340 """Exception class for errors raised by MPGraphExecutor."""
342 pass
345class MPTimeoutError(MPGraphExecutorError):
346 """Exception raised when task execution times out."""
348 pass
351class MPGraphExecutor(QuantumGraphExecutor):
352 """Implementation of QuantumGraphExecutor using same-host multiprocess
353 execution of Quanta.
355 Parameters
356 ----------
357 numProc : `int`
358 Number of processes to use for executing tasks.
359 timeout : `float`
360 Time in seconds to wait for tasks to finish.
361 quantumExecutor : `QuantumExecutor`
362 Executor for single quantum. For multiprocess-style execution when
363 ``numProc`` is greater than one this instance must support pickle.
364 startMethod : `str`, optional
365 Start method from `multiprocessing` module, `None` selects the best
366 one for current platform.
367 failFast : `bool`, optional
368 If set to ``True`` then stop processing on first error from any task.
369 pdb : `str`, optional
370 Debugger to import and use (via the ``post_mortem`` function) in the
371 event of an exception.
372 executionGraphFixup : `ExecutionGraphFixup`, optional
373 Instance used for modification of execution graph.
374 """
376 def __init__(
377 self,
378 numProc: int,
379 timeout: float,
380 quantumExecutor: QuantumExecutor,
381 *,
382 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
383 failFast: bool = False,
384 pdb: Optional[str] = None,
385 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
386 ):
387 self.numProc = numProc
388 self.timeout = timeout
389 self.quantumExecutor = quantumExecutor
390 self.failFast = failFast
391 self.pdb = pdb
392 self.executionGraphFixup = executionGraphFixup
393 self.report: Optional[Report] = None
395 # We set default start method as spawn for MacOS and fork for Linux;
396 # None for all other platforms to use multiprocessing default.
397 if startMethod is None:
398 methods = dict(linux="fork", darwin="spawn")
399 startMethod = methods.get(sys.platform) # type: ignore
400 self.startMethod = startMethod
402 def execute(self, graph: QuantumGraph) -> None:
403 # Docstring inherited from QuantumGraphExecutor.execute
404 graph = self._fixupQuanta(graph)
405 self.report = Report()
406 try:
407 if self.numProc > 1:
408 self._executeQuantaMP(graph, self.report)
409 else:
410 self._executeQuantaInProcess(graph, self.report)
411 except Exception as exc:
412 self.report.set_exception(exc)
413 raise
415 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
416 """Call fixup code to modify execution graph.
418 Parameters
419 ----------
420 graph : `QuantumGraph`
421 `QuantumGraph` to modify
423 Returns
424 -------
425 graph : `QuantumGraph`
426 Modified `QuantumGraph`.
428 Raises
429 ------
430 MPGraphExecutorError
431 Raised if execution graph cannot be ordered after modification,
432 i.e. it has dependency cycles.
433 """
434 if not self.executionGraphFixup:
435 return graph
437 _LOG.debug("Call execution graph fixup method")
438 graph = self.executionGraphFixup.fixupQuanta(graph)
440 # Detect if there is now a cycle created within the graph
441 if graph.findCycle():
442 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
444 return graph
446 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
447 """Execute all Quanta in current process.
449 Parameters
450 ----------
451 graph : `QuantumGraph`
452 `QuantumGraph` that is to be executed
453 report : `Report`
454 Object for reporting execution status.
455 """
456 successCount, totalCount = 0, len(graph)
457 failedNodes: set[QuantumNode] = set()
458 for qnode in graph:
459 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
461 # Any failed inputs mean that the quantum has to be skipped.
462 inputNodes = graph.determineInputsToQuantumNode(qnode)
463 if inputNodes & failedNodes:
464 _LOG.error(
465 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
466 qnode.taskDef,
467 qnode.quantum.dataId,
468 )
469 failedNodes.add(qnode)
470 failed_quantum_report = QuantumReport(
471 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
472 )
473 report.quantaReports.append(failed_quantum_report)
474 continue
476 _LOG.debug("Executing %s", qnode)
477 try:
478 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum)
479 successCount += 1
480 except Exception as exc:
481 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
482 _LOG.error(
483 "Task <%s dataId=%s> failed; dropping into pdb.",
484 qnode.taskDef,
485 qnode.quantum.dataId,
486 exc_info=exc,
487 )
488 try:
489 pdb = importlib.import_module(self.pdb)
490 except ImportError as imp_exc:
491 raise MPGraphExecutorError(
492 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
493 ) from exc
494 if not hasattr(pdb, "post_mortem"):
495 raise MPGraphExecutorError(
496 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
497 ) from exc
498 pdb.post_mortem(exc.__traceback__)
499 failedNodes.add(qnode)
500 report.status = ExecutionStatus.FAILURE
501 if self.failFast:
502 raise MPGraphExecutorError(
503 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
504 ) from exc
505 else:
506 # Note that there could be exception safety issues, which
507 # we presently ignore.
508 _LOG.error(
509 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
510 qnode.taskDef,
511 qnode.quantum.dataId,
512 exc_info=exc,
513 )
514 finally:
515 # sqlalchemy has some objects that can last until a garbage
516 # collection cycle is run, which can happen at unpredictable
517 # times, run a collection loop here explicitly.
518 gc.collect()
520 quantum_report = self.quantumExecutor.getReport()
521 if quantum_report:
522 report.quantaReports.append(quantum_report)
524 _LOG.info(
525 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
526 successCount,
527 len(failedNodes),
528 totalCount - successCount - len(failedNodes),
529 totalCount,
530 )
532 # Raise an exception if there were any failures.
533 if failedNodes:
534 raise MPGraphExecutorError("One or more tasks failed during execution.")
536 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
537 """Execute all Quanta in separate processes.
539 Parameters
540 ----------
541 graph : `QuantumGraph`
542 `QuantumGraph` that is to be executed.
543 report : `Report`
544 Object for reporting execution status.
545 """
547 disable_implicit_threading() # To prevent thread contention
549 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
551 # re-pack input quantum data into jobs list
552 jobs = _JobList(graph)
554 # check that all tasks can run in sub-process
555 for job in jobs.jobs:
556 taskDef = job.qnode.taskDef
557 if not taskDef.taskClass.canMultiprocess:
558 raise MPGraphExecutorError(
559 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
560 )
562 finishedCount, failedCount = 0, 0
563 while jobs.pending or jobs.running:
564 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
565 _LOG.debug("#runningJobs: %s", len(jobs.running))
567 # See if any jobs have finished
568 for job in jobs.running:
569 assert job.process is not None, "Process cannot be None"
570 if not job.process.is_alive():
571 _LOG.debug("finished: %s", job)
572 # finished
573 exitcode = job.process.exitcode
574 quantum_report = job.report()
575 report.quantaReports.append(quantum_report)
576 if exitcode == 0:
577 jobs.setJobState(job, JobState.FINISHED)
578 job.cleanup()
579 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
580 else:
581 if job.terminated:
582 # Was killed due to timeout.
583 if report.status == ExecutionStatus.SUCCESS:
584 # Do not override global FAILURE status
585 report.status = ExecutionStatus.TIMEOUT
586 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
587 jobs.setJobState(job, JobState.TIMED_OUT)
588 else:
589 report.status = ExecutionStatus.FAILURE
590 # failMessage() has to be called before cleanup()
591 message = job.failMessage()
592 jobs.setJobState(job, JobState.FAILED)
594 job.cleanup()
595 _LOG.debug("failed: %s", job)
596 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
597 # stop all running jobs
598 for stopJob in jobs.running:
599 if stopJob is not job:
600 stopJob.stop()
601 if job.state is JobState.TIMED_OUT:
602 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
603 else:
604 raise MPGraphExecutorError(message)
605 else:
606 _LOG.error("%s; processing will continue for remaining tasks.", message)
607 else:
608 # check for timeout
609 now = time.time()
610 if now - job.started > self.timeout:
611 # Try to kill it, and there is a chance that it
612 # finishes successfully before it gets killed. Exit
613 # status is handled by the code above on next
614 # iteration.
615 _LOG.debug("Terminating job %s due to timeout", job)
616 job.stop()
618 # Fail jobs whose inputs failed, this may need several iterations
619 # if the order is not right, will be done in the next loop.
620 if jobs.failedNodes:
621 for job in jobs.pending:
622 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
623 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
624 if jobInputNodes & jobs.failedNodes:
625 quantum_report = QuantumReport(
626 status=ExecutionStatus.SKIPPED,
627 dataId=job.qnode.quantum.dataId,
628 taskLabel=job.qnode.taskDef.label,
629 )
630 report.quantaReports.append(quantum_report)
631 jobs.setJobState(job, JobState.FAILED_DEP)
632 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
634 # see if we can start more jobs
635 if len(jobs.running) < self.numProc:
636 for job in jobs.pending:
637 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
638 if jobInputNodes <= jobs.finishedNodes:
639 # all dependencies have completed, can start new job
640 if len(jobs.running) < self.numProc:
641 _LOG.debug("Submitting %s", job)
642 jobs.submit(job, self.quantumExecutor, self.startMethod)
643 if len(jobs.running) >= self.numProc:
644 # Cannot start any more jobs, wait until something
645 # finishes.
646 break
648 # Do cleanup for timed out jobs if necessary.
649 jobs.cleanup()
651 # Print progress message if something changed.
652 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
653 if (finishedCount, failedCount) != (newFinished, newFailed):
654 finishedCount, failedCount = newFinished, newFailed
655 totalCount = len(jobs.jobs)
656 _LOG.info(
657 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
658 finishedCount,
659 failedCount,
660 totalCount - finishedCount - failedCount,
661 totalCount,
662 )
664 # Here we want to wait until one of the running jobs completes
665 # but multiprocessing does not provide an API for that, for now
666 # just sleep a little bit and go back to the loop.
667 if jobs.running:
668 time.sleep(0.1)
670 if jobs.failedNodes:
671 # print list of failed jobs
672 _LOG.error("Failed jobs:")
673 for job in jobs.jobs:
674 if job.state != JobState.FINISHED:
675 _LOG.error(" - %s: %s", job.state.name, job)
677 # if any job failed raise an exception
678 if jobs.failedNodes == jobs.timedOutNodes:
679 raise MPTimeoutError("One or more tasks timed out during execution.")
680 else:
681 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
683 def getReport(self) -> Optional[Report]:
684 # Docstring inherited from base class
685 if self.report is None:
686 raise RuntimeError("getReport() called before execute()")
687 return self.report