Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 12%
302 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-25 02:56 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-25 02:56 -0800
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import time
34from collections.abc import Iterable
35from enum import Enum
36from typing import Literal, Optional
38from lsst.daf.butler.cli.cliLog import CliLog
39from lsst.pipe.base import InvalidQuantumError, TaskDef
40from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
41from lsst.utils.threads import disable_implicit_threading
43from .executionGraphFixup import ExecutionGraphFixup
44from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
45from .reports import ExecutionStatus, QuantumReport, Report
47_LOG = logging.getLogger(__name__)
50# Possible states for the executing task:
51# - PENDING: job has not started yet
52# - RUNNING: job is currently executing
53# - FINISHED: job finished successfully
54# - FAILED: job execution failed (process returned non-zero status)
55# - TIMED_OUT: job is killed due to too long execution time
56# - FAILED_DEP: one of the dependencies of this job has failed/timed out
57JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
60class _Job:
61 """Class representing a job running single task.
63 Parameters
64 ----------
65 qnode: `~lsst.pipe.base.QuantumNode`
66 Quantum and some associated information.
67 """
69 def __init__(self, qnode: QuantumNode):
70 self.qnode = qnode
71 self.process: Optional[multiprocessing.process.BaseProcess] = None
72 self._state = JobState.PENDING
73 self.started: float = 0.0
74 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
75 self._terminated = False
77 @property
78 def state(self) -> JobState:
79 """Job processing state (JobState)"""
80 return self._state
82 @property
83 def terminated(self) -> bool:
84 """Return True if job was killed by stop() method and negative exit
85 code is returned from child process. (`bool`)"""
86 if self._terminated:
87 assert self.process is not None, "Process must be started"
88 if self.process.exitcode is not None:
89 return self.process.exitcode < 0
90 return False
92 def start(
93 self,
94 quantumExecutor: QuantumExecutor,
95 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
96 ) -> None:
97 """Start process which runs the task.
99 Parameters
100 ----------
101 quantumExecutor : `QuantumExecutor`
102 Executor for single quantum.
103 startMethod : `str`, optional
104 Start method from `multiprocessing` module.
105 """
106 # Unpickling of quantum has to happen after butler/executor, this is
107 # why it is pickled manually here.
108 quantum_pickle = pickle.dumps(self.qnode.quantum)
109 taskDef = self.qnode.taskDef
110 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
111 logConfigState = CliLog.configState
113 mp_ctx = multiprocessing.get_context(startMethod)
114 self.process = mp_ctx.Process(
115 target=_Job._executeJob,
116 args=(quantumExecutor, taskDef, quantum_pickle, logConfigState, snd_conn),
117 name=f"task-{self.qnode.quantum.dataId}",
118 )
119 self.process.start()
120 self.started = time.time()
121 self._state = JobState.RUNNING
123 @staticmethod
124 def _executeJob(
125 quantumExecutor: QuantumExecutor,
126 taskDef: TaskDef,
127 quantum_pickle: bytes,
128 logConfigState: list,
129 snd_conn: multiprocessing.connection.Connection,
130 ) -> None:
131 """Execute a job with arguments.
133 Parameters
134 ----------
135 quantumExecutor : `QuantumExecutor`
136 Executor for single quantum.
137 taskDef : `bytes`
138 Task definition structure.
139 quantum_pickle : `bytes`
140 Quantum for this task execution in pickled form.
141 snd_conn : `multiprocessing.Connection`
142 Connection to send job report to parent process.
143 """
144 if logConfigState and not CliLog.configState:
145 # means that we are in a new spawned Python process and we have to
146 # re-initialize logging
147 CliLog.replayConfigState(logConfigState)
149 quantum = pickle.loads(quantum_pickle)
150 try:
151 quantumExecutor.execute(taskDef, quantum)
152 finally:
153 # If sending fails we do not want this new exception to be exposed.
154 try:
155 report = quantumExecutor.getReport()
156 snd_conn.send(report)
157 except Exception:
158 pass
160 def stop(self) -> None:
161 """Stop the process."""
162 assert self.process is not None, "Process must be started"
163 self.process.terminate()
164 # give it 1 second to finish or KILL
165 for i in range(10):
166 time.sleep(0.1)
167 if not self.process.is_alive():
168 break
169 else:
170 _LOG.debug("Killing process %s", self.process.name)
171 self.process.kill()
172 self._terminated = True
174 def cleanup(self) -> None:
175 """Release processes resources, has to be called for each finished
176 process.
177 """
178 if self.process and not self.process.is_alive():
179 self.process.close()
180 self.process = None
181 self._rcv_conn = None
183 def report(self) -> QuantumReport:
184 """Return task report, should be called after process finishes and
185 before cleanup().
186 """
187 assert self.process is not None, "Process must be started"
188 assert self._rcv_conn is not None, "Process must be started"
189 try:
190 report = self._rcv_conn.recv()
191 report.exitCode = self.process.exitcode
192 except Exception:
193 # Likely due to the process killed, but there may be other reasons.
194 # Exit code should not be None, this is to keep mypy happy.
195 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
196 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
197 report = QuantumReport.from_exit_code(
198 exitCode=exitcode,
199 dataId=self.qnode.quantum.dataId,
200 taskLabel=self.qnode.taskDef.label,
201 )
202 if self.terminated:
203 # Means it was killed, assume it's due to timeout
204 report.status = ExecutionStatus.TIMEOUT
205 return report
207 def failMessage(self) -> str:
208 """Return a message describing task failure"""
209 assert self.process is not None, "Process must be started"
210 assert self.process.exitcode is not None, "Process has to finish"
211 exitcode = self.process.exitcode
212 if exitcode < 0:
213 # Negative exit code means it is killed by signal
214 signum = -exitcode
215 msg = f"Task {self} failed, killed by signal {signum}"
216 # Just in case this is some very odd signal, expect ValueError
217 try:
218 strsignal = signal.strsignal(signum)
219 msg = f"{msg} ({strsignal})"
220 except ValueError:
221 pass
222 elif exitcode > 0:
223 msg = f"Task {self} failed, exit code={exitcode}"
224 else:
225 msg = ""
226 return msg
228 def __str__(self) -> str:
229 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
232class _JobList:
233 """Simple list of _Job instances with few convenience methods.
235 Parameters
236 ----------
237 iterable : iterable of `~lsst.pipe.base.QuantumNode`
238 Sequence of Quanta to execute. This has to be ordered according to
239 task dependencies.
240 """
242 def __init__(self, iterable: Iterable[QuantumNode]):
243 self.jobs = [_Job(qnode) for qnode in iterable]
244 self.pending = self.jobs[:]
245 self.running: list[_Job] = []
246 self.finishedNodes: set[QuantumNode] = set()
247 self.failedNodes: set[QuantumNode] = set()
248 self.timedOutNodes: set[QuantumNode] = set()
250 def submit(
251 self,
252 job: _Job,
253 quantumExecutor: QuantumExecutor,
254 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
255 ) -> None:
256 """Submit one more job for execution
258 Parameters
259 ----------
260 job : `_Job`
261 Job to submit.
262 quantumExecutor : `QuantumExecutor`
263 Executor for single quantum.
264 startMethod : `str`, optional
265 Start method from `multiprocessing` module.
266 """
267 # this will raise if job is not in pending list
268 self.pending.remove(job)
269 job.start(quantumExecutor, startMethod)
270 self.running.append(job)
272 def setJobState(self, job: _Job, state: JobState) -> None:
273 """Update job state.
275 Parameters
276 ----------
277 job : `_Job`
278 Job to submit.
279 state : `JobState`
280 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
281 FAILED_DEP state is acceptable.
282 """
283 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
284 assert state in allowedStates, f"State {state} not allowed here"
286 # remove job from pending/running lists
287 if job.state == JobState.PENDING:
288 self.pending.remove(job)
289 elif job.state == JobState.RUNNING:
290 self.running.remove(job)
292 qnode = job.qnode
293 # it should not be in any of these, but just in case
294 self.finishedNodes.discard(qnode)
295 self.failedNodes.discard(qnode)
296 self.timedOutNodes.discard(qnode)
298 job._state = state
299 if state == JobState.FINISHED:
300 self.finishedNodes.add(qnode)
301 elif state == JobState.FAILED:
302 self.failedNodes.add(qnode)
303 elif state == JobState.FAILED_DEP:
304 self.failedNodes.add(qnode)
305 elif state == JobState.TIMED_OUT:
306 self.failedNodes.add(qnode)
307 self.timedOutNodes.add(qnode)
308 else:
309 raise ValueError(f"Unexpected state value: {state}")
311 def cleanup(self) -> None:
312 """Do periodic cleanup for jobs that did not finish correctly.
314 If timed out jobs are killed but take too long to stop then regular
315 cleanup will not work for them. Here we check all timed out jobs
316 periodically and do cleanup if they managed to die by this time.
317 """
318 for job in self.jobs:
319 if job.state == JobState.TIMED_OUT and job.process is not None:
320 job.cleanup()
323class MPGraphExecutorError(Exception):
324 """Exception class for errors raised by MPGraphExecutor."""
326 pass
329class MPTimeoutError(MPGraphExecutorError):
330 """Exception raised when task execution times out."""
332 pass
335class MPGraphExecutor(QuantumGraphExecutor):
336 """Implementation of QuantumGraphExecutor using same-host multiprocess
337 execution of Quanta.
339 Parameters
340 ----------
341 numProc : `int`
342 Number of processes to use for executing tasks.
343 timeout : `float`
344 Time in seconds to wait for tasks to finish.
345 quantumExecutor : `QuantumExecutor`
346 Executor for single quantum. For multiprocess-style execution when
347 ``numProc`` is greater than one this instance must support pickle.
348 startMethod : `str`, optional
349 Start method from `multiprocessing` module, `None` selects the best
350 one for current platform.
351 failFast : `bool`, optional
352 If set to ``True`` then stop processing on first error from any task.
353 pdb : `str`, optional
354 Debugger to import and use (via the ``post_mortem`` function) in the
355 event of an exception.
356 executionGraphFixup : `ExecutionGraphFixup`, optional
357 Instance used for modification of execution graph.
358 """
360 def __init__(
361 self,
362 numProc: int,
363 timeout: float,
364 quantumExecutor: QuantumExecutor,
365 *,
366 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
367 failFast: bool = False,
368 pdb: Optional[str] = None,
369 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
370 ):
371 self.numProc = numProc
372 self.timeout = timeout
373 self.quantumExecutor = quantumExecutor
374 self.failFast = failFast
375 self.pdb = pdb
376 self.executionGraphFixup = executionGraphFixup
377 self.report: Optional[Report] = None
379 # We set default start method as spawn for MacOS and fork for Linux;
380 # None for all other platforms to use multiprocessing default.
381 if startMethod is None:
382 methods = dict(linux="fork", darwin="spawn")
383 startMethod = methods.get(sys.platform) # type: ignore
384 self.startMethod = startMethod
386 def execute(self, graph: QuantumGraph) -> None:
387 # Docstring inherited from QuantumGraphExecutor.execute
388 graph = self._fixupQuanta(graph)
389 self.report = Report()
390 try:
391 if self.numProc > 1:
392 self._executeQuantaMP(graph, self.report)
393 else:
394 self._executeQuantaInProcess(graph, self.report)
395 except Exception as exc:
396 self.report.set_exception(exc)
397 raise
399 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
400 """Call fixup code to modify execution graph.
402 Parameters
403 ----------
404 graph : `QuantumGraph`
405 `QuantumGraph` to modify
407 Returns
408 -------
409 graph : `QuantumGraph`
410 Modified `QuantumGraph`.
412 Raises
413 ------
414 MPGraphExecutorError
415 Raised if execution graph cannot be ordered after modification,
416 i.e. it has dependency cycles.
417 """
418 if not self.executionGraphFixup:
419 return graph
421 _LOG.debug("Call execution graph fixup method")
422 graph = self.executionGraphFixup.fixupQuanta(graph)
424 # Detect if there is now a cycle created within the graph
425 if graph.findCycle():
426 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
428 return graph
430 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
431 """Execute all Quanta in current process.
433 Parameters
434 ----------
435 graph : `QuantumGraph`
436 `QuantumGraph` that is to be executed
437 report : `Report`
438 Object for reporting execution status.
439 """
440 successCount, totalCount = 0, len(graph)
441 failedNodes: set[QuantumNode] = set()
442 for qnode in graph:
444 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
446 # Any failed inputs mean that the quantum has to be skipped.
447 inputNodes = graph.determineInputsToQuantumNode(qnode)
448 if inputNodes & failedNodes:
449 _LOG.error(
450 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
451 qnode.taskDef,
452 qnode.quantum.dataId,
453 )
454 failedNodes.add(qnode)
455 failed_quantum_report = QuantumReport(
456 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
457 )
458 report.quantaReports.append(failed_quantum_report)
459 continue
461 _LOG.debug("Executing %s", qnode)
462 try:
463 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum)
464 successCount += 1
465 except Exception as exc:
466 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
467 _LOG.error(
468 "Task <%s dataId=%s> failed; dropping into pdb.",
469 qnode.taskDef,
470 qnode.quantum.dataId,
471 exc_info=exc,
472 )
473 try:
474 pdb = importlib.import_module(self.pdb)
475 except ImportError as imp_exc:
476 raise MPGraphExecutorError(
477 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
478 ) from exc
479 if not hasattr(pdb, "post_mortem"):
480 raise MPGraphExecutorError(
481 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
482 ) from exc
483 pdb.post_mortem(exc.__traceback__)
484 failedNodes.add(qnode)
485 report.status = ExecutionStatus.FAILURE
486 if self.failFast:
487 raise MPGraphExecutorError(
488 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
489 ) from exc
490 else:
491 # Note that there could be exception safety issues, which
492 # we presently ignore.
493 _LOG.error(
494 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
495 qnode.taskDef,
496 qnode.quantum.dataId,
497 exc_info=exc,
498 )
499 finally:
500 # sqlalchemy has some objects that can last until a garbage
501 # collection cycle is run, which can happen at unpredictable
502 # times, run a collection loop here explicitly.
503 gc.collect()
505 quantum_report = self.quantumExecutor.getReport()
506 if quantum_report:
507 report.quantaReports.append(quantum_report)
509 _LOG.info(
510 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
511 successCount,
512 len(failedNodes),
513 totalCount - successCount - len(failedNodes),
514 totalCount,
515 )
517 # Raise an exception if there were any failures.
518 if failedNodes:
519 raise MPGraphExecutorError("One or more tasks failed during execution.")
521 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
522 """Execute all Quanta in separate processes.
524 Parameters
525 ----------
526 graph : `QuantumGraph`
527 `QuantumGraph` that is to be executed.
528 report : `Report`
529 Object for reporting execution status.
530 """
532 disable_implicit_threading() # To prevent thread contention
534 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
536 # re-pack input quantum data into jobs list
537 jobs = _JobList(graph)
539 # check that all tasks can run in sub-process
540 for job in jobs.jobs:
541 taskDef = job.qnode.taskDef
542 if not taskDef.taskClass.canMultiprocess:
543 raise MPGraphExecutorError(
544 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
545 )
547 finishedCount, failedCount = 0, 0
548 while jobs.pending or jobs.running:
550 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
551 _LOG.debug("#runningJobs: %s", len(jobs.running))
553 # See if any jobs have finished
554 for job in jobs.running:
555 assert job.process is not None, "Process cannot be None"
556 if not job.process.is_alive():
557 _LOG.debug("finished: %s", job)
558 # finished
559 exitcode = job.process.exitcode
560 quantum_report = job.report()
561 report.quantaReports.append(quantum_report)
562 if exitcode == 0:
563 jobs.setJobState(job, JobState.FINISHED)
564 job.cleanup()
565 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
566 else:
567 if job.terminated:
568 # Was killed due to timeout.
569 if report.status == ExecutionStatus.SUCCESS:
570 # Do not override global FAILURE status
571 report.status = ExecutionStatus.TIMEOUT
572 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
573 jobs.setJobState(job, JobState.TIMED_OUT)
574 else:
575 report.status = ExecutionStatus.FAILURE
576 # failMessage() has to be called before cleanup()
577 message = job.failMessage()
578 jobs.setJobState(job, JobState.FAILED)
580 job.cleanup()
581 _LOG.debug("failed: %s", job)
582 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
583 # stop all running jobs
584 for stopJob in jobs.running:
585 if stopJob is not job:
586 stopJob.stop()
587 if job.state is JobState.TIMED_OUT:
588 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
589 else:
590 raise MPGraphExecutorError(message)
591 else:
592 _LOG.error("%s; processing will continue for remaining tasks.", message)
593 else:
594 # check for timeout
595 now = time.time()
596 if now - job.started > self.timeout:
597 # Try to kill it, and there is a chance that it
598 # finishes successfully before it gets killed. Exit
599 # status is handled by the code above on next
600 # iteration.
601 _LOG.debug("Terminating job %s due to timeout", job)
602 job.stop()
604 # Fail jobs whose inputs failed, this may need several iterations
605 # if the order is not right, will be done in the next loop.
606 if jobs.failedNodes:
607 for job in jobs.pending:
608 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
609 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
610 if jobInputNodes & jobs.failedNodes:
611 quantum_report = QuantumReport(
612 status=ExecutionStatus.SKIPPED,
613 dataId=job.qnode.quantum.dataId,
614 taskLabel=job.qnode.taskDef.label,
615 )
616 report.quantaReports.append(quantum_report)
617 jobs.setJobState(job, JobState.FAILED_DEP)
618 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
620 # see if we can start more jobs
621 if len(jobs.running) < self.numProc:
622 for job in jobs.pending:
623 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
624 if jobInputNodes <= jobs.finishedNodes:
625 # all dependencies have completed, can start new job
626 if len(jobs.running) < self.numProc:
627 _LOG.debug("Submitting %s", job)
628 jobs.submit(job, self.quantumExecutor, self.startMethod)
629 if len(jobs.running) >= self.numProc:
630 # Cannot start any more jobs, wait until something
631 # finishes.
632 break
634 # Do cleanup for timed out jobs if necessary.
635 jobs.cleanup()
637 # Print progress message if something changed.
638 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
639 if (finishedCount, failedCount) != (newFinished, newFailed):
640 finishedCount, failedCount = newFinished, newFailed
641 totalCount = len(jobs.jobs)
642 _LOG.info(
643 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
644 finishedCount,
645 failedCount,
646 totalCount - finishedCount - failedCount,
647 totalCount,
648 )
650 # Here we want to wait until one of the running jobs completes
651 # but multiprocessing does not provide an API for that, for now
652 # just sleep a little bit and go back to the loop.
653 if jobs.running:
654 time.sleep(0.1)
656 if jobs.failedNodes:
657 # print list of failed jobs
658 _LOG.error("Failed jobs:")
659 for job in jobs.jobs:
660 if job.state != JobState.FINISHED:
661 _LOG.error(" - %s: %s", job.state.name, job)
663 # if any job failed raise an exception
664 if jobs.failedNodes == jobs.timedOutNodes:
665 raise MPTimeoutError("One or more tasks timed out during execution.")
666 else:
667 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
669 def getReport(self) -> Optional[Report]:
670 # Docstring inherited from base class
671 if self.report is None:
672 raise RuntimeError("getReport() called before execute()")
673 return self.report