Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 12%
302 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-25 02:09 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-25 02:09 -0700
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import time
34from collections.abc import Iterable
35from enum import Enum
36from typing import Literal, Optional
38from lsst.daf.butler.cli.cliLog import CliLog
39from lsst.pipe.base import InvalidQuantumError, TaskDef
40from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
41from lsst.utils.threads import disable_implicit_threading
43from .executionGraphFixup import ExecutionGraphFixup
44from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
45from .reports import ExecutionStatus, QuantumReport, Report
47_LOG = logging.getLogger(__name__)
50# Possible states for the executing task:
51# - PENDING: job has not started yet
52# - RUNNING: job is currently executing
53# - FINISHED: job finished successfully
54# - FAILED: job execution failed (process returned non-zero status)
55# - TIMED_OUT: job is killed due to too long execution time
56# - FAILED_DEP: one of the dependencies of this job has failed/timed out
57JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
60class _Job:
61 """Class representing a job running single task.
63 Parameters
64 ----------
65 qnode: `~lsst.pipe.base.QuantumNode`
66 Quantum and some associated information.
67 """
69 def __init__(self, qnode: QuantumNode):
70 self.qnode = qnode
71 self.process: Optional[multiprocessing.process.BaseProcess] = None
72 self._state = JobState.PENDING
73 self.started: float = 0.0
74 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
75 self._terminated = False
77 @property
78 def state(self) -> JobState:
79 """Job processing state (JobState)"""
80 return self._state
82 @property
83 def terminated(self) -> bool:
84 """Return True if job was killed by stop() method and negative exit
85 code is returned from child process. (`bool`)"""
86 if self._terminated:
87 assert self.process is not None, "Process must be started"
88 if self.process.exitcode is not None:
89 return self.process.exitcode < 0
90 return False
92 def start(
93 self,
94 quantumExecutor: QuantumExecutor,
95 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
96 ) -> None:
97 """Start process which runs the task.
99 Parameters
100 ----------
101 quantumExecutor : `QuantumExecutor`
102 Executor for single quantum.
103 startMethod : `str`, optional
104 Start method from `multiprocessing` module.
105 """
106 # Unpickling of quantum has to happen after butler/executor, this is
107 # why it is pickled manually here.
108 quantum_pickle = pickle.dumps(self.qnode.quantum)
109 taskDef = self.qnode.taskDef
110 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
111 logConfigState = CliLog.configState
113 mp_ctx = multiprocessing.get_context(startMethod)
114 self.process = mp_ctx.Process(
115 target=_Job._executeJob,
116 args=(quantumExecutor, taskDef, quantum_pickle, logConfigState, snd_conn),
117 name=f"task-{self.qnode.quantum.dataId}",
118 )
119 self.process.start()
120 self.started = time.time()
121 self._state = JobState.RUNNING
123 @staticmethod
124 def _executeJob(
125 quantumExecutor: QuantumExecutor,
126 taskDef: TaskDef,
127 quantum_pickle: bytes,
128 logConfigState: list,
129 snd_conn: multiprocessing.connection.Connection,
130 ) -> None:
131 """Execute a job with arguments.
133 Parameters
134 ----------
135 quantumExecutor : `QuantumExecutor`
136 Executor for single quantum.
137 taskDef : `bytes`
138 Task definition structure.
139 quantum_pickle : `bytes`
140 Quantum for this task execution in pickled form.
141 snd_conn : `multiprocessing.Connection`
142 Connection to send job report to parent process.
143 """
144 if logConfigState and not CliLog.configState:
145 # means that we are in a new spawned Python process and we have to
146 # re-initialize logging
147 CliLog.replayConfigState(logConfigState)
149 quantum = pickle.loads(quantum_pickle)
150 try:
151 quantumExecutor.execute(taskDef, quantum)
152 finally:
153 # If sending fails we do not want this new exception to be exposed.
154 try:
155 report = quantumExecutor.getReport()
156 snd_conn.send(report)
157 except Exception:
158 pass
160 def stop(self) -> None:
161 """Stop the process."""
162 assert self.process is not None, "Process must be started"
163 self.process.terminate()
164 # give it 1 second to finish or KILL
165 for i in range(10):
166 time.sleep(0.1)
167 if not self.process.is_alive():
168 break
169 else:
170 _LOG.debug("Killing process %s", self.process.name)
171 self.process.kill()
172 self._terminated = True
174 def cleanup(self) -> None:
175 """Release processes resources, has to be called for each finished
176 process.
177 """
178 if self.process and not self.process.is_alive():
179 self.process.close()
180 self.process = None
181 self._rcv_conn = None
183 def report(self) -> QuantumReport:
184 """Return task report, should be called after process finishes and
185 before cleanup().
186 """
187 assert self.process is not None, "Process must be started"
188 assert self._rcv_conn is not None, "Process must be started"
189 try:
190 report = self._rcv_conn.recv()
191 report.exitCode = self.process.exitcode
192 except Exception:
193 # Likely due to the process killed, but there may be other reasons.
194 # Exit code should not be None, this is to keep mypy happy.
195 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
196 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
197 report = QuantumReport.from_exit_code(
198 exitCode=exitcode,
199 dataId=self.qnode.quantum.dataId,
200 taskLabel=self.qnode.taskDef.label,
201 )
202 if self.terminated:
203 # Means it was killed, assume it's due to timeout
204 report.status = ExecutionStatus.TIMEOUT
205 return report
207 def failMessage(self) -> str:
208 """Return a message describing task failure"""
209 assert self.process is not None, "Process must be started"
210 assert self.process.exitcode is not None, "Process has to finish"
211 exitcode = self.process.exitcode
212 if exitcode < 0:
213 # Negative exit code means it is killed by signal
214 signum = -exitcode
215 msg = f"Task {self} failed, killed by signal {signum}"
216 # Just in case this is some very odd signal, expect ValueError
217 try:
218 strsignal = signal.strsignal(signum)
219 msg = f"{msg} ({strsignal})"
220 except ValueError:
221 pass
222 elif exitcode > 0:
223 msg = f"Task {self} failed, exit code={exitcode}"
224 else:
225 msg = ""
226 return msg
228 def __str__(self) -> str:
229 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
232class _JobList:
233 """Simple list of _Job instances with few convenience methods.
235 Parameters
236 ----------
237 iterable : iterable of `~lsst.pipe.base.QuantumNode`
238 Sequence of Quanta to execute. This has to be ordered according to
239 task dependencies.
240 """
242 def __init__(self, iterable: Iterable[QuantumNode]):
243 self.jobs = [_Job(qnode) for qnode in iterable]
244 self.pending = self.jobs[:]
245 self.running: list[_Job] = []
246 self.finishedNodes: set[QuantumNode] = set()
247 self.failedNodes: set[QuantumNode] = set()
248 self.timedOutNodes: set[QuantumNode] = set()
250 def submit(
251 self,
252 job: _Job,
253 quantumExecutor: QuantumExecutor,
254 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
255 ) -> None:
256 """Submit one more job for execution
258 Parameters
259 ----------
260 job : `_Job`
261 Job to submit.
262 quantumExecutor : `QuantumExecutor`
263 Executor for single quantum.
264 startMethod : `str`, optional
265 Start method from `multiprocessing` module.
266 """
267 # this will raise if job is not in pending list
268 self.pending.remove(job)
269 job.start(quantumExecutor, startMethod)
270 self.running.append(job)
272 def setJobState(self, job: _Job, state: JobState) -> None:
273 """Update job state.
275 Parameters
276 ----------
277 job : `_Job`
278 Job to submit.
279 state : `JobState`
280 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
281 FAILED_DEP state is acceptable.
282 """
283 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
284 assert state in allowedStates, f"State {state} not allowed here"
286 # remove job from pending/running lists
287 if job.state == JobState.PENDING:
288 self.pending.remove(job)
289 elif job.state == JobState.RUNNING:
290 self.running.remove(job)
292 qnode = job.qnode
293 # it should not be in any of these, but just in case
294 self.finishedNodes.discard(qnode)
295 self.failedNodes.discard(qnode)
296 self.timedOutNodes.discard(qnode)
298 job._state = state
299 if state == JobState.FINISHED:
300 self.finishedNodes.add(qnode)
301 elif state == JobState.FAILED:
302 self.failedNodes.add(qnode)
303 elif state == JobState.FAILED_DEP:
304 self.failedNodes.add(qnode)
305 elif state == JobState.TIMED_OUT:
306 self.failedNodes.add(qnode)
307 self.timedOutNodes.add(qnode)
308 else:
309 raise ValueError(f"Unexpected state value: {state}")
311 def cleanup(self) -> None:
312 """Do periodic cleanup for jobs that did not finish correctly.
314 If timed out jobs are killed but take too long to stop then regular
315 cleanup will not work for them. Here we check all timed out jobs
316 periodically and do cleanup if they managed to die by this time.
317 """
318 for job in self.jobs:
319 if job.state == JobState.TIMED_OUT and job.process is not None:
320 job.cleanup()
323class MPGraphExecutorError(Exception):
324 """Exception class for errors raised by MPGraphExecutor."""
326 pass
329class MPTimeoutError(MPGraphExecutorError):
330 """Exception raised when task execution times out."""
332 pass
335class MPGraphExecutor(QuantumGraphExecutor):
336 """Implementation of QuantumGraphExecutor using same-host multiprocess
337 execution of Quanta.
339 Parameters
340 ----------
341 numProc : `int`
342 Number of processes to use for executing tasks.
343 timeout : `float`
344 Time in seconds to wait for tasks to finish.
345 quantumExecutor : `QuantumExecutor`
346 Executor for single quantum. For multiprocess-style execution when
347 ``numProc`` is greater than one this instance must support pickle.
348 startMethod : `str`, optional
349 Start method from `multiprocessing` module, `None` selects the best
350 one for current platform.
351 failFast : `bool`, optional
352 If set to ``True`` then stop processing on first error from any task.
353 pdb : `str`, optional
354 Debugger to import and use (via the ``post_mortem`` function) in the
355 event of an exception.
356 executionGraphFixup : `ExecutionGraphFixup`, optional
357 Instance used for modification of execution graph.
358 """
360 def __init__(
361 self,
362 numProc: int,
363 timeout: float,
364 quantumExecutor: QuantumExecutor,
365 *,
366 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
367 failFast: bool = False,
368 pdb: Optional[str] = None,
369 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
370 ):
371 self.numProc = numProc
372 self.timeout = timeout
373 self.quantumExecutor = quantumExecutor
374 self.failFast = failFast
375 self.pdb = pdb
376 self.executionGraphFixup = executionGraphFixup
377 self.report: Optional[Report] = None
379 # We set default start method as spawn for MacOS and fork for Linux;
380 # None for all other platforms to use multiprocessing default.
381 if startMethod is None:
382 methods = dict(linux="fork", darwin="spawn")
383 startMethod = methods.get(sys.platform) # type: ignore
384 self.startMethod = startMethod
386 def execute(self, graph: QuantumGraph) -> None:
387 # Docstring inherited from QuantumGraphExecutor.execute
388 graph = self._fixupQuanta(graph)
389 self.report = Report()
390 try:
391 if self.numProc > 1:
392 self._executeQuantaMP(graph, self.report)
393 else:
394 self._executeQuantaInProcess(graph, self.report)
395 except Exception as exc:
396 self.report.set_exception(exc)
397 raise
399 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
400 """Call fixup code to modify execution graph.
402 Parameters
403 ----------
404 graph : `QuantumGraph`
405 `QuantumGraph` to modify
407 Returns
408 -------
409 graph : `QuantumGraph`
410 Modified `QuantumGraph`.
412 Raises
413 ------
414 MPGraphExecutorError
415 Raised if execution graph cannot be ordered after modification,
416 i.e. it has dependency cycles.
417 """
418 if not self.executionGraphFixup:
419 return graph
421 _LOG.debug("Call execution graph fixup method")
422 graph = self.executionGraphFixup.fixupQuanta(graph)
424 # Detect if there is now a cycle created within the graph
425 if graph.findCycle():
426 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
428 return graph
430 def _executeQuantaInProcess(self, graph: QuantumGraph, report: Report) -> None:
431 """Execute all Quanta in current process.
433 Parameters
434 ----------
435 graph : `QuantumGraph`
436 `QuantumGraph` that is to be executed
437 report : `Report`
438 Object for reporting execution status.
439 """
440 successCount, totalCount = 0, len(graph)
441 failedNodes: set[QuantumNode] = set()
442 for qnode in graph:
443 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
445 # Any failed inputs mean that the quantum has to be skipped.
446 inputNodes = graph.determineInputsToQuantumNode(qnode)
447 if inputNodes & failedNodes:
448 _LOG.error(
449 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
450 qnode.taskDef,
451 qnode.quantum.dataId,
452 )
453 failedNodes.add(qnode)
454 failed_quantum_report = QuantumReport(
455 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
456 )
457 report.quantaReports.append(failed_quantum_report)
458 continue
460 _LOG.debug("Executing %s", qnode)
461 try:
462 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum)
463 successCount += 1
464 except Exception as exc:
465 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
466 _LOG.error(
467 "Task <%s dataId=%s> failed; dropping into pdb.",
468 qnode.taskDef,
469 qnode.quantum.dataId,
470 exc_info=exc,
471 )
472 try:
473 pdb = importlib.import_module(self.pdb)
474 except ImportError as imp_exc:
475 raise MPGraphExecutorError(
476 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
477 ) from exc
478 if not hasattr(pdb, "post_mortem"):
479 raise MPGraphExecutorError(
480 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
481 ) from exc
482 pdb.post_mortem(exc.__traceback__)
483 failedNodes.add(qnode)
484 report.status = ExecutionStatus.FAILURE
485 if self.failFast:
486 raise MPGraphExecutorError(
487 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
488 ) from exc
489 else:
490 # Note that there could be exception safety issues, which
491 # we presently ignore.
492 _LOG.error(
493 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
494 qnode.taskDef,
495 qnode.quantum.dataId,
496 exc_info=exc,
497 )
498 finally:
499 # sqlalchemy has some objects that can last until a garbage
500 # collection cycle is run, which can happen at unpredictable
501 # times, run a collection loop here explicitly.
502 gc.collect()
504 quantum_report = self.quantumExecutor.getReport()
505 if quantum_report:
506 report.quantaReports.append(quantum_report)
508 _LOG.info(
509 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
510 successCount,
511 len(failedNodes),
512 totalCount - successCount - len(failedNodes),
513 totalCount,
514 )
516 # Raise an exception if there were any failures.
517 if failedNodes:
518 raise MPGraphExecutorError("One or more tasks failed during execution.")
520 def _executeQuantaMP(self, graph: QuantumGraph, report: Report) -> None:
521 """Execute all Quanta in separate processes.
523 Parameters
524 ----------
525 graph : `QuantumGraph`
526 `QuantumGraph` that is to be executed.
527 report : `Report`
528 Object for reporting execution status.
529 """
531 disable_implicit_threading() # To prevent thread contention
533 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
535 # re-pack input quantum data into jobs list
536 jobs = _JobList(graph)
538 # check that all tasks can run in sub-process
539 for job in jobs.jobs:
540 taskDef = job.qnode.taskDef
541 if not taskDef.taskClass.canMultiprocess:
542 raise MPGraphExecutorError(
543 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
544 )
546 finishedCount, failedCount = 0, 0
547 while jobs.pending or jobs.running:
548 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
549 _LOG.debug("#runningJobs: %s", len(jobs.running))
551 # See if any jobs have finished
552 for job in jobs.running:
553 assert job.process is not None, "Process cannot be None"
554 if not job.process.is_alive():
555 _LOG.debug("finished: %s", job)
556 # finished
557 exitcode = job.process.exitcode
558 quantum_report = job.report()
559 report.quantaReports.append(quantum_report)
560 if exitcode == 0:
561 jobs.setJobState(job, JobState.FINISHED)
562 job.cleanup()
563 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
564 else:
565 if job.terminated:
566 # Was killed due to timeout.
567 if report.status == ExecutionStatus.SUCCESS:
568 # Do not override global FAILURE status
569 report.status = ExecutionStatus.TIMEOUT
570 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
571 jobs.setJobState(job, JobState.TIMED_OUT)
572 else:
573 report.status = ExecutionStatus.FAILURE
574 # failMessage() has to be called before cleanup()
575 message = job.failMessage()
576 jobs.setJobState(job, JobState.FAILED)
578 job.cleanup()
579 _LOG.debug("failed: %s", job)
580 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
581 # stop all running jobs
582 for stopJob in jobs.running:
583 if stopJob is not job:
584 stopJob.stop()
585 if job.state is JobState.TIMED_OUT:
586 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
587 else:
588 raise MPGraphExecutorError(message)
589 else:
590 _LOG.error("%s; processing will continue for remaining tasks.", message)
591 else:
592 # check for timeout
593 now = time.time()
594 if now - job.started > self.timeout:
595 # Try to kill it, and there is a chance that it
596 # finishes successfully before it gets killed. Exit
597 # status is handled by the code above on next
598 # iteration.
599 _LOG.debug("Terminating job %s due to timeout", job)
600 job.stop()
602 # Fail jobs whose inputs failed, this may need several iterations
603 # if the order is not right, will be done in the next loop.
604 if jobs.failedNodes:
605 for job in jobs.pending:
606 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
607 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
608 if jobInputNodes & jobs.failedNodes:
609 quantum_report = QuantumReport(
610 status=ExecutionStatus.SKIPPED,
611 dataId=job.qnode.quantum.dataId,
612 taskLabel=job.qnode.taskDef.label,
613 )
614 report.quantaReports.append(quantum_report)
615 jobs.setJobState(job, JobState.FAILED_DEP)
616 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
618 # see if we can start more jobs
619 if len(jobs.running) < self.numProc:
620 for job in jobs.pending:
621 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
622 if jobInputNodes <= jobs.finishedNodes:
623 # all dependencies have completed, can start new job
624 if len(jobs.running) < self.numProc:
625 _LOG.debug("Submitting %s", job)
626 jobs.submit(job, self.quantumExecutor, self.startMethod)
627 if len(jobs.running) >= self.numProc:
628 # Cannot start any more jobs, wait until something
629 # finishes.
630 break
632 # Do cleanup for timed out jobs if necessary.
633 jobs.cleanup()
635 # Print progress message if something changed.
636 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
637 if (finishedCount, failedCount) != (newFinished, newFailed):
638 finishedCount, failedCount = newFinished, newFailed
639 totalCount = len(jobs.jobs)
640 _LOG.info(
641 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
642 finishedCount,
643 failedCount,
644 totalCount - finishedCount - failedCount,
645 totalCount,
646 )
648 # Here we want to wait until one of the running jobs completes
649 # but multiprocessing does not provide an API for that, for now
650 # just sleep a little bit and go back to the loop.
651 if jobs.running:
652 time.sleep(0.1)
654 if jobs.failedNodes:
655 # print list of failed jobs
656 _LOG.error("Failed jobs:")
657 for job in jobs.jobs:
658 if job.state != JobState.FINISHED:
659 _LOG.error(" - %s: %s", job.state.name, job)
661 # if any job failed raise an exception
662 if jobs.failedNodes == jobs.timedOutNodes:
663 raise MPTimeoutError("One or more tasks timed out during execution.")
664 else:
665 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
667 def getReport(self) -> Optional[Report]:
668 # Docstring inherited from base class
669 if self.report is None:
670 raise RuntimeError("getReport() called before execute()")
671 return self.report