Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 12%
307 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-10 10:58 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-10 10:58 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import time
34from enum import Enum
35from typing import TYPE_CHECKING, Iterable, Literal, Optional
37from lsst.daf.butler.cli.cliLog import CliLog
38from lsst.pipe.base import InvalidQuantumError, TaskDef
39from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
40from lsst.utils.threads import disable_implicit_threading
42from .executionGraphFixup import ExecutionGraphFixup
43from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
44from .reports import ExecutionStatus, QuantumReport, Report
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 from lsst.daf.butler import Butler
49_LOG = logging.getLogger(__name__)
52# Possible states for the executing task:
53# - PENDING: job has not started yet
54# - RUNNING: job is currently executing
55# - FINISHED: job finished successfully
56# - FAILED: job execution failed (process returned non-zero status)
57# - TIMED_OUT: job is killed due to too long execution time
58# - FAILED_DEP: one of the dependencies of this job has failed/timed out
59JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
62class _Job:
63 """Class representing a job running single task.
65 Parameters
66 ----------
67 qnode: `~lsst.pipe.base.QuantumNode`
68 Quantum and some associated information.
69 """
71 def __init__(self, qnode: QuantumNode):
72 self.qnode = qnode
73 self.process: Optional[multiprocessing.process.BaseProcess] = None
74 self._state = JobState.PENDING
75 self.started: float = 0.0
76 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
77 self._terminated = False
79 @property
80 def state(self) -> JobState:
81 """Job processing state (JobState)"""
82 return self._state
84 @property
85 def terminated(self) -> bool:
86 """Return True if job was killed by stop() method and negative exit
87 code is returned from child process. (`bool`)"""
88 if self._terminated:
89 assert self.process is not None, "Process must be started"
90 if self.process.exitcode is not None:
91 return self.process.exitcode < 0
92 return False
94 def start(
95 self,
96 butler: Butler,
97 quantumExecutor: QuantumExecutor,
98 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
99 ) -> None:
100 """Start process which runs the task.
102 Parameters
103 ----------
104 butler : `lsst.daf.butler.Butler`
105 Data butler instance.
106 quantumExecutor : `QuantumExecutor`
107 Executor for single quantum.
108 startMethod : `str`, optional
109 Start method from `multiprocessing` module.
110 """
111 # Unpickling of quantum has to happen after butler, this is why
112 # it is pickled manually here.
113 quantum_pickle = pickle.dumps(self.qnode.quantum)
114 taskDef = self.qnode.taskDef
115 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
116 logConfigState = CliLog.configState
117 mp_ctx = multiprocessing.get_context(startMethod)
118 self.process = mp_ctx.Process(
119 target=_Job._executeJob,
120 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn),
121 name=f"task-{self.qnode.quantum.dataId}",
122 )
123 self.process.start()
124 self.started = time.time()
125 self._state = JobState.RUNNING
127 @staticmethod
128 def _executeJob(
129 quantumExecutor: QuantumExecutor,
130 taskDef: TaskDef,
131 quantum_pickle: bytes,
132 butler: Butler,
133 logConfigState: list,
134 snd_conn: multiprocessing.connection.Connection,
135 ) -> None:
136 """Execute a job with arguments.
138 Parameters
139 ----------
140 quantumExecutor : `QuantumExecutor`
141 Executor for single quantum.
142 taskDef : `bytes`
143 Task definition structure.
144 quantum_pickle : `bytes`
145 Quantum for this task execution in pickled form.
146 butler : `lss.daf.butler.Butler`
147 Data butler instance.
148 snd_conn : `multiprocessing.Connection`
149 Connection to send job report to parent process.
150 """
151 if logConfigState and not CliLog.configState:
152 # means that we are in a new spawned Python process and we have to
153 # re-initialize logging
154 CliLog.replayConfigState(logConfigState)
156 # have to reset connection pool to avoid sharing connections
157 if butler is not None:
158 butler.registry.resetConnectionPool()
160 quantum = pickle.loads(quantum_pickle)
161 try:
162 if butler is not None:
163 # None is apparently possible in some tests, despite annotation
164 butler.registry.refresh()
165 quantumExecutor.execute(taskDef, quantum, butler)
166 finally:
167 # If sending fails we do not want this new exception to be exposed.
168 try:
169 report = quantumExecutor.getReport()
170 snd_conn.send(report)
171 except Exception:
172 pass
174 def stop(self) -> None:
175 """Stop the process."""
176 assert self.process is not None, "Process must be started"
177 self.process.terminate()
178 # give it 1 second to finish or KILL
179 for i in range(10):
180 time.sleep(0.1)
181 if not self.process.is_alive():
182 break
183 else:
184 _LOG.debug("Killing process %s", self.process.name)
185 self.process.kill()
186 self._terminated = True
188 def cleanup(self) -> None:
189 """Release processes resources, has to be called for each finished
190 process.
191 """
192 if self.process and not self.process.is_alive():
193 self.process.close()
194 self.process = None
195 self._rcv_conn = None
197 def report(self) -> QuantumReport:
198 """Return task report, should be called after process finishes and
199 before cleanup().
200 """
201 assert self.process is not None, "Process must be started"
202 assert self._rcv_conn is not None, "Process must be started"
203 try:
204 report = self._rcv_conn.recv()
205 report.exitCode = self.process.exitcode
206 except Exception:
207 # Likely due to the process killed, but there may be other reasons.
208 # Exit code should not be None, this is to keep mypy happy.
209 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
210 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
211 report = QuantumReport.from_exit_code(
212 exitCode=exitcode,
213 dataId=self.qnode.quantum.dataId,
214 taskLabel=self.qnode.taskDef.label,
215 )
216 if self.terminated:
217 # Means it was killed, assume it's due to timeout
218 report.status = ExecutionStatus.TIMEOUT
219 return report
221 def failMessage(self) -> str:
222 """Return a message describing task failure"""
223 assert self.process is not None, "Process must be started"
224 assert self.process.exitcode is not None, "Process has to finish"
225 exitcode = self.process.exitcode
226 if exitcode < 0:
227 # Negative exit code means it is killed by signal
228 signum = -exitcode
229 msg = f"Task {self} failed, killed by signal {signum}"
230 # Just in case this is some very odd signal, expect ValueError
231 try:
232 strsignal = signal.strsignal(signum)
233 msg = f"{msg} ({strsignal})"
234 except ValueError:
235 pass
236 elif exitcode > 0:
237 msg = f"Task {self} failed, exit code={exitcode}"
238 else:
239 msg = ""
240 return msg
242 def __str__(self) -> str:
243 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
246class _JobList:
247 """Simple list of _Job instances with few convenience methods.
249 Parameters
250 ----------
251 iterable : iterable of `~lsst.pipe.base.QuantumNode`
252 Sequence of Quanta to execute. This has to be ordered according to
253 task dependencies.
254 """
256 def __init__(self, iterable: Iterable[QuantumNode]):
257 self.jobs = [_Job(qnode) for qnode in iterable]
258 self.pending = self.jobs[:]
259 self.running: list[_Job] = []
260 self.finishedNodes: set[QuantumNode] = set()
261 self.failedNodes: set[QuantumNode] = set()
262 self.timedOutNodes: set[QuantumNode] = set()
264 def submit(
265 self,
266 job: _Job,
267 butler: Butler,
268 quantumExecutor: QuantumExecutor,
269 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
270 ) -> None:
271 """Submit one more job for execution
273 Parameters
274 ----------
275 job : `_Job`
276 Job to submit.
277 butler : `lsst.daf.butler.Butler`
278 Data butler instance.
279 quantumExecutor : `QuantumExecutor`
280 Executor for single quantum.
281 startMethod : `str`, optional
282 Start method from `multiprocessing` module.
283 """
284 # this will raise if job is not in pending list
285 self.pending.remove(job)
286 job.start(butler, quantumExecutor, startMethod)
287 self.running.append(job)
289 def setJobState(self, job: _Job, state: JobState) -> None:
290 """Update job state.
292 Parameters
293 ----------
294 job : `_Job`
295 Job to submit.
296 state : `JobState`
297 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
298 FAILED_DEP state is acceptable.
299 """
300 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
301 assert state in allowedStates, f"State {state} not allowed here"
303 # remove job from pending/running lists
304 if job.state == JobState.PENDING:
305 self.pending.remove(job)
306 elif job.state == JobState.RUNNING:
307 self.running.remove(job)
309 qnode = job.qnode
310 # it should not be in any of these, but just in case
311 self.finishedNodes.discard(qnode)
312 self.failedNodes.discard(qnode)
313 self.timedOutNodes.discard(qnode)
315 job._state = state
316 if state == JobState.FINISHED:
317 self.finishedNodes.add(qnode)
318 elif state == JobState.FAILED:
319 self.failedNodes.add(qnode)
320 elif state == JobState.FAILED_DEP:
321 self.failedNodes.add(qnode)
322 elif state == JobState.TIMED_OUT:
323 self.failedNodes.add(qnode)
324 self.timedOutNodes.add(qnode)
325 else:
326 raise ValueError(f"Unexpected state value: {state}")
328 def cleanup(self) -> None:
329 """Do periodic cleanup for jobs that did not finish correctly.
331 If timed out jobs are killed but take too long to stop then regular
332 cleanup will not work for them. Here we check all timed out jobs
333 periodically and do cleanup if they managed to die by this time.
334 """
335 for job in self.jobs:
336 if job.state == JobState.TIMED_OUT and job.process is not None:
337 job.cleanup()
340class MPGraphExecutorError(Exception):
341 """Exception class for errors raised by MPGraphExecutor."""
343 pass
346class MPTimeoutError(MPGraphExecutorError):
347 """Exception raised when task execution times out."""
349 pass
352class MPGraphExecutor(QuantumGraphExecutor):
353 """Implementation of QuantumGraphExecutor using same-host multiprocess
354 execution of Quanta.
356 Parameters
357 ----------
358 numProc : `int`
359 Number of processes to use for executing tasks.
360 timeout : `float`
361 Time in seconds to wait for tasks to finish.
362 quantumExecutor : `QuantumExecutor`
363 Executor for single quantum. For multiprocess-style execution when
364 ``numProc`` is greater than one this instance must support pickle.
365 startMethod : `str`, optional
366 Start method from `multiprocessing` module, `None` selects the best
367 one for current platform.
368 failFast : `bool`, optional
369 If set to ``True`` then stop processing on first error from any task.
370 pdb : `str`, optional
371 Debugger to import and use (via the ``post_mortem`` function) in the
372 event of an exception.
373 executionGraphFixup : `ExecutionGraphFixup`, optional
374 Instance used for modification of execution graph.
375 """
377 def __init__(
378 self,
379 numProc: int,
380 timeout: float,
381 quantumExecutor: QuantumExecutor,
382 *,
383 startMethod: Literal["spawn"] | Literal["fork"] | Literal["forkserver"] | None = None,
384 failFast: bool = False,
385 pdb: Optional[str] = None,
386 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
387 ):
388 self.numProc = numProc
389 self.timeout = timeout
390 self.quantumExecutor = quantumExecutor
391 self.failFast = failFast
392 self.pdb = pdb
393 self.executionGraphFixup = executionGraphFixup
394 self.report: Optional[Report] = None
396 # We set default start method as spawn for MacOS and fork for Linux;
397 # None for all other platforms to use multiprocessing default.
398 if startMethod is None:
399 methods = dict(linux="fork", darwin="spawn")
400 startMethod = methods.get(sys.platform) # type: ignore
401 self.startMethod = startMethod
403 def execute(self, graph: QuantumGraph, butler: Butler) -> None:
404 # Docstring inherited from QuantumGraphExecutor.execute
405 graph = self._fixupQuanta(graph)
406 self.report = Report()
407 try:
408 if self.numProc > 1:
409 self._executeQuantaMP(graph, butler, self.report)
410 else:
411 self._executeQuantaInProcess(graph, butler, self.report)
412 except Exception as exc:
413 self.report.set_exception(exc)
414 raise
416 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
417 """Call fixup code to modify execution graph.
419 Parameters
420 ----------
421 graph : `QuantumGraph`
422 `QuantumGraph` to modify
424 Returns
425 -------
426 graph : `QuantumGraph`
427 Modified `QuantumGraph`.
429 Raises
430 ------
431 MPGraphExecutorError
432 Raised if execution graph cannot be ordered after modification,
433 i.e. it has dependency cycles.
434 """
435 if not self.executionGraphFixup:
436 return graph
438 _LOG.debug("Call execution graph fixup method")
439 graph = self.executionGraphFixup.fixupQuanta(graph)
441 # Detect if there is now a cycle created within the graph
442 if graph.findCycle():
443 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
445 return graph
447 def _executeQuantaInProcess(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
448 """Execute all Quanta in current process.
450 Parameters
451 ----------
452 graph : `QuantumGraph`
453 `QuantumGraph` that is to be executed
454 butler : `lsst.daf.butler.Butler`
455 Data butler instance
456 report : `Report`
457 Object for reporting execution status.
458 """
459 successCount, totalCount = 0, len(graph)
460 failedNodes: set[QuantumNode] = set()
461 for qnode in graph:
463 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
465 # Any failed inputs mean that the quantum has to be skipped.
466 inputNodes = graph.determineInputsToQuantumNode(qnode)
467 if inputNodes & failedNodes:
468 _LOG.error(
469 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
470 qnode.taskDef,
471 qnode.quantum.dataId,
472 )
473 failedNodes.add(qnode)
474 failed_quantum_report = QuantumReport(
475 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
476 )
477 report.quantaReports.append(failed_quantum_report)
478 continue
480 _LOG.debug("Executing %s", qnode)
481 try:
482 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler)
483 successCount += 1
484 except Exception as exc:
485 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
486 _LOG.error(
487 "Task <%s dataId=%s> failed; dropping into pdb.",
488 qnode.taskDef,
489 qnode.quantum.dataId,
490 exc_info=exc,
491 )
492 try:
493 pdb = importlib.import_module(self.pdb)
494 except ImportError as imp_exc:
495 raise MPGraphExecutorError(
496 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
497 ) from exc
498 if not hasattr(pdb, "post_mortem"):
499 raise MPGraphExecutorError(
500 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
501 ) from exc
502 pdb.post_mortem(exc.__traceback__)
503 failedNodes.add(qnode)
504 report.status = ExecutionStatus.FAILURE
505 if self.failFast:
506 raise MPGraphExecutorError(
507 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
508 ) from exc
509 else:
510 # Note that there could be exception safety issues, which
511 # we presently ignore.
512 _LOG.error(
513 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
514 qnode.taskDef,
515 qnode.quantum.dataId,
516 exc_info=exc,
517 )
518 finally:
519 # sqlalchemy has some objects that can last until a garbage
520 # collection cycle is run, which can happen at unpredictable
521 # times, run a collection loop here explicitly.
522 gc.collect()
524 quantum_report = self.quantumExecutor.getReport()
525 if quantum_report:
526 report.quantaReports.append(quantum_report)
528 _LOG.info(
529 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
530 successCount,
531 len(failedNodes),
532 totalCount - successCount - len(failedNodes),
533 totalCount,
534 )
536 # Raise an exception if there were any failures.
537 if failedNodes:
538 raise MPGraphExecutorError("One or more tasks failed during execution.")
540 def _executeQuantaMP(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
541 """Execute all Quanta in separate processes.
543 Parameters
544 ----------
545 graph : `QuantumGraph`
546 `QuantumGraph` that is to be executed.
547 butler : `lsst.daf.butler.Butler`
548 Data butler instance
549 report : `Report`
550 Object for reporting execution status.
551 """
553 disable_implicit_threading() # To prevent thread contention
555 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
557 # re-pack input quantum data into jobs list
558 jobs = _JobList(graph)
560 # check that all tasks can run in sub-process
561 for job in jobs.jobs:
562 taskDef = job.qnode.taskDef
563 if not taskDef.taskClass.canMultiprocess:
564 raise MPGraphExecutorError(
565 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
566 )
568 finishedCount, failedCount = 0, 0
569 while jobs.pending or jobs.running:
571 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
572 _LOG.debug("#runningJobs: %s", len(jobs.running))
574 # See if any jobs have finished
575 for job in jobs.running:
576 assert job.process is not None, "Process cannot be None"
577 if not job.process.is_alive():
578 _LOG.debug("finished: %s", job)
579 # finished
580 exitcode = job.process.exitcode
581 quantum_report = job.report()
582 report.quantaReports.append(quantum_report)
583 if exitcode == 0:
584 jobs.setJobState(job, JobState.FINISHED)
585 job.cleanup()
586 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
587 else:
588 if job.terminated:
589 # Was killed due to timeout.
590 if report.status == ExecutionStatus.SUCCESS:
591 # Do not override global FAILURE status
592 report.status = ExecutionStatus.TIMEOUT
593 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
594 jobs.setJobState(job, JobState.TIMED_OUT)
595 else:
596 report.status = ExecutionStatus.FAILURE
597 # failMessage() has to be called before cleanup()
598 message = job.failMessage()
599 jobs.setJobState(job, JobState.FAILED)
601 job.cleanup()
602 _LOG.debug("failed: %s", job)
603 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
604 # stop all running jobs
605 for stopJob in jobs.running:
606 if stopJob is not job:
607 stopJob.stop()
608 if job.state is JobState.TIMED_OUT:
609 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
610 else:
611 raise MPGraphExecutorError(message)
612 else:
613 _LOG.error("%s; processing will continue for remaining tasks.", message)
614 else:
615 # check for timeout
616 now = time.time()
617 if now - job.started > self.timeout:
618 # Try to kill it, and there is a chance that it
619 # finishes successfully before it gets killed. Exit
620 # status is handled by the code above on next
621 # iteration.
622 _LOG.debug("Terminating job %s due to timeout", job)
623 job.stop()
625 # Fail jobs whose inputs failed, this may need several iterations
626 # if the order is not right, will be done in the next loop.
627 if jobs.failedNodes:
628 for job in jobs.pending:
629 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
630 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
631 if jobInputNodes & jobs.failedNodes:
632 quantum_report = QuantumReport(
633 status=ExecutionStatus.SKIPPED,
634 dataId=job.qnode.quantum.dataId,
635 taskLabel=job.qnode.taskDef.label,
636 )
637 report.quantaReports.append(quantum_report)
638 jobs.setJobState(job, JobState.FAILED_DEP)
639 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
641 # see if we can start more jobs
642 if len(jobs.running) < self.numProc:
643 for job in jobs.pending:
644 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
645 if jobInputNodes <= jobs.finishedNodes:
646 # all dependencies have completed, can start new job
647 if len(jobs.running) < self.numProc:
648 _LOG.debug("Submitting %s", job)
649 jobs.submit(job, butler, self.quantumExecutor, self.startMethod)
650 if len(jobs.running) >= self.numProc:
651 # Cannot start any more jobs, wait until something
652 # finishes.
653 break
655 # Do cleanup for timed out jobs if necessary.
656 jobs.cleanup()
658 # Print progress message if something changed.
659 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
660 if (finishedCount, failedCount) != (newFinished, newFailed):
661 finishedCount, failedCount = newFinished, newFailed
662 totalCount = len(jobs.jobs)
663 _LOG.info(
664 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
665 finishedCount,
666 failedCount,
667 totalCount - finishedCount - failedCount,
668 totalCount,
669 )
671 # Here we want to wait until one of the running jobs completes
672 # but multiprocessing does not provide an API for that, for now
673 # just sleep a little bit and go back to the loop.
674 if jobs.running:
675 time.sleep(0.1)
677 if jobs.failedNodes:
678 # print list of failed jobs
679 _LOG.error("Failed jobs:")
680 for job in jobs.jobs:
681 if job.state != JobState.FINISHED:
682 _LOG.error(" - %s: %s", job.state.name, job)
684 # if any job failed raise an exception
685 if jobs.failedNodes == jobs.timedOutNodes:
686 raise MPTimeoutError("One or more tasks timed out during execution.")
687 else:
688 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
690 def getReport(self) -> Optional[Report]:
691 # Docstring inherited from base class
692 if self.report is None:
693 raise RuntimeError("getReport() called before execute()")
694 return self.report