Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 12%
307 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-18 09:18 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-18 09:18 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import threading
34import time
35from collections.abc import Iterable
36from enum import Enum
37from typing import Literal, Optional
39from lsst.daf.butler.cli.cliLog import CliLog
40from lsst.pipe.base import InvalidQuantumError, TaskDef
41from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
42from lsst.utils.threads import disable_implicit_threading
44from .executionGraphFixup import ExecutionGraphFixup
45from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
46from .reports import ExecutionStatus, QuantumReport, Report
48_LOG = logging.getLogger(__name__)
51# Possible states for the executing task:
52# - PENDING: job has not started yet
53# - RUNNING: job is currently executing
54# - FINISHED: job finished successfully
55# - FAILED: job execution failed (process returned non-zero status)
56# - TIMED_OUT: job is killed due to too long execution time
57# - FAILED_DEP: one of the dependencies of this job has failed/timed out
58JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
61class _Job:
62 """Class representing a job running single task.
64 Parameters
65 ----------
66 qnode: `~lsst.pipe.base.QuantumNode`
67 Quantum and some associated information.
68 """
70 def __init__(self, qnode: QuantumNode):
71 self.qnode = qnode
72 self.process: Optional[multiprocessing.process.BaseProcess] = None
73 self._state = JobState.PENDING
74 self.started: float = 0.0
75 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
76 self._terminated = False
78 @property
79 def state(self) -> JobState:
80 """Job processing state (JobState)"""
81 return self._state
83 @property
84 def terminated(self) -> bool:
85 """Return True if job was killed by stop() method and negative exit
86 code is returned from child process. (`bool`)"""
87 if self._terminated:
88 assert self.process is not None, "Process must be started"
89 if self.process.exitcode is not None:
90 return self.process.exitcode < 0
91 return False
93 def start(
94 self,
95 quantumExecutor: QuantumExecutor,
96 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
97 ) -> None:
98 """Start process which runs the task.
100 Parameters
101 ----------
102 quantumExecutor : `QuantumExecutor`
103 Executor for single quantum.
104 startMethod : `str`, optional
105 Start method from `multiprocessing` module.
106 """
107 # Unpickling of quantum has to happen after butler/executor, this is
108 # why it is pickled manually here.
109 quantum_pickle = pickle.dumps(self.qnode.quantum)
110 taskDef = self.qnode.taskDef
111 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
112 logConfigState = CliLog.configState
114 mp_ctx = multiprocessing.get_context(startMethod)
115 self.process = mp_ctx.Process(
116 target=_Job._executeJob,
117 args=(quantumExecutor, taskDef, quantum_pickle, logConfigState, snd_conn),
118 name=f"task-{self.qnode.quantum.dataId}",
119 )
120 self.process.start()
121 self.started = time.time()
122 self._state = JobState.RUNNING
124 @staticmethod
125 def _executeJob(
126 quantumExecutor: QuantumExecutor,
127 taskDef: TaskDef,
128 quantum_pickle: bytes,
129 logConfigState: list,
130 snd_conn: multiprocessing.connection.Connection,
131 ) -> None:
132 """Execute a job with arguments.
134 Parameters
135 ----------
136 quantumExecutor : `QuantumExecutor`
137 Executor for single quantum.
138 taskDef : `bytes`
139 Task definition structure.
140 quantum_pickle : `bytes`
141 Quantum for this task execution in pickled form.
142 snd_conn : `multiprocessing.Connection`
143 Connection to send job report to parent process.
144 """
145 # This terrible hack is a workaround for Python threading bug:
146 # https://github.com/python/cpython/issues/102512. Should be removed
147 # when fix for that bug is deployed. Inspired by
148 # https://github.com/QubesOS/qubes-core-admin-client/pull/236/files.
149 thread = threading.current_thread()
150 if isinstance(thread, threading._DummyThread):
151 if getattr(thread, "_tstate_lock", "") is None:
152 thread._set_tstate_lock() # type: ignore[attr-defined]
154 if logConfigState and not CliLog.configState:
155 # means that we are in a new spawned Python process and we have to
156 # re-initialize logging
157 CliLog.replayConfigState(logConfigState)
159 quantum = pickle.loads(quantum_pickle)
160 try:
161 quantumExecutor.execute(taskDef, quantum)
162 finally:
163 # If sending fails we do not want this new exception to be exposed.
164 try:
165 report = quantumExecutor.getReport()
166 snd_conn.send(report)
167 except Exception:
168 pass
170 def stop(self) -> None:
171 """Stop the process."""
172 assert self.process is not None, "Process must be started"
173 self.process.terminate()
174 # give it 1 second to finish or KILL
175 for i in range(10):
176 time.sleep(0.1)
177 if not self.process.is_alive():
178 break
179 else:
180 _LOG.debug("Killing process %s", self.process.name)
181 self.process.kill()
182 self._terminated = True
184 def cleanup(self) -> None:
185 """Release processes resources, has to be called for each finished
186 process.
187 """
188 if self.process and not self.process.is_alive():
189 self.process.close()
190 self.process = None
191 self._rcv_conn = None
193 def report(self) -> QuantumReport:
194 """Return task report, should be called after process finishes and
195 before cleanup().
196 """
197 assert self.process is not None, "Process must be started"
198 assert self._rcv_conn is not None, "Process must be started"
199 try:
200 report = self._rcv_conn.recv()
201 report.exitCode = self.process.exitcode
202 except Exception:
203 # Likely due to the process killed, but there may be other reasons.
204 # Exit code should not be None, this is to keep mypy happy.
205 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
206 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
207 report = QuantumReport.from_exit_code(
208 exitCode=exitcode,
209 dataId=self.qnode.quantum.dataId,
210 taskLabel=self.qnode.taskDef.label,
211 )
212 if self.terminated:
213 # Means it was killed, assume it's due to timeout
214 report.status = ExecutionStatus.TIMEOUT
215 return report
217 def failMessage(self) -> str:
218 """Return a message describing task failure"""
219 assert self.process is not None, "Process must be started"
220 assert self.process.exitcode is not None, "Process has to finish"
221 exitcode = self.process.exitcode
222 if exitcode < 0:
223 # Negative exit code means it is killed by signal
224 signum = -exitcode
225 msg = f"Task {self} failed, killed by signal {signum}"
226 # Just in case this is some very odd signal, expect ValueError
227 try:
228 strsignal = signal.strsignal(signum)
229 msg = f"{msg} ({strsignal})"
230 except ValueError:
231 pass
232 elif exitcode > 0:
233 msg = f"Task {self} failed, exit code={exitcode}"
234 else:
235 msg = ""
236 return msg
238 def __str__(self) -> str:
239 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
242class _JobList:
243 """Simple list of _Job instances with few convenience methods.
245 Parameters
246 ----------
247 iterable : iterable of `~lsst.pipe.base.QuantumNode`
248 Sequence of Quanta to execute. This has to be ordered according to
249 task dependencies.
250 """
252 def __init__(self, iterable: Iterable[QuantumNode]):
253 self.jobs = [_Job(qnode) for qnode in iterable]
254 self.pending = self.jobs[:]
255 self.running: list[_Job] = []
256 self.finishedNodes: set[QuantumNode] = set()
257 self.failedNodes: set[QuantumNode] = set()
258 self.timedOutNodes: set[QuantumNode] = set()
260 def submit(
261 self,
262 job: _Job,
263 quantumExecutor: QuantumExecutor,
264 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
265 ) -> None:
266 """Submit one more job for execution
268 Parameters
269 ----------
270 job : `_Job`
271 Job to submit.
272 quantumExecutor : `QuantumExecutor`
273 Executor for single quantum.
274 startMethod : `str`, optional
275 Start method from `multiprocessing` module.
276 """
277 # this will raise if job is not in pending list
278 self.pending.remove(job)
279 job.start(quantumExecutor, startMethod)
280 self.running.append(job)
282 def setJobState(self, job: _Job, state: JobState) -> None:
283 """Update job state.
285 Parameters
286 ----------
287 job : `_Job`
288 Job to submit.
289 state : `JobState`
290 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
291 FAILED_DEP state is acceptable.
292 """
293 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
294 assert state in allowedStates, f"State {state} not allowed here"
296 # remove job from pending/running lists
297 if job.state == JobState.PENDING:
298 self.pending.remove(job)
299 elif job.state == JobState.RUNNING:
300 self.running.remove(job)
302 qnode = job.qnode
303 # it should not be in any of these, but just in case
304 self.finishedNodes.discard(qnode)
305 self.failedNodes.discard(qnode)
306 self.timedOutNodes.discard(qnode)
308 job._state = state
309 if state == JobState.FINISHED:
310 self.finishedNodes.add(qnode)
311 elif state == JobState.FAILED:
312 self.failedNodes.add(qnode)
313 elif state == JobState.FAILED_DEP:
314 self.failedNodes.add(qnode)
315 elif state == JobState.TIMED_OUT:
316 self.failedNodes.add(qnode)
317 self.timedOutNodes.add(qnode)
318 else:
319 raise ValueError(f"Unexpected state value: {state}")
321 def cleanup(self) -> None:
322 """Do periodic cleanup for jobs that did not finish correctly.
324 If timed out jobs are killed but take too long to stop then regular
325 cleanup will not work for them. Here we check all timed out jobs
326 periodically and do cleanup if they managed to die by this time.
327 """
328 for job in self.jobs:
329 if job.state == JobState.TIMED_OUT and job.process is not None:
330 job.cleanup()
333class MPGraphExecutorError(Exception):
334 """Exception class for errors raised by MPGraphExecutor."""
336 pass
339class MPTimeoutError(MPGraphExecutorError):
340 """Exception raised when task execution times out."""
342 pass
345class MPGraphExecutor(QuantumGraphExecutor):
346 """Implementation of QuantumGraphExecutor using same-host multiprocess
347 execution of Quanta.
349 Parameters
350 ----------
351 numProc : `int`
352 Number of processes to use for executing tasks.
353 timeout : `float`
354 Time in seconds to wait for tasks to finish.
355 quantumExecutor : `QuantumExecutor`
356 Executor for single quantum. For multiprocess-style execution when
357 ``numProc`` is greater than one this instance must support pickle.
358 startMethod : `str`, optional
359 Start method from `multiprocessing` module, `None` selects the best
360 one for current platform.
361 failFast : `bool`, optional
362 If set to ``True`` then stop processing on first error from any task.
363 pdb : `str`, optional
364 Debugger to import and use (via the ``post_mortem`` function) in the
365 event of an exception.
366 executionGraphFixup : `ExecutionGraphFixup`, optional
367 Instance used for modification of execution graph.
368 """
370 def __init__(
371 self,
372 numProc: int,
373 timeout: float,
374 quantumExecutor: QuantumExecutor,
375 *,
376 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
377 failFast: bool = False,
378 pdb: Optional[str] = None,
379 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
380 ):
381 self.numProc = numProc
382 self.timeout = timeout
383 self.quantumExecutor = quantumExecutor
384 self.failFast = failFast
385 self.pdb = pdb
386 self.executionGraphFixup = executionGraphFixup
387 self.report: Optional[Report] = None
389 # We set default start method as spawn for MacOS and fork for Linux;
390 # None for all other platforms to use multiprocessing default.
391 if startMethod is None:
392 methods = dict(linux="fork", darwin="spawn")
393 startMethod = methods.get(sys.platform) # type: ignore
394 self.startMethod = startMethod
396 def execute(self, graph: QuantumGraph) -> None:
397 # Docstring inherited from QuantumGraphExecutor.execute
398 graph = self._fixupQuanta(graph)
399 self.report = Report()
400 try:
401 if self.numProc > 1:
402 self._executeQuantaMP(graph, self.report)
403 else:
404 self._executeQuantaInProcess(graph, self.report)
405 except Exception as exc:
406 self.report.set_exception(exc)
407 raise
409 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
410 """Call fixup code to modify execution graph.
412 Parameters
413 ----------
414 graph : `QuantumGraph`
415 `QuantumGraph` to modify
417 Returns
418 -------
419 graph : `QuantumGraph`
420 Modified `QuantumGraph`.
422 Raises
423 ------
424 MPGraphExecutorError
425 Raised if execution graph cannot be ordered after modification,
426 i.e. it has dependency cycles.
427 """
428 if not self.executionGraphFixup:
429 return graph
431 _LOG.debug("Call execution graph fixup method")
432 graph = self.executionGraphFixup.fixupQuanta(graph)
434 # Detect if there is now a cycle created within the graph
435 if graph.findCycle():
436 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
438 return graph
440 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
441 """Execute all Quanta in current process.
443 Parameters
444 ----------
445 graph : `QuantumGraph`
446 `QuantumGraph` that is to be executed
447 report : `Report`
448 Object for reporting execution status.
449 """
450 successCount, totalCount = 0, len(graph)
451 failedNodes: set[QuantumNode] = set()
452 for qnode in graph:
453 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
455 # Any failed inputs mean that the quantum has to be skipped.
456 inputNodes = graph.determineInputsToQuantumNode(qnode)
457 if inputNodes & failedNodes:
458 _LOG.error(
459 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
460 qnode.taskDef,
461 qnode.quantum.dataId,
462 )
463 failedNodes.add(qnode)
464 failed_quantum_report = QuantumReport(
465 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
466 )
467 report.quantaReports.append(failed_quantum_report)
468 continue
470 _LOG.debug("Executing %s", qnode)
471 try:
472 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum)
473 successCount += 1
474 except Exception as exc:
475 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
476 _LOG.error(
477 "Task <%s dataId=%s> failed; dropping into pdb.",
478 qnode.taskDef,
479 qnode.quantum.dataId,
480 exc_info=exc,
481 )
482 try:
483 pdb = importlib.import_module(self.pdb)
484 except ImportError as imp_exc:
485 raise MPGraphExecutorError(
486 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
487 ) from exc
488 if not hasattr(pdb, "post_mortem"):
489 raise MPGraphExecutorError(
490 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
491 ) from exc
492 pdb.post_mortem(exc.__traceback__)
493 failedNodes.add(qnode)
494 report.status = ExecutionStatus.FAILURE
495 if self.failFast:
496 raise MPGraphExecutorError(
497 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
498 ) from exc
499 else:
500 # Note that there could be exception safety issues, which
501 # we presently ignore.
502 _LOG.error(
503 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
504 qnode.taskDef,
505 qnode.quantum.dataId,
506 exc_info=exc,
507 )
508 finally:
509 # sqlalchemy has some objects that can last until a garbage
510 # collection cycle is run, which can happen at unpredictable
511 # times, run a collection loop here explicitly.
512 gc.collect()
514 quantum_report = self.quantumExecutor.getReport()
515 if quantum_report:
516 report.quantaReports.append(quantum_report)
518 _LOG.info(
519 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
520 successCount,
521 len(failedNodes),
522 totalCount - successCount - len(failedNodes),
523 totalCount,
524 )
526 # Raise an exception if there were any failures.
527 if failedNodes:
528 raise MPGraphExecutorError("One or more tasks failed during execution.")
530 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
531 """Execute all Quanta in separate processes.
533 Parameters
534 ----------
535 graph : `QuantumGraph`
536 `QuantumGraph` that is to be executed.
537 report : `Report`
538 Object for reporting execution status.
539 """
541 disable_implicit_threading() # To prevent thread contention
543 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
545 # re-pack input quantum data into jobs list
546 jobs = _JobList(graph)
548 # check that all tasks can run in sub-process
549 for job in jobs.jobs:
550 taskDef = job.qnode.taskDef
551 if not taskDef.taskClass.canMultiprocess:
552 raise MPGraphExecutorError(
553 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
554 )
556 finishedCount, failedCount = 0, 0
557 while jobs.pending or jobs.running:
558 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
559 _LOG.debug("#runningJobs: %s", len(jobs.running))
561 # See if any jobs have finished
562 for job in jobs.running:
563 assert job.process is not None, "Process cannot be None"
564 if not job.process.is_alive():
565 _LOG.debug("finished: %s", job)
566 # finished
567 exitcode = job.process.exitcode
568 quantum_report = job.report()
569 report.quantaReports.append(quantum_report)
570 if exitcode == 0:
571 jobs.setJobState(job, JobState.FINISHED)
572 job.cleanup()
573 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
574 else:
575 if job.terminated:
576 # Was killed due to timeout.
577 if report.status == ExecutionStatus.SUCCESS:
578 # Do not override global FAILURE status
579 report.status = ExecutionStatus.TIMEOUT
580 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
581 jobs.setJobState(job, JobState.TIMED_OUT)
582 else:
583 report.status = ExecutionStatus.FAILURE
584 # failMessage() has to be called before cleanup()
585 message = job.failMessage()
586 jobs.setJobState(job, JobState.FAILED)
588 job.cleanup()
589 _LOG.debug("failed: %s", job)
590 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
591 # stop all running jobs
592 for stopJob in jobs.running:
593 if stopJob is not job:
594 stopJob.stop()
595 if job.state is JobState.TIMED_OUT:
596 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
597 else:
598 raise MPGraphExecutorError(message)
599 else:
600 _LOG.error("%s; processing will continue for remaining tasks.", message)
601 else:
602 # check for timeout
603 now = time.time()
604 if now - job.started > self.timeout:
605 # Try to kill it, and there is a chance that it
606 # finishes successfully before it gets killed. Exit
607 # status is handled by the code above on next
608 # iteration.
609 _LOG.debug("Terminating job %s due to timeout", job)
610 job.stop()
612 # Fail jobs whose inputs failed, this may need several iterations
613 # if the order is not right, will be done in the next loop.
614 if jobs.failedNodes:
615 for job in jobs.pending:
616 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
617 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
618 if jobInputNodes & jobs.failedNodes:
619 quantum_report = QuantumReport(
620 status=ExecutionStatus.SKIPPED,
621 dataId=job.qnode.quantum.dataId,
622 taskLabel=job.qnode.taskDef.label,
623 )
624 report.quantaReports.append(quantum_report)
625 jobs.setJobState(job, JobState.FAILED_DEP)
626 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
628 # see if we can start more jobs
629 if len(jobs.running) < self.numProc:
630 for job in jobs.pending:
631 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
632 if jobInputNodes <= jobs.finishedNodes:
633 # all dependencies have completed, can start new job
634 if len(jobs.running) < self.numProc:
635 _LOG.debug("Submitting %s", job)
636 jobs.submit(job, self.quantumExecutor, self.startMethod)
637 if len(jobs.running) >= self.numProc:
638 # Cannot start any more jobs, wait until something
639 # finishes.
640 break
642 # Do cleanup for timed out jobs if necessary.
643 jobs.cleanup()
645 # Print progress message if something changed.
646 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
647 if (finishedCount, failedCount) != (newFinished, newFailed):
648 finishedCount, failedCount = newFinished, newFailed
649 totalCount = len(jobs.jobs)
650 _LOG.info(
651 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
652 finishedCount,
653 failedCount,
654 totalCount - finishedCount - failedCount,
655 totalCount,
656 )
658 # Here we want to wait until one of the running jobs completes
659 # but multiprocessing does not provide an API for that, for now
660 # just sleep a little bit and go back to the loop.
661 if jobs.running:
662 time.sleep(0.1)
664 if jobs.failedNodes:
665 # print list of failed jobs
666 _LOG.error("Failed jobs:")
667 for job in jobs.jobs:
668 if job.state != JobState.FINISHED:
669 _LOG.error(" - %s: %s", job.state.name, job)
671 # if any job failed raise an exception
672 if jobs.failedNodes == jobs.timedOutNodes:
673 raise MPTimeoutError("One or more tasks timed out during execution.")
674 else:
675 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
677 def getReport(self) -> Optional[Report]:
678 # Docstring inherited from base class
679 if self.report is None:
680 raise RuntimeError("getReport() called before execute()")
681 return self.report