Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 15%
303 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-14 02:53 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-14 02:53 -0700
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import time
34from enum import Enum
35from typing import TYPE_CHECKING, Iterable, Optional
37from lsst.daf.butler.cli.cliLog import CliLog
38from lsst.pipe.base import InvalidQuantumError, TaskDef
39from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
40from lsst.utils.threads import disable_implicit_threading
42from .executionGraphFixup import ExecutionGraphFixup
43from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
44from .reports import ExecutionStatus, QuantumReport, Report
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 from lsst.daf.butler import Butler
49_LOG = logging.getLogger(__name__)
52# Possible states for the executing task:
53# - PENDING: job has not started yet
54# - RUNNING: job is currently executing
55# - FINISHED: job finished successfully
56# - FAILED: job execution failed (process returned non-zero status)
57# - TIMED_OUT: job is killed due to too long execution time
58# - FAILED_DEP: one of the dependencies of this job has failed/timed out
59JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
62class _Job:
63 """Class representing a job running single task.
65 Parameters
66 ----------
67 qnode: `~lsst.pipe.base.QuantumNode`
68 Quantum and some associated information.
69 """
71 def __init__(self, qnode: QuantumNode):
72 self.qnode = qnode
73 self.process: Optional[multiprocessing.process.BaseProcess] = None
74 self._state = JobState.PENDING
75 self.started: float = 0.0
76 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
77 self._terminated = False
79 @property
80 def state(self) -> JobState:
81 """Job processing state (JobState)"""
82 return self._state
84 @property
85 def terminated(self) -> bool:
86 """Return True if job was killed by stop() method and negative exit
87 code is returned from child process. (`bool`)"""
88 if self._terminated:
89 assert self.process is not None, "Process must be started"
90 if self.process.exitcode is not None:
91 return self.process.exitcode < 0
92 return False
94 def start(
95 self, butler: Butler, quantumExecutor: QuantumExecutor, startMethod: Optional[str] = None
96 ) -> None:
97 """Start process which runs the task.
99 Parameters
100 ----------
101 butler : `lsst.daf.butler.Butler`
102 Data butler instance.
103 quantumExecutor : `QuantumExecutor`
104 Executor for single quantum.
105 startMethod : `str`, optional
106 Start method from `multiprocessing` module.
107 """
108 # Unpickling of quantum has to happen after butler, this is why
109 # it is pickled manually here.
110 quantum_pickle = pickle.dumps(self.qnode.quantum)
111 taskDef = self.qnode.taskDef
112 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
113 logConfigState = CliLog.configState
114 mp_ctx = multiprocessing.get_context(startMethod)
115 self.process = mp_ctx.Process(
116 target=_Job._executeJob,
117 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn),
118 name=f"task-{self.qnode.quantum.dataId}",
119 )
120 self.process.start()
121 self.started = time.time()
122 self._state = JobState.RUNNING
124 @staticmethod
125 def _executeJob(
126 quantumExecutor: QuantumExecutor,
127 taskDef: TaskDef,
128 quantum_pickle: bytes,
129 butler: Butler,
130 logConfigState: list,
131 snd_conn: multiprocessing.connection.Connection,
132 ) -> None:
133 """Execute a job with arguments.
135 Parameters
136 ----------
137 quantumExecutor : `QuantumExecutor`
138 Executor for single quantum.
139 taskDef : `bytes`
140 Task definition structure.
141 quantum_pickle : `bytes`
142 Quantum for this task execution in pickled form.
143 butler : `lss.daf.butler.Butler`
144 Data butler instance.
145 snd_conn : `multiprocessing.Connection`
146 Connection to send job report to parent process.
147 """
148 if logConfigState and not CliLog.configState:
149 # means that we are in a new spawned Python process and we have to
150 # re-initialize logging
151 CliLog.replayConfigState(logConfigState)
153 # have to reset connection pool to avoid sharing connections
154 if butler is not None:
155 butler.registry.resetConnectionPool()
157 quantum = pickle.loads(quantum_pickle)
158 try:
159 quantumExecutor.execute(taskDef, quantum, butler)
160 finally:
161 # If sending fails we do not want this new exception to be exposed.
162 try:
163 report = quantumExecutor.getReport()
164 snd_conn.send(report)
165 except Exception:
166 pass
168 def stop(self) -> None:
169 """Stop the process."""
170 assert self.process is not None, "Process must be started"
171 self.process.terminate()
172 # give it 1 second to finish or KILL
173 for i in range(10):
174 time.sleep(0.1)
175 if not self.process.is_alive():
176 break
177 else:
178 _LOG.debug("Killing process %s", self.process.name)
179 self.process.kill()
180 self._terminated = True
182 def cleanup(self) -> None:
183 """Release processes resources, has to be called for each finished
184 process.
185 """
186 if self.process and not self.process.is_alive():
187 self.process.close()
188 self.process = None
189 self._rcv_conn = None
191 def report(self) -> QuantumReport:
192 """Return task report, should be called after process finishes and
193 before cleanup().
194 """
195 assert self.process is not None, "Process must be started"
196 assert self._rcv_conn is not None, "Process must be started"
197 try:
198 report = self._rcv_conn.recv()
199 report.exitCode = self.process.exitcode
200 except Exception:
201 # Likely due to the process killed, but there may be other reasons.
202 # Exit code should not be None, this is to keep mypy happy.
203 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
204 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
205 report = QuantumReport.from_exit_code(
206 exitCode=exitcode,
207 dataId=self.qnode.quantum.dataId,
208 taskLabel=self.qnode.taskDef.label,
209 )
210 if self.terminated:
211 # Means it was killed, assume it's due to timeout
212 report.status = ExecutionStatus.TIMEOUT
213 return report
215 def failMessage(self) -> str:
216 """Return a message describing task failure"""
217 assert self.process is not None, "Process must be started"
218 assert self.process.exitcode is not None, "Process has to finish"
219 exitcode = self.process.exitcode
220 if exitcode < 0:
221 # Negative exit code means it is killed by signal
222 signum = -exitcode
223 msg = f"Task {self} failed, killed by signal {signum}"
224 # Just in case this is some very odd signal, expect ValueError
225 try:
226 strsignal = signal.strsignal(signum)
227 msg = f"{msg} ({strsignal})"
228 except ValueError:
229 pass
230 elif exitcode > 0:
231 msg = f"Task {self} failed, exit code={exitcode}"
232 else:
233 msg = ""
234 return msg
236 def __str__(self) -> str:
237 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
240class _JobList:
241 """Simple list of _Job instances with few convenience methods.
243 Parameters
244 ----------
245 iterable : iterable of `~lsst.pipe.base.QuantumNode`
246 Sequence of Quanta to execute. This has to be ordered according to
247 task dependencies.
248 """
250 def __init__(self, iterable: Iterable[QuantumNode]):
251 self.jobs = [_Job(qnode) for qnode in iterable]
252 self.pending = self.jobs[:]
253 self.running: list[_Job] = []
254 self.finishedNodes: set[QuantumNode] = set()
255 self.failedNodes: set[QuantumNode] = set()
256 self.timedOutNodes: set[QuantumNode] = set()
258 def submit(
259 self, job: _Job, butler: Butler, quantumExecutor: QuantumExecutor, startMethod: Optional[str] = None
260 ) -> None:
261 """Submit one more job for execution
263 Parameters
264 ----------
265 job : `_Job`
266 Job to submit.
267 butler : `lsst.daf.butler.Butler`
268 Data butler instance.
269 quantumExecutor : `QuantumExecutor`
270 Executor for single quantum.
271 startMethod : `str`, optional
272 Start method from `multiprocessing` module.
273 """
274 # this will raise if job is not in pending list
275 self.pending.remove(job)
276 job.start(butler, quantumExecutor, startMethod)
277 self.running.append(job)
279 def setJobState(self, job: _Job, state: JobState) -> None:
280 """Update job state.
282 Parameters
283 ----------
284 job : `_Job`
285 Job to submit.
286 state : `JobState`
287 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
288 FAILED_DEP state is acceptable.
289 """
290 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
291 assert state in allowedStates, f"State {state} not allowed here"
293 # remove job from pending/running lists
294 if job.state == JobState.PENDING:
295 self.pending.remove(job)
296 elif job.state == JobState.RUNNING:
297 self.running.remove(job)
299 qnode = job.qnode
300 # it should not be in any of these, but just in case
301 self.finishedNodes.discard(qnode)
302 self.failedNodes.discard(qnode)
303 self.timedOutNodes.discard(qnode)
305 job._state = state
306 if state == JobState.FINISHED:
307 self.finishedNodes.add(qnode)
308 elif state == JobState.FAILED:
309 self.failedNodes.add(qnode)
310 elif state == JobState.FAILED_DEP:
311 self.failedNodes.add(qnode)
312 elif state == JobState.TIMED_OUT:
313 self.failedNodes.add(qnode)
314 self.timedOutNodes.add(qnode)
315 else:
316 raise ValueError(f"Unexpected state value: {state}")
318 def cleanup(self) -> None:
319 """Do periodic cleanup for jobs that did not finish correctly.
321 If timed out jobs are killed but take too long to stop then regular
322 cleanup will not work for them. Here we check all timed out jobs
323 periodically and do cleanup if they managed to die by this time.
324 """
325 for job in self.jobs:
326 if job.state == JobState.TIMED_OUT and job.process is not None:
327 job.cleanup()
330class MPGraphExecutorError(Exception):
331 """Exception class for errors raised by MPGraphExecutor."""
333 pass
336class MPTimeoutError(MPGraphExecutorError):
337 """Exception raised when task execution times out."""
339 pass
342class MPGraphExecutor(QuantumGraphExecutor):
343 """Implementation of QuantumGraphExecutor using same-host multiprocess
344 execution of Quanta.
346 Parameters
347 ----------
348 numProc : `int`
349 Number of processes to use for executing tasks.
350 timeout : `float`
351 Time in seconds to wait for tasks to finish.
352 quantumExecutor : `QuantumExecutor`
353 Executor for single quantum. For multiprocess-style execution when
354 ``numProc`` is greater than one this instance must support pickle.
355 startMethod : `str`, optional
356 Start method from `multiprocessing` module, `None` selects the best
357 one for current platform.
358 failFast : `bool`, optional
359 If set to ``True`` then stop processing on first error from any task.
360 pdb : `str`, optional
361 Debugger to import and use (via the ``post_mortem`` function) in the
362 event of an exception.
363 executionGraphFixup : `ExecutionGraphFixup`, optional
364 Instance used for modification of execution graph.
365 """
367 def __init__(
368 self,
369 numProc: int,
370 timeout: float,
371 quantumExecutor: QuantumExecutor,
372 *,
373 startMethod: Optional[str] = None,
374 failFast: bool = False,
375 pdb: Optional[str] = None,
376 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
377 ):
378 self.numProc = numProc
379 self.timeout = timeout
380 self.quantumExecutor = quantumExecutor
381 self.failFast = failFast
382 self.pdb = pdb
383 self.executionGraphFixup = executionGraphFixup
384 self.report: Optional[Report] = None
386 # We set default start method as spawn for MacOS and fork for Linux;
387 # None for all other platforms to use multiprocessing default.
388 if startMethod is None:
389 methods = dict(linux="fork", darwin="spawn")
390 startMethod = methods.get(sys.platform)
391 self.startMethod = startMethod
393 def execute(self, graph: QuantumGraph, butler: Butler) -> None:
394 # Docstring inherited from QuantumGraphExecutor.execute
395 graph = self._fixupQuanta(graph)
396 self.report = Report()
397 try:
398 if self.numProc > 1:
399 self._executeQuantaMP(graph, butler, self.report)
400 else:
401 self._executeQuantaInProcess(graph, butler, self.report)
402 except Exception as exc:
403 self.report.set_exception(exc)
404 raise
406 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
407 """Call fixup code to modify execution graph.
409 Parameters
410 ----------
411 graph : `QuantumGraph`
412 `QuantumGraph` to modify
414 Returns
415 -------
416 graph : `QuantumGraph`
417 Modified `QuantumGraph`.
419 Raises
420 ------
421 MPGraphExecutorError
422 Raised if execution graph cannot be ordered after modification,
423 i.e. it has dependency cycles.
424 """
425 if not self.executionGraphFixup:
426 return graph
428 _LOG.debug("Call execution graph fixup method")
429 graph = self.executionGraphFixup.fixupQuanta(graph)
431 # Detect if there is now a cycle created within the graph
432 if graph.findCycle():
433 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
435 return graph
437 def _executeQuantaInProcess(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
438 """Execute all Quanta in current process.
440 Parameters
441 ----------
442 graph : `QuantumGraph`
443 `QuantumGraph` that is to be executed
444 butler : `lsst.daf.butler.Butler`
445 Data butler instance
446 report : `Report`
447 Object for reporting execution status.
448 """
449 successCount, totalCount = 0, len(graph)
450 failedNodes: set[QuantumNode] = set()
451 for qnode in graph:
453 # Any failed inputs mean that the quantum has to be skipped.
454 inputNodes = graph.determineInputsToQuantumNode(qnode)
455 if inputNodes & failedNodes:
456 _LOG.error(
457 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
458 qnode.taskDef,
459 qnode.quantum.dataId,
460 )
461 failedNodes.add(qnode)
462 failed_quantum_report = QuantumReport(
463 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
464 )
465 report.quantaReports.append(failed_quantum_report)
466 continue
468 _LOG.debug("Executing %s", qnode)
469 try:
470 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler)
471 successCount += 1
472 except Exception as exc:
473 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
474 _LOG.error(
475 "Task <%s dataId=%s> failed; dropping into pdb.",
476 qnode.taskDef,
477 qnode.quantum.dataId,
478 exc_info=exc,
479 )
480 try:
481 pdb = importlib.import_module(self.pdb)
482 except ImportError as imp_exc:
483 raise MPGraphExecutorError(
484 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
485 ) from exc
486 if not hasattr(pdb, "post_mortem"):
487 raise MPGraphExecutorError(
488 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
489 ) from exc
490 pdb.post_mortem(exc.__traceback__)
491 failedNodes.add(qnode)
492 report.status = ExecutionStatus.FAILURE
493 if self.failFast:
494 raise MPGraphExecutorError(
495 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
496 ) from exc
497 else:
498 # Note that there could be exception safety issues, which
499 # we presently ignore.
500 _LOG.error(
501 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
502 qnode.taskDef,
503 qnode.quantum.dataId,
504 exc_info=exc,
505 )
506 finally:
507 # sqlalchemy has some objects that can last until a garbage
508 # collection cycle is run, which can happen at unpredictable
509 # times, run a collection loop here explicitly.
510 gc.collect()
512 quantum_report = self.quantumExecutor.getReport()
513 if quantum_report:
514 report.quantaReports.append(quantum_report)
516 _LOG.info(
517 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
518 successCount,
519 len(failedNodes),
520 totalCount - successCount - len(failedNodes),
521 totalCount,
522 )
524 # Raise an exception if there were any failures.
525 if failedNodes:
526 raise MPGraphExecutorError("One or more tasks failed during execution.")
528 def _executeQuantaMP(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
529 """Execute all Quanta in separate processes.
531 Parameters
532 ----------
533 graph : `QuantumGraph`
534 `QuantumGraph` that is to be executed.
535 butler : `lsst.daf.butler.Butler`
536 Data butler instance
537 report : `Report`
538 Object for reporting execution status.
539 """
541 disable_implicit_threading() # To prevent thread contention
543 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
545 # re-pack input quantum data into jobs list
546 jobs = _JobList(graph)
548 # check that all tasks can run in sub-process
549 for job in jobs.jobs:
550 taskDef = job.qnode.taskDef
551 if not taskDef.taskClass.canMultiprocess:
552 raise MPGraphExecutorError(
553 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
554 )
556 finishedCount, failedCount = 0, 0
557 while jobs.pending or jobs.running:
559 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
560 _LOG.debug("#runningJobs: %s", len(jobs.running))
562 # See if any jobs have finished
563 for job in jobs.running:
564 assert job.process is not None, "Process cannot be None"
565 if not job.process.is_alive():
566 _LOG.debug("finished: %s", job)
567 # finished
568 exitcode = job.process.exitcode
569 quantum_report = job.report()
570 report.quantaReports.append(quantum_report)
571 if exitcode == 0:
572 jobs.setJobState(job, JobState.FINISHED)
573 job.cleanup()
574 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
575 else:
576 if job.terminated:
577 # Was killed due to timeout.
578 if report.status == ExecutionStatus.SUCCESS:
579 # Do not override global FAILURE status
580 report.status = ExecutionStatus.TIMEOUT
581 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
582 jobs.setJobState(job, JobState.TIMED_OUT)
583 else:
584 report.status = ExecutionStatus.FAILURE
585 # failMessage() has to be called before cleanup()
586 message = job.failMessage()
587 jobs.setJobState(job, JobState.FAILED)
589 job.cleanup()
590 _LOG.debug("failed: %s", job)
591 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
592 # stop all running jobs
593 for stopJob in jobs.running:
594 if stopJob is not job:
595 stopJob.stop()
596 if job.state is JobState.TIMED_OUT:
597 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
598 else:
599 raise MPGraphExecutorError(message)
600 else:
601 _LOG.error("%s; processing will continue for remaining tasks.", message)
602 else:
603 # check for timeout
604 now = time.time()
605 if now - job.started > self.timeout:
606 # Try to kill it, and there is a chance that it
607 # finishes successfully before it gets killed. Exit
608 # status is handled by the code above on next
609 # iteration.
610 _LOG.debug("Terminating job %s due to timeout", job)
611 job.stop()
613 # Fail jobs whose inputs failed, this may need several iterations
614 # if the order is not right, will be done in the next loop.
615 if jobs.failedNodes:
616 for job in jobs.pending:
617 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
618 if jobInputNodes & jobs.failedNodes:
619 quantum_report = QuantumReport(
620 status=ExecutionStatus.SKIPPED,
621 dataId=job.qnode.quantum.dataId,
622 taskLabel=job.qnode.taskDef.label,
623 )
624 report.quantaReports.append(quantum_report)
625 jobs.setJobState(job, JobState.FAILED_DEP)
626 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
628 # see if we can start more jobs
629 if len(jobs.running) < self.numProc:
630 for job in jobs.pending:
631 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
632 if jobInputNodes <= jobs.finishedNodes:
633 # all dependencies have completed, can start new job
634 if len(jobs.running) < self.numProc:
635 _LOG.debug("Submitting %s", job)
636 jobs.submit(job, butler, self.quantumExecutor, self.startMethod)
637 if len(jobs.running) >= self.numProc:
638 # Cannot start any more jobs, wait until something
639 # finishes.
640 break
642 # Do cleanup for timed out jobs if necessary.
643 jobs.cleanup()
645 # Print progress message if something changed.
646 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
647 if (finishedCount, failedCount) != (newFinished, newFailed):
648 finishedCount, failedCount = newFinished, newFailed
649 totalCount = len(jobs.jobs)
650 _LOG.info(
651 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
652 finishedCount,
653 failedCount,
654 totalCount - finishedCount - failedCount,
655 totalCount,
656 )
658 # Here we want to wait until one of the running jobs completes
659 # but multiprocessing does not provide an API for that, for now
660 # just sleep a little bit and go back to the loop.
661 if jobs.running:
662 time.sleep(0.1)
664 if jobs.failedNodes:
665 # print list of failed jobs
666 _LOG.error("Failed jobs:")
667 for job in jobs.jobs:
668 if job.state != JobState.FINISHED:
669 _LOG.error(" - %s: %s", job.state.name, job)
671 # if any job failed raise an exception
672 if jobs.failedNodes == jobs.timedOutNodes:
673 raise MPTimeoutError("One or more tasks timed out during execution.")
674 else:
675 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
677 def getReport(self) -> Optional[Report]:
678 # Docstring inherited from base class
679 if self.report is None:
680 raise RuntimeError("getReport() called before execute()")
681 return self.report