Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 12%
305 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 02:43 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 02:43 -0800
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import time
34from enum import Enum
35from typing import TYPE_CHECKING, Iterable, Literal, Optional
37from lsst.daf.butler.cli.cliLog import CliLog
38from lsst.pipe.base import InvalidQuantumError, TaskDef
39from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
40from lsst.utils.threads import disable_implicit_threading
42from .executionGraphFixup import ExecutionGraphFixup
43from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
44from .reports import ExecutionStatus, QuantumReport, Report
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 from lsst.daf.butler import Butler
49_LOG = logging.getLogger(__name__)
52# Possible states for the executing task:
53# - PENDING: job has not started yet
54# - RUNNING: job is currently executing
55# - FINISHED: job finished successfully
56# - FAILED: job execution failed (process returned non-zero status)
57# - TIMED_OUT: job is killed due to too long execution time
58# - FAILED_DEP: one of the dependencies of this job has failed/timed out
59JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
62class _Job:
63 """Class representing a job running single task.
65 Parameters
66 ----------
67 qnode: `~lsst.pipe.base.QuantumNode`
68 Quantum and some associated information.
69 """
71 def __init__(self, qnode: QuantumNode):
72 self.qnode = qnode
73 self.process: Optional[multiprocessing.process.BaseProcess] = None
74 self._state = JobState.PENDING
75 self.started: float = 0.0
76 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
77 self._terminated = False
79 @property
80 def state(self) -> JobState:
81 """Job processing state (JobState)"""
82 return self._state
84 @property
85 def terminated(self) -> bool:
86 """Return True if job was killed by stop() method and negative exit
87 code is returned from child process. (`bool`)"""
88 if self._terminated:
89 assert self.process is not None, "Process must be started"
90 if self.process.exitcode is not None:
91 return self.process.exitcode < 0
92 return False
94 def start(
95 self,
96 butler: Butler,
97 quantumExecutor: QuantumExecutor,
98 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
99 ) -> None:
100 """Start process which runs the task.
102 Parameters
103 ----------
104 butler : `lsst.daf.butler.Butler`
105 Data butler instance.
106 quantumExecutor : `QuantumExecutor`
107 Executor for single quantum.
108 startMethod : `str`, optional
109 Start method from `multiprocessing` module.
110 """
111 # Unpickling of quantum has to happen after butler, this is why
112 # it is pickled manually here.
113 quantum_pickle = pickle.dumps(self.qnode.quantum)
114 taskDef = self.qnode.taskDef
115 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
116 logConfigState = CliLog.configState
117 mp_ctx = multiprocessing.get_context(startMethod)
118 self.process = mp_ctx.Process(
119 target=_Job._executeJob,
120 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn),
121 name=f"task-{self.qnode.quantum.dataId}",
122 )
123 self.process.start()
124 self.started = time.time()
125 self._state = JobState.RUNNING
127 @staticmethod
128 def _executeJob(
129 quantumExecutor: QuantumExecutor,
130 taskDef: TaskDef,
131 quantum_pickle: bytes,
132 butler: Butler,
133 logConfigState: list,
134 snd_conn: multiprocessing.connection.Connection,
135 ) -> None:
136 """Execute a job with arguments.
138 Parameters
139 ----------
140 quantumExecutor : `QuantumExecutor`
141 Executor for single quantum.
142 taskDef : `bytes`
143 Task definition structure.
144 quantum_pickle : `bytes`
145 Quantum for this task execution in pickled form.
146 butler : `lss.daf.butler.Butler`
147 Data butler instance.
148 snd_conn : `multiprocessing.Connection`
149 Connection to send job report to parent process.
150 """
151 if logConfigState and not CliLog.configState:
152 # means that we are in a new spawned Python process and we have to
153 # re-initialize logging
154 CliLog.replayConfigState(logConfigState)
156 # have to reset connection pool to avoid sharing connections
157 if butler is not None:
158 butler.registry.resetConnectionPool()
160 quantum = pickle.loads(quantum_pickle)
161 try:
162 quantumExecutor.execute(taskDef, quantum, butler)
163 finally:
164 # If sending fails we do not want this new exception to be exposed.
165 try:
166 report = quantumExecutor.getReport()
167 snd_conn.send(report)
168 except Exception:
169 pass
171 def stop(self) -> None:
172 """Stop the process."""
173 assert self.process is not None, "Process must be started"
174 self.process.terminate()
175 # give it 1 second to finish or KILL
176 for i in range(10):
177 time.sleep(0.1)
178 if not self.process.is_alive():
179 break
180 else:
181 _LOG.debug("Killing process %s", self.process.name)
182 self.process.kill()
183 self._terminated = True
185 def cleanup(self) -> None:
186 """Release processes resources, has to be called for each finished
187 process.
188 """
189 if self.process and not self.process.is_alive():
190 self.process.close()
191 self.process = None
192 self._rcv_conn = None
194 def report(self) -> QuantumReport:
195 """Return task report, should be called after process finishes and
196 before cleanup().
197 """
198 assert self.process is not None, "Process must be started"
199 assert self._rcv_conn is not None, "Process must be started"
200 try:
201 report = self._rcv_conn.recv()
202 report.exitCode = self.process.exitcode
203 except Exception:
204 # Likely due to the process killed, but there may be other reasons.
205 # Exit code should not be None, this is to keep mypy happy.
206 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
207 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
208 report = QuantumReport.from_exit_code(
209 exitCode=exitcode,
210 dataId=self.qnode.quantum.dataId,
211 taskLabel=self.qnode.taskDef.label,
212 )
213 if self.terminated:
214 # Means it was killed, assume it's due to timeout
215 report.status = ExecutionStatus.TIMEOUT
216 return report
218 def failMessage(self) -> str:
219 """Return a message describing task failure"""
220 assert self.process is not None, "Process must be started"
221 assert self.process.exitcode is not None, "Process has to finish"
222 exitcode = self.process.exitcode
223 if exitcode < 0:
224 # Negative exit code means it is killed by signal
225 signum = -exitcode
226 msg = f"Task {self} failed, killed by signal {signum}"
227 # Just in case this is some very odd signal, expect ValueError
228 try:
229 strsignal = signal.strsignal(signum)
230 msg = f"{msg} ({strsignal})"
231 except ValueError:
232 pass
233 elif exitcode > 0:
234 msg = f"Task {self} failed, exit code={exitcode}"
235 else:
236 msg = ""
237 return msg
239 def __str__(self) -> str:
240 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
243class _JobList:
244 """Simple list of _Job instances with few convenience methods.
246 Parameters
247 ----------
248 iterable : iterable of `~lsst.pipe.base.QuantumNode`
249 Sequence of Quanta to execute. This has to be ordered according to
250 task dependencies.
251 """
253 def __init__(self, iterable: Iterable[QuantumNode]):
254 self.jobs = [_Job(qnode) for qnode in iterable]
255 self.pending = self.jobs[:]
256 self.running: list[_Job] = []
257 self.finishedNodes: set[QuantumNode] = set()
258 self.failedNodes: set[QuantumNode] = set()
259 self.timedOutNodes: set[QuantumNode] = set()
261 def submit(
262 self,
263 job: _Job,
264 butler: Butler,
265 quantumExecutor: QuantumExecutor,
266 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
267 ) -> None:
268 """Submit one more job for execution
270 Parameters
271 ----------
272 job : `_Job`
273 Job to submit.
274 butler : `lsst.daf.butler.Butler`
275 Data butler instance.
276 quantumExecutor : `QuantumExecutor`
277 Executor for single quantum.
278 startMethod : `str`, optional
279 Start method from `multiprocessing` module.
280 """
281 # this will raise if job is not in pending list
282 self.pending.remove(job)
283 job.start(butler, quantumExecutor, startMethod)
284 self.running.append(job)
286 def setJobState(self, job: _Job, state: JobState) -> None:
287 """Update job state.
289 Parameters
290 ----------
291 job : `_Job`
292 Job to submit.
293 state : `JobState`
294 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
295 FAILED_DEP state is acceptable.
296 """
297 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
298 assert state in allowedStates, f"State {state} not allowed here"
300 # remove job from pending/running lists
301 if job.state == JobState.PENDING:
302 self.pending.remove(job)
303 elif job.state == JobState.RUNNING:
304 self.running.remove(job)
306 qnode = job.qnode
307 # it should not be in any of these, but just in case
308 self.finishedNodes.discard(qnode)
309 self.failedNodes.discard(qnode)
310 self.timedOutNodes.discard(qnode)
312 job._state = state
313 if state == JobState.FINISHED:
314 self.finishedNodes.add(qnode)
315 elif state == JobState.FAILED:
316 self.failedNodes.add(qnode)
317 elif state == JobState.FAILED_DEP:
318 self.failedNodes.add(qnode)
319 elif state == JobState.TIMED_OUT:
320 self.failedNodes.add(qnode)
321 self.timedOutNodes.add(qnode)
322 else:
323 raise ValueError(f"Unexpected state value: {state}")
325 def cleanup(self) -> None:
326 """Do periodic cleanup for jobs that did not finish correctly.
328 If timed out jobs are killed but take too long to stop then regular
329 cleanup will not work for them. Here we check all timed out jobs
330 periodically and do cleanup if they managed to die by this time.
331 """
332 for job in self.jobs:
333 if job.state == JobState.TIMED_OUT and job.process is not None:
334 job.cleanup()
337class MPGraphExecutorError(Exception):
338 """Exception class for errors raised by MPGraphExecutor."""
340 pass
343class MPTimeoutError(MPGraphExecutorError):
344 """Exception raised when task execution times out."""
346 pass
349class MPGraphExecutor(QuantumGraphExecutor):
350 """Implementation of QuantumGraphExecutor using same-host multiprocess
351 execution of Quanta.
353 Parameters
354 ----------
355 numProc : `int`
356 Number of processes to use for executing tasks.
357 timeout : `float`
358 Time in seconds to wait for tasks to finish.
359 quantumExecutor : `QuantumExecutor`
360 Executor for single quantum. For multiprocess-style execution when
361 ``numProc`` is greater than one this instance must support pickle.
362 startMethod : `str`, optional
363 Start method from `multiprocessing` module, `None` selects the best
364 one for current platform.
365 failFast : `bool`, optional
366 If set to ``True`` then stop processing on first error from any task.
367 pdb : `str`, optional
368 Debugger to import and use (via the ``post_mortem`` function) in the
369 event of an exception.
370 executionGraphFixup : `ExecutionGraphFixup`, optional
371 Instance used for modification of execution graph.
372 """
374 def __init__(
375 self,
376 numProc: int,
377 timeout: float,
378 quantumExecutor: QuantumExecutor,
379 *,
380 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
381 failFast: bool = False,
382 pdb: Optional[str] = None,
383 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
384 ):
385 self.numProc = numProc
386 self.timeout = timeout
387 self.quantumExecutor = quantumExecutor
388 self.failFast = failFast
389 self.pdb = pdb
390 self.executionGraphFixup = executionGraphFixup
391 self.report: Optional[Report] = None
393 # We set default start method as spawn for MacOS and fork for Linux;
394 # None for all other platforms to use multiprocessing default.
395 if startMethod is None:
396 methods = dict(linux="fork", darwin="spawn")
397 startMethod = methods.get(sys.platform) # type: ignore
398 self.startMethod = startMethod
400 def execute(self, graph: QuantumGraph, butler: Butler) -> None:
401 # Docstring inherited from QuantumGraphExecutor.execute
402 graph = self._fixupQuanta(graph)
403 self.report = Report()
404 try:
405 if self.numProc > 1:
406 self._executeQuantaMP(graph, butler, self.report)
407 else:
408 self._executeQuantaInProcess(graph, butler, self.report)
409 except Exception as exc:
410 self.report.set_exception(exc)
411 raise
413 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
414 """Call fixup code to modify execution graph.
416 Parameters
417 ----------
418 graph : `QuantumGraph`
419 `QuantumGraph` to modify
421 Returns
422 -------
423 graph : `QuantumGraph`
424 Modified `QuantumGraph`.
426 Raises
427 ------
428 MPGraphExecutorError
429 Raised if execution graph cannot be ordered after modification,
430 i.e. it has dependency cycles.
431 """
432 if not self.executionGraphFixup:
433 return graph
435 _LOG.debug("Call execution graph fixup method")
436 graph = self.executionGraphFixup.fixupQuanta(graph)
438 # Detect if there is now a cycle created within the graph
439 if graph.findCycle():
440 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
442 return graph
444 def _executeQuantaInProcess(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
445 """Execute all Quanta in current process.
447 Parameters
448 ----------
449 graph : `QuantumGraph`
450 `QuantumGraph` that is to be executed
451 butler : `lsst.daf.butler.Butler`
452 Data butler instance
453 report : `Report`
454 Object for reporting execution status.
455 """
456 successCount, totalCount = 0, len(graph)
457 failedNodes: set[QuantumNode] = set()
458 for qnode in graph:
460 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
462 # Any failed inputs mean that the quantum has to be skipped.
463 inputNodes = graph.determineInputsToQuantumNode(qnode)
464 if inputNodes & failedNodes:
465 _LOG.error(
466 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
467 qnode.taskDef,
468 qnode.quantum.dataId,
469 )
470 failedNodes.add(qnode)
471 failed_quantum_report = QuantumReport(
472 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
473 )
474 report.quantaReports.append(failed_quantum_report)
475 continue
477 _LOG.debug("Executing %s", qnode)
478 try:
479 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler)
480 successCount += 1
481 except Exception as exc:
482 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
483 _LOG.error(
484 "Task <%s dataId=%s> failed; dropping into pdb.",
485 qnode.taskDef,
486 qnode.quantum.dataId,
487 exc_info=exc,
488 )
489 try:
490 pdb = importlib.import_module(self.pdb)
491 except ImportError as imp_exc:
492 raise MPGraphExecutorError(
493 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
494 ) from exc
495 if not hasattr(pdb, "post_mortem"):
496 raise MPGraphExecutorError(
497 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
498 ) from exc
499 pdb.post_mortem(exc.__traceback__)
500 failedNodes.add(qnode)
501 report.status = ExecutionStatus.FAILURE
502 if self.failFast:
503 raise MPGraphExecutorError(
504 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
505 ) from exc
506 else:
507 # Note that there could be exception safety issues, which
508 # we presently ignore.
509 _LOG.error(
510 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
511 qnode.taskDef,
512 qnode.quantum.dataId,
513 exc_info=exc,
514 )
515 finally:
516 # sqlalchemy has some objects that can last until a garbage
517 # collection cycle is run, which can happen at unpredictable
518 # times, run a collection loop here explicitly.
519 gc.collect()
521 quantum_report = self.quantumExecutor.getReport()
522 if quantum_report:
523 report.quantaReports.append(quantum_report)
525 _LOG.info(
526 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
527 successCount,
528 len(failedNodes),
529 totalCount - successCount - len(failedNodes),
530 totalCount,
531 )
533 # Raise an exception if there were any failures.
534 if failedNodes:
535 raise MPGraphExecutorError("One or more tasks failed during execution.")
537 def _executeQuantaMP(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
538 """Execute all Quanta in separate processes.
540 Parameters
541 ----------
542 graph : `QuantumGraph`
543 `QuantumGraph` that is to be executed.
544 butler : `lsst.daf.butler.Butler`
545 Data butler instance
546 report : `Report`
547 Object for reporting execution status.
548 """
550 disable_implicit_threading() # To prevent thread contention
552 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
554 # re-pack input quantum data into jobs list
555 jobs = _JobList(graph)
557 # check that all tasks can run in sub-process
558 for job in jobs.jobs:
559 taskDef = job.qnode.taskDef
560 if not taskDef.taskClass.canMultiprocess:
561 raise MPGraphExecutorError(
562 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
563 )
565 finishedCount, failedCount = 0, 0
566 while jobs.pending or jobs.running:
568 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
569 _LOG.debug("#runningJobs: %s", len(jobs.running))
571 # See if any jobs have finished
572 for job in jobs.running:
573 assert job.process is not None, "Process cannot be None"
574 if not job.process.is_alive():
575 _LOG.debug("finished: %s", job)
576 # finished
577 exitcode = job.process.exitcode
578 quantum_report = job.report()
579 report.quantaReports.append(quantum_report)
580 if exitcode == 0:
581 jobs.setJobState(job, JobState.FINISHED)
582 job.cleanup()
583 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
584 else:
585 if job.terminated:
586 # Was killed due to timeout.
587 if report.status == ExecutionStatus.SUCCESS:
588 # Do not override global FAILURE status
589 report.status = ExecutionStatus.TIMEOUT
590 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
591 jobs.setJobState(job, JobState.TIMED_OUT)
592 else:
593 report.status = ExecutionStatus.FAILURE
594 # failMessage() has to be called before cleanup()
595 message = job.failMessage()
596 jobs.setJobState(job, JobState.FAILED)
598 job.cleanup()
599 _LOG.debug("failed: %s", job)
600 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
601 # stop all running jobs
602 for stopJob in jobs.running:
603 if stopJob is not job:
604 stopJob.stop()
605 if job.state is JobState.TIMED_OUT:
606 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
607 else:
608 raise MPGraphExecutorError(message)
609 else:
610 _LOG.error("%s; processing will continue for remaining tasks.", message)
611 else:
612 # check for timeout
613 now = time.time()
614 if now - job.started > self.timeout:
615 # Try to kill it, and there is a chance that it
616 # finishes successfully before it gets killed. Exit
617 # status is handled by the code above on next
618 # iteration.
619 _LOG.debug("Terminating job %s due to timeout", job)
620 job.stop()
622 # Fail jobs whose inputs failed, this may need several iterations
623 # if the order is not right, will be done in the next loop.
624 if jobs.failedNodes:
625 for job in jobs.pending:
626 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
627 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
628 if jobInputNodes & jobs.failedNodes:
629 quantum_report = QuantumReport(
630 status=ExecutionStatus.SKIPPED,
631 dataId=job.qnode.quantum.dataId,
632 taskLabel=job.qnode.taskDef.label,
633 )
634 report.quantaReports.append(quantum_report)
635 jobs.setJobState(job, JobState.FAILED_DEP)
636 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
638 # see if we can start more jobs
639 if len(jobs.running) < self.numProc:
640 for job in jobs.pending:
641 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
642 if jobInputNodes <= jobs.finishedNodes:
643 # all dependencies have completed, can start new job
644 if len(jobs.running) < self.numProc:
645 _LOG.debug("Submitting %s", job)
646 jobs.submit(job, butler, self.quantumExecutor, self.startMethod)
647 if len(jobs.running) >= self.numProc:
648 # Cannot start any more jobs, wait until something
649 # finishes.
650 break
652 # Do cleanup for timed out jobs if necessary.
653 jobs.cleanup()
655 # Print progress message if something changed.
656 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
657 if (finishedCount, failedCount) != (newFinished, newFailed):
658 finishedCount, failedCount = newFinished, newFailed
659 totalCount = len(jobs.jobs)
660 _LOG.info(
661 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
662 finishedCount,
663 failedCount,
664 totalCount - finishedCount - failedCount,
665 totalCount,
666 )
668 # Here we want to wait until one of the running jobs completes
669 # but multiprocessing does not provide an API for that, for now
670 # just sleep a little bit and go back to the loop.
671 if jobs.running:
672 time.sleep(0.1)
674 if jobs.failedNodes:
675 # print list of failed jobs
676 _LOG.error("Failed jobs:")
677 for job in jobs.jobs:
678 if job.state != JobState.FINISHED:
679 _LOG.error(" - %s: %s", job.state.name, job)
681 # if any job failed raise an exception
682 if jobs.failedNodes == jobs.timedOutNodes:
683 raise MPTimeoutError("One or more tasks timed out during execution.")
684 else:
685 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
687 def getReport(self) -> Optional[Report]:
688 # Docstring inherited from base class
689 if self.report is None:
690 raise RuntimeError("getReport() called before execute()")
691 return self.report