Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 12%
305 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:42 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:42 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import time
34from enum import Enum
35from typing import TYPE_CHECKING, Iterable, Literal, Optional
37from lsst.daf.butler.cli.cliLog import CliLog
38from lsst.pipe.base import InvalidQuantumError, TaskDef
39from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
40from lsst.utils.threads import disable_implicit_threading
42from .executionGraphFixup import ExecutionGraphFixup
43from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
44from .reports import ExecutionStatus, QuantumReport, Report
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 from lsst.daf.butler import Butler
49_LOG = logging.getLogger(__name__)
52# Possible states for the executing task:
53# - PENDING: job has not started yet
54# - RUNNING: job is currently executing
55# - FINISHED: job finished successfully
56# - FAILED: job execution failed (process returned non-zero status)
57# - TIMED_OUT: job is killed due to too long execution time
58# - FAILED_DEP: one of the dependencies of this job has failed/timed out
59JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
62class _Job:
63 """Class representing a job running single task.
65 Parameters
66 ----------
67 qnode: `~lsst.pipe.base.QuantumNode`
68 Quantum and some associated information.
69 """
71 def __init__(self, qnode: QuantumNode):
72 self.qnode = qnode
73 self.process: Optional[multiprocessing.process.BaseProcess] = None
74 self._state = JobState.PENDING
75 self.started: float = 0.0
76 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
77 self._terminated = False
79 @property
80 def state(self) -> JobState:
81 """Job processing state (JobState)"""
82 return self._state
84 @property
85 def terminated(self) -> bool:
86 """Return True if job was killed by stop() method and negative exit
87 code is returned from child process. (`bool`)"""
88 if self._terminated:
89 assert self.process is not None, "Process must be started"
90 if self.process.exitcode is not None:
91 return self.process.exitcode < 0
92 return False
94 def start(
95 self,
96 butler: Butler,
97 quantumExecutor: QuantumExecutor,
98 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
99 ) -> None:
100 """Start process which runs the task.
102 Parameters
103 ----------
104 butler : `lsst.daf.butler.Butler`
105 Data butler instance.
106 quantumExecutor : `QuantumExecutor`
107 Executor for single quantum.
108 startMethod : `str`, optional
109 Start method from `multiprocessing` module.
110 """
111 # Unpickling of quantum has to happen after butler, this is why
112 # it is pickled manually here.
113 quantum_pickle = pickle.dumps(self.qnode.quantum)
114 taskDef = self.qnode.taskDef
115 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
116 logConfigState = CliLog.configState
117 mp_ctx = multiprocessing.get_context(startMethod)
118 self.process = mp_ctx.Process(
119 target=_Job._executeJob,
120 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn),
121 name=f"task-{self.qnode.quantum.dataId}",
122 )
123 self.process.start()
124 self.started = time.time()
125 self._state = JobState.RUNNING
127 @staticmethod
128 def _executeJob(
129 quantumExecutor: QuantumExecutor,
130 taskDef: TaskDef,
131 quantum_pickle: bytes,
132 butler: Butler,
133 logConfigState: list,
134 snd_conn: multiprocessing.connection.Connection,
135 ) -> None:
136 """Execute a job with arguments.
138 Parameters
139 ----------
140 quantumExecutor : `QuantumExecutor`
141 Executor for single quantum.
142 taskDef : `bytes`
143 Task definition structure.
144 quantum_pickle : `bytes`
145 Quantum for this task execution in pickled form.
146 butler : `lss.daf.butler.Butler`
147 Data butler instance.
148 snd_conn : `multiprocessing.Connection`
149 Connection to send job report to parent process.
150 """
151 if logConfigState and not CliLog.configState:
152 # means that we are in a new spawned Python process and we have to
153 # re-initialize logging
154 CliLog.replayConfigState(logConfigState)
156 # have to reset connection pool to avoid sharing connections
157 if butler is not None:
158 butler.registry.resetConnectionPool()
160 quantum = pickle.loads(quantum_pickle)
161 try:
162 quantumExecutor.execute(taskDef, quantum, butler)
163 finally:
164 # If sending fails we do not want this new exception to be exposed.
165 try:
166 report = quantumExecutor.getReport()
167 snd_conn.send(report)
168 except Exception:
169 pass
171 def stop(self) -> None:
172 """Stop the process."""
173 assert self.process is not None, "Process must be started"
174 self.process.terminate()
175 # give it 1 second to finish or KILL
176 for i in range(10):
177 time.sleep(0.1)
178 if not self.process.is_alive():
179 break
180 else:
181 _LOG.debug("Killing process %s", self.process.name)
182 self.process.kill()
183 self._terminated = True
185 def cleanup(self) -> None:
186 """Release processes resources, has to be called for each finished
187 process.
188 """
189 if self.process and not self.process.is_alive():
190 self.process.close()
191 self.process = None
192 self._rcv_conn = None
194 def report(self) -> QuantumReport:
195 """Return task report, should be called after process finishes and
196 before cleanup().
197 """
198 assert self.process is not None, "Process must be started"
199 assert self._rcv_conn is not None, "Process must be started"
200 try:
201 report = self._rcv_conn.recv()
202 report.exitCode = self.process.exitcode
203 except Exception:
204 # Likely due to the process killed, but there may be other reasons.
205 # Exit code should not be None, this is to keep mypy happy.
206 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
207 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
208 report = QuantumReport.from_exit_code(
209 exitCode=exitcode,
210 dataId=self.qnode.quantum.dataId,
211 taskLabel=self.qnode.taskDef.label,
212 )
213 if self.terminated:
214 # Means it was killed, assume it's due to timeout
215 report.status = ExecutionStatus.TIMEOUT
216 return report
218 def failMessage(self) -> str:
219 """Return a message describing task failure"""
220 assert self.process is not None, "Process must be started"
221 assert self.process.exitcode is not None, "Process has to finish"
222 exitcode = self.process.exitcode
223 if exitcode < 0:
224 # Negative exit code means it is killed by signal
225 signum = -exitcode
226 msg = f"Task {self} failed, killed by signal {signum}"
227 # Just in case this is some very odd signal, expect ValueError
228 try:
229 strsignal = signal.strsignal(signum)
230 msg = f"{msg} ({strsignal})"
231 except ValueError:
232 pass
233 elif exitcode > 0:
234 msg = f"Task {self} failed, exit code={exitcode}"
235 else:
236 msg = ""
237 return msg
239 def __str__(self) -> str:
240 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
243class _JobList:
244 """Simple list of _Job instances with few convenience methods.
246 Parameters
247 ----------
248 iterable : iterable of `~lsst.pipe.base.QuantumNode`
249 Sequence of Quanta to execute. This has to be ordered according to
250 task dependencies.
251 """
253 def __init__(self, iterable: Iterable[QuantumNode]):
254 self.jobs = [_Job(qnode) for qnode in iterable]
255 self.pending = self.jobs[:]
256 self.running: list[_Job] = []
257 self.finishedNodes: set[QuantumNode] = set()
258 self.failedNodes: set[QuantumNode] = set()
259 self.timedOutNodes: set[QuantumNode] = set()
261 def submit(
262 self,
263 job: _Job,
264 butler: Butler,
265 quantumExecutor: QuantumExecutor,
266 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
267 ) -> None:
268 """Submit one more job for execution
270 Parameters
271 ----------
272 job : `_Job`
273 Job to submit.
274 butler : `lsst.daf.butler.Butler`
275 Data butler instance.
276 quantumExecutor : `QuantumExecutor`
277 Executor for single quantum.
278 startMethod : `str`, optional
279 Start method from `multiprocessing` module.
280 """
281 # this will raise if job is not in pending list
282 self.pending.remove(job)
283 job.start(butler, quantumExecutor, startMethod)
284 self.running.append(job)
286 def setJobState(self, job: _Job, state: JobState) -> None:
287 """Update job state.
289 Parameters
290 ----------
291 job : `_Job`
292 Job to submit.
293 state : `JobState`
294 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
295 FAILED_DEP state is acceptable.
296 """
297 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
298 assert state in allowedStates, f"State {state} not allowed here"
300 # remove job from pending/running lists
301 if job.state == JobState.PENDING:
302 self.pending.remove(job)
303 elif job.state == JobState.RUNNING:
304 self.running.remove(job)
306 qnode = job.qnode
307 # it should not be in any of these, but just in case
308 self.finishedNodes.discard(qnode)
309 self.failedNodes.discard(qnode)
310 self.timedOutNodes.discard(qnode)
312 job._state = state
313 if state == JobState.FINISHED:
314 self.finishedNodes.add(qnode)
315 elif state == JobState.FAILED:
316 self.failedNodes.add(qnode)
317 elif state == JobState.FAILED_DEP:
318 self.failedNodes.add(qnode)
319 elif state == JobState.TIMED_OUT:
320 self.failedNodes.add(qnode)
321 self.timedOutNodes.add(qnode)
322 else:
323 raise ValueError(f"Unexpected state value: {state}")
325 def cleanup(self) -> None:
326 """Do periodic cleanup for jobs that did not finish correctly.
328 If timed out jobs are killed but take too long to stop then regular
329 cleanup will not work for them. Here we check all timed out jobs
330 periodically and do cleanup if they managed to die by this time.
331 """
332 for job in self.jobs:
333 if job.state == JobState.TIMED_OUT and job.process is not None:
334 job.cleanup()
337class MPGraphExecutorError(Exception):
338 """Exception class for errors raised by MPGraphExecutor."""
340 pass
343class MPTimeoutError(MPGraphExecutorError):
344 """Exception raised when task execution times out."""
346 pass
349class MPGraphExecutor(QuantumGraphExecutor):
350 """Implementation of QuantumGraphExecutor using same-host multiprocess
351 execution of Quanta.
353 Parameters
354 ----------
355 numProc : `int`
356 Number of processes to use for executing tasks.
357 timeout : `float`
358 Time in seconds to wait for tasks to finish.
359 quantumExecutor : `QuantumExecutor`
360 Executor for single quantum. For multiprocess-style execution when
361 ``numProc`` is greater than one this instance must support pickle.
362 startMethod : `str`, optional
363 Start method from `multiprocessing` module, `None` selects the best
364 one for current platform.
365 failFast : `bool`, optional
366 If set to ``True`` then stop processing on first error from any task.
367 pdb : `str`, optional
368 Debugger to import and use (via the ``post_mortem`` function) in the
369 event of an exception.
370 executionGraphFixup : `ExecutionGraphFixup`, optional
371 Instance used for modification of execution graph.
372 """
374 def __init__(
375 self,
376 numProc: int,
377 timeout: float,
378 quantumExecutor: QuantumExecutor,
379 *,
380 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
381 failFast: bool = False,
382 pdb: Optional[str] = None,
383 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
384 ):
385 self.numProc = numProc
386 self.timeout = timeout
387 self.quantumExecutor = quantumExecutor
388 self.failFast = failFast
389 self.pdb = pdb
390 self.executionGraphFixup = executionGraphFixup
391 self.report: Optional[Report] = None
393 # We set default start method as spawn for MacOS and fork for Linux;
394 # None for all other platforms to use multiprocessing default.
395 if startMethod is None:
396 methods = dict(linux="fork", darwin="spawn")
397 startMethod = methods.get(sys.platform) # type: ignore
398 self.startMethod = startMethod
400 def execute(self, graph: QuantumGraph, butler: Butler) -> None:
401 # Docstring inherited from QuantumGraphExecutor.execute
402 graph = self._fixupQuanta(graph)
403 self.report = Report()
404 try:
405 if self.numProc > 1:
406 self._executeQuantaMP(graph, butler, self.report)
407 else:
408 self._executeQuantaInProcess(graph, butler, self.report)
409 except Exception as exc:
410 self.report.set_exception(exc)
411 raise
413 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
414 """Call fixup code to modify execution graph.
416 Parameters
417 ----------
418 graph : `QuantumGraph`
419 `QuantumGraph` to modify
421 Returns
422 -------
423 graph : `QuantumGraph`
424 Modified `QuantumGraph`.
426 Raises
427 ------
428 MPGraphExecutorError
429 Raised if execution graph cannot be ordered after modification,
430 i.e. it has dependency cycles.
431 """
432 if not self.executionGraphFixup:
433 return graph
435 _LOG.debug("Call execution graph fixup method")
436 graph = self.executionGraphFixup.fixupQuanta(graph)
438 # Detect if there is now a cycle created within the graph
439 if graph.findCycle():
440 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
442 return graph
444 def _executeQuantaInProcess(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
445 """Execute all Quanta in current process.
447 Parameters
448 ----------
449 graph : `QuantumGraph`
450 `QuantumGraph` that is to be executed
451 butler : `lsst.daf.butler.Butler`
452 Data butler instance
453 report : `Report`
454 Object for reporting execution status.
455 """
456 successCount, totalCount = 0, len(graph)
457 failedNodes: set[QuantumNode] = set()
458 for qnode in graph:
459 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
461 # Any failed inputs mean that the quantum has to be skipped.
462 inputNodes = graph.determineInputsToQuantumNode(qnode)
463 if inputNodes & failedNodes:
464 _LOG.error(
465 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
466 qnode.taskDef,
467 qnode.quantum.dataId,
468 )
469 failedNodes.add(qnode)
470 failed_quantum_report = QuantumReport(
471 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
472 )
473 report.quantaReports.append(failed_quantum_report)
474 continue
476 _LOG.debug("Executing %s", qnode)
477 try:
478 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler)
479 successCount += 1
480 except Exception as exc:
481 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
482 _LOG.error(
483 "Task <%s dataId=%s> failed; dropping into pdb.",
484 qnode.taskDef,
485 qnode.quantum.dataId,
486 exc_info=exc,
487 )
488 try:
489 pdb = importlib.import_module(self.pdb)
490 except ImportError as imp_exc:
491 raise MPGraphExecutorError(
492 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
493 ) from exc
494 if not hasattr(pdb, "post_mortem"):
495 raise MPGraphExecutorError(
496 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
497 ) from exc
498 pdb.post_mortem(exc.__traceback__)
499 failedNodes.add(qnode)
500 report.status = ExecutionStatus.FAILURE
501 if self.failFast:
502 raise MPGraphExecutorError(
503 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
504 ) from exc
505 else:
506 # Note that there could be exception safety issues, which
507 # we presently ignore.
508 _LOG.error(
509 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
510 qnode.taskDef,
511 qnode.quantum.dataId,
512 exc_info=exc,
513 )
514 finally:
515 # sqlalchemy has some objects that can last until a garbage
516 # collection cycle is run, which can happen at unpredictable
517 # times, run a collection loop here explicitly.
518 gc.collect()
520 quantum_report = self.quantumExecutor.getReport()
521 if quantum_report:
522 report.quantaReports.append(quantum_report)
524 _LOG.info(
525 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
526 successCount,
527 len(failedNodes),
528 totalCount - successCount - len(failedNodes),
529 totalCount,
530 )
532 # Raise an exception if there were any failures.
533 if failedNodes:
534 raise MPGraphExecutorError("One or more tasks failed during execution.")
536 def _executeQuantaMP(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
537 """Execute all Quanta in separate processes.
539 Parameters
540 ----------
541 graph : `QuantumGraph`
542 `QuantumGraph` that is to be executed.
543 butler : `lsst.daf.butler.Butler`
544 Data butler instance
545 report : `Report`
546 Object for reporting execution status.
547 """
549 disable_implicit_threading() # To prevent thread contention
551 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
553 # re-pack input quantum data into jobs list
554 jobs = _JobList(graph)
556 # check that all tasks can run in sub-process
557 for job in jobs.jobs:
558 taskDef = job.qnode.taskDef
559 if not taskDef.taskClass.canMultiprocess:
560 raise MPGraphExecutorError(
561 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
562 )
564 finishedCount, failedCount = 0, 0
565 while jobs.pending or jobs.running:
566 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
567 _LOG.debug("#runningJobs: %s", len(jobs.running))
569 # See if any jobs have finished
570 for job in jobs.running:
571 assert job.process is not None, "Process cannot be None"
572 if not job.process.is_alive():
573 _LOG.debug("finished: %s", job)
574 # finished
575 exitcode = job.process.exitcode
576 quantum_report = job.report()
577 report.quantaReports.append(quantum_report)
578 if exitcode == 0:
579 jobs.setJobState(job, JobState.FINISHED)
580 job.cleanup()
581 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
582 else:
583 if job.terminated:
584 # Was killed due to timeout.
585 if report.status == ExecutionStatus.SUCCESS:
586 # Do not override global FAILURE status
587 report.status = ExecutionStatus.TIMEOUT
588 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
589 jobs.setJobState(job, JobState.TIMED_OUT)
590 else:
591 report.status = ExecutionStatus.FAILURE
592 # failMessage() has to be called before cleanup()
593 message = job.failMessage()
594 jobs.setJobState(job, JobState.FAILED)
596 job.cleanup()
597 _LOG.debug("failed: %s", job)
598 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
599 # stop all running jobs
600 for stopJob in jobs.running:
601 if stopJob is not job:
602 stopJob.stop()
603 if job.state is JobState.TIMED_OUT:
604 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
605 else:
606 raise MPGraphExecutorError(message)
607 else:
608 _LOG.error("%s; processing will continue for remaining tasks.", message)
609 else:
610 # check for timeout
611 now = time.time()
612 if now - job.started > self.timeout:
613 # Try to kill it, and there is a chance that it
614 # finishes successfully before it gets killed. Exit
615 # status is handled by the code above on next
616 # iteration.
617 _LOG.debug("Terminating job %s due to timeout", job)
618 job.stop()
620 # Fail jobs whose inputs failed, this may need several iterations
621 # if the order is not right, will be done in the next loop.
622 if jobs.failedNodes:
623 for job in jobs.pending:
624 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
625 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
626 if jobInputNodes & jobs.failedNodes:
627 quantum_report = QuantumReport(
628 status=ExecutionStatus.SKIPPED,
629 dataId=job.qnode.quantum.dataId,
630 taskLabel=job.qnode.taskDef.label,
631 )
632 report.quantaReports.append(quantum_report)
633 jobs.setJobState(job, JobState.FAILED_DEP)
634 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
636 # see if we can start more jobs
637 if len(jobs.running) < self.numProc:
638 for job in jobs.pending:
639 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
640 if jobInputNodes <= jobs.finishedNodes:
641 # all dependencies have completed, can start new job
642 if len(jobs.running) < self.numProc:
643 _LOG.debug("Submitting %s", job)
644 jobs.submit(job, butler, self.quantumExecutor, self.startMethod)
645 if len(jobs.running) >= self.numProc:
646 # Cannot start any more jobs, wait until something
647 # finishes.
648 break
650 # Do cleanup for timed out jobs if necessary.
651 jobs.cleanup()
653 # Print progress message if something changed.
654 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
655 if (finishedCount, failedCount) != (newFinished, newFailed):
656 finishedCount, failedCount = newFinished, newFailed
657 totalCount = len(jobs.jobs)
658 _LOG.info(
659 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
660 finishedCount,
661 failedCount,
662 totalCount - finishedCount - failedCount,
663 totalCount,
664 )
666 # Here we want to wait until one of the running jobs completes
667 # but multiprocessing does not provide an API for that, for now
668 # just sleep a little bit and go back to the loop.
669 if jobs.running:
670 time.sleep(0.1)
672 if jobs.failedNodes:
673 # print list of failed jobs
674 _LOG.error("Failed jobs:")
675 for job in jobs.jobs:
676 if job.state != JobState.FINISHED:
677 _LOG.error(" - %s: %s", job.state.name, job)
679 # if any job failed raise an exception
680 if jobs.failedNodes == jobs.timedOutNodes:
681 raise MPTimeoutError("One or more tasks timed out during execution.")
682 else:
683 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
685 def getReport(self) -> Optional[Report]:
686 # Docstring inherited from base class
687 if self.report is None:
688 raise RuntimeError("getReport() called before execute()")
689 return self.report