Coverage for python/lsst/ctrl/mpexec/mpGraphExecutor.py: 14%
305 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-30 02:33 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-30 02:33 -0700
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["MPGraphExecutor", "MPGraphExecutorError", "MPTimeoutError"]
26import gc
27import importlib
28import logging
29import multiprocessing
30import pickle
31import signal
32import sys
33import time
34from enum import Enum
35from typing import TYPE_CHECKING, Iterable, Optional
37from lsst.daf.butler.cli.cliLog import CliLog
38from lsst.pipe.base import InvalidQuantumError, TaskDef
39from lsst.pipe.base.graph.graph import QuantumGraph, QuantumNode
40from lsst.utils.threads import disable_implicit_threading
42from .executionGraphFixup import ExecutionGraphFixup
43from .quantumGraphExecutor import QuantumExecutor, QuantumGraphExecutor
44from .reports import ExecutionStatus, QuantumReport, Report
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 from lsst.daf.butler import Butler
49_LOG = logging.getLogger(__name__)
52# Possible states for the executing task:
53# - PENDING: job has not started yet
54# - RUNNING: job is currently executing
55# - FINISHED: job finished successfully
56# - FAILED: job execution failed (process returned non-zero status)
57# - TIMED_OUT: job is killed due to too long execution time
58# - FAILED_DEP: one of the dependencies of this job has failed/timed out
59JobState = Enum("JobState", "PENDING RUNNING FINISHED FAILED TIMED_OUT FAILED_DEP")
62class _Job:
63 """Class representing a job running single task.
65 Parameters
66 ----------
67 qnode: `~lsst.pipe.base.QuantumNode`
68 Quantum and some associated information.
69 """
71 def __init__(self, qnode: QuantumNode):
72 self.qnode = qnode
73 self.process: Optional[multiprocessing.process.BaseProcess] = None
74 self._state = JobState.PENDING
75 self.started: float = 0.0
76 self._rcv_conn: Optional[multiprocessing.connection.Connection] = None
77 self._terminated = False
79 @property
80 def state(self) -> JobState:
81 """Job processing state (JobState)"""
82 return self._state
84 @property
85 def terminated(self) -> bool:
86 """Return True if job was killed by stop() method and negative exit
87 code is returned from child process. (`bool`)"""
88 if self._terminated:
89 assert self.process is not None, "Process must be started"
90 if self.process.exitcode is not None:
91 return self.process.exitcode < 0
92 return False
94 def start(
95 self, butler: Butler, quantumExecutor: QuantumExecutor, startMethod: Optional[str] = None
96 ) -> None:
97 """Start process which runs the task.
99 Parameters
100 ----------
101 butler : `lsst.daf.butler.Butler`
102 Data butler instance.
103 quantumExecutor : `QuantumExecutor`
104 Executor for single quantum.
105 startMethod : `str`, optional
106 Start method from `multiprocessing` module.
107 """
108 # Unpickling of quantum has to happen after butler, this is why
109 # it is pickled manually here.
110 quantum_pickle = pickle.dumps(self.qnode.quantum)
111 taskDef = self.qnode.taskDef
112 self._rcv_conn, snd_conn = multiprocessing.Pipe(False)
113 logConfigState = CliLog.configState
114 mp_ctx = multiprocessing.get_context(startMethod)
115 self.process = mp_ctx.Process(
116 target=_Job._executeJob,
117 args=(quantumExecutor, taskDef, quantum_pickle, butler, logConfigState, snd_conn),
118 name=f"task-{self.qnode.quantum.dataId}",
119 )
120 self.process.start()
121 self.started = time.time()
122 self._state = JobState.RUNNING
124 @staticmethod
125 def _executeJob(
126 quantumExecutor: QuantumExecutor,
127 taskDef: TaskDef,
128 quantum_pickle: bytes,
129 butler: Butler,
130 logConfigState: list,
131 snd_conn: multiprocessing.connection.Connection,
132 ) -> None:
133 """Execute a job with arguments.
135 Parameters
136 ----------
137 quantumExecutor : `QuantumExecutor`
138 Executor for single quantum.
139 taskDef : `bytes`
140 Task definition structure.
141 quantum_pickle : `bytes`
142 Quantum for this task execution in pickled form.
143 butler : `lss.daf.butler.Butler`
144 Data butler instance.
145 snd_conn : `multiprocessing.Connection`
146 Connection to send job report to parent process.
147 """
148 if logConfigState and not CliLog.configState:
149 # means that we are in a new spawned Python process and we have to
150 # re-initialize logging
151 CliLog.replayConfigState(logConfigState)
153 # have to reset connection pool to avoid sharing connections
154 if butler is not None:
155 butler.registry.resetConnectionPool()
157 quantum = pickle.loads(quantum_pickle)
158 try:
159 quantumExecutor.execute(taskDef, quantum, butler)
160 finally:
161 # If sending fails we do not want this new exception to be exposed.
162 try:
163 report = quantumExecutor.getReport()
164 snd_conn.send(report)
165 except Exception:
166 pass
168 def stop(self) -> None:
169 """Stop the process."""
170 assert self.process is not None, "Process must be started"
171 self.process.terminate()
172 # give it 1 second to finish or KILL
173 for i in range(10):
174 time.sleep(0.1)
175 if not self.process.is_alive():
176 break
177 else:
178 _LOG.debug("Killing process %s", self.process.name)
179 self.process.kill()
180 self._terminated = True
182 def cleanup(self) -> None:
183 """Release processes resources, has to be called for each finished
184 process.
185 """
186 if self.process and not self.process.is_alive():
187 self.process.close()
188 self.process = None
189 self._rcv_conn = None
191 def report(self) -> QuantumReport:
192 """Return task report, should be called after process finishes and
193 before cleanup().
194 """
195 assert self.process is not None, "Process must be started"
196 assert self._rcv_conn is not None, "Process must be started"
197 try:
198 report = self._rcv_conn.recv()
199 report.exitCode = self.process.exitcode
200 except Exception:
201 # Likely due to the process killed, but there may be other reasons.
202 # Exit code should not be None, this is to keep mypy happy.
203 exitcode = self.process.exitcode if self.process.exitcode is not None else -1
204 assert self.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
205 report = QuantumReport.from_exit_code(
206 exitCode=exitcode,
207 dataId=self.qnode.quantum.dataId,
208 taskLabel=self.qnode.taskDef.label,
209 )
210 if self.terminated:
211 # Means it was killed, assume it's due to timeout
212 report.status = ExecutionStatus.TIMEOUT
213 return report
215 def failMessage(self) -> str:
216 """Return a message describing task failure"""
217 assert self.process is not None, "Process must be started"
218 assert self.process.exitcode is not None, "Process has to finish"
219 exitcode = self.process.exitcode
220 if exitcode < 0:
221 # Negative exit code means it is killed by signal
222 signum = -exitcode
223 msg = f"Task {self} failed, killed by signal {signum}"
224 # Just in case this is some very odd signal, expect ValueError
225 try:
226 strsignal = signal.strsignal(signum)
227 msg = f"{msg} ({strsignal})"
228 except ValueError:
229 pass
230 elif exitcode > 0:
231 msg = f"Task {self} failed, exit code={exitcode}"
232 else:
233 msg = ""
234 return msg
236 def __str__(self) -> str:
237 return f"<{self.qnode.taskDef} dataId={self.qnode.quantum.dataId}>"
240class _JobList:
241 """Simple list of _Job instances with few convenience methods.
243 Parameters
244 ----------
245 iterable : iterable of `~lsst.pipe.base.QuantumNode`
246 Sequence of Quanta to execute. This has to be ordered according to
247 task dependencies.
248 """
250 def __init__(self, iterable: Iterable[QuantumNode]):
251 self.jobs = [_Job(qnode) for qnode in iterable]
252 self.pending = self.jobs[:]
253 self.running: list[_Job] = []
254 self.finishedNodes: set[QuantumNode] = set()
255 self.failedNodes: set[QuantumNode] = set()
256 self.timedOutNodes: set[QuantumNode] = set()
258 def submit(
259 self, job: _Job, butler: Butler, quantumExecutor: QuantumExecutor, startMethod: Optional[str] = None
260 ) -> None:
261 """Submit one more job for execution
263 Parameters
264 ----------
265 job : `_Job`
266 Job to submit.
267 butler : `lsst.daf.butler.Butler`
268 Data butler instance.
269 quantumExecutor : `QuantumExecutor`
270 Executor for single quantum.
271 startMethod : `str`, optional
272 Start method from `multiprocessing` module.
273 """
274 # this will raise if job is not in pending list
275 self.pending.remove(job)
276 job.start(butler, quantumExecutor, startMethod)
277 self.running.append(job)
279 def setJobState(self, job: _Job, state: JobState) -> None:
280 """Update job state.
282 Parameters
283 ----------
284 job : `_Job`
285 Job to submit.
286 state : `JobState`
287 New job state, note that only FINISHED, FAILED, TIMED_OUT, or
288 FAILED_DEP state is acceptable.
289 """
290 allowedStates = (JobState.FINISHED, JobState.FAILED, JobState.TIMED_OUT, JobState.FAILED_DEP)
291 assert state in allowedStates, f"State {state} not allowed here"
293 # remove job from pending/running lists
294 if job.state == JobState.PENDING:
295 self.pending.remove(job)
296 elif job.state == JobState.RUNNING:
297 self.running.remove(job)
299 qnode = job.qnode
300 # it should not be in any of these, but just in case
301 self.finishedNodes.discard(qnode)
302 self.failedNodes.discard(qnode)
303 self.timedOutNodes.discard(qnode)
305 job._state = state
306 if state == JobState.FINISHED:
307 self.finishedNodes.add(qnode)
308 elif state == JobState.FAILED:
309 self.failedNodes.add(qnode)
310 elif state == JobState.FAILED_DEP:
311 self.failedNodes.add(qnode)
312 elif state == JobState.TIMED_OUT:
313 self.failedNodes.add(qnode)
314 self.timedOutNodes.add(qnode)
315 else:
316 raise ValueError(f"Unexpected state value: {state}")
318 def cleanup(self) -> None:
319 """Do periodic cleanup for jobs that did not finish correctly.
321 If timed out jobs are killed but take too long to stop then regular
322 cleanup will not work for them. Here we check all timed out jobs
323 periodically and do cleanup if they managed to die by this time.
324 """
325 for job in self.jobs:
326 if job.state == JobState.TIMED_OUT and job.process is not None:
327 job.cleanup()
330class MPGraphExecutorError(Exception):
331 """Exception class for errors raised by MPGraphExecutor."""
333 pass
336class MPTimeoutError(MPGraphExecutorError):
337 """Exception raised when task execution times out."""
339 pass
342class MPGraphExecutor(QuantumGraphExecutor):
343 """Implementation of QuantumGraphExecutor using same-host multiprocess
344 execution of Quanta.
346 Parameters
347 ----------
348 numProc : `int`
349 Number of processes to use for executing tasks.
350 timeout : `float`
351 Time in seconds to wait for tasks to finish.
352 quantumExecutor : `QuantumExecutor`
353 Executor for single quantum. For multiprocess-style execution when
354 ``numProc`` is greater than one this instance must support pickle.
355 startMethod : `str`, optional
356 Start method from `multiprocessing` module, `None` selects the best
357 one for current platform.
358 failFast : `bool`, optional
359 If set to ``True`` then stop processing on first error from any task.
360 pdb : `str`, optional
361 Debugger to import and use (via the ``post_mortem`` function) in the
362 event of an exception.
363 executionGraphFixup : `ExecutionGraphFixup`, optional
364 Instance used for modification of execution graph.
365 """
367 def __init__(
368 self,
369 numProc: int,
370 timeout: float,
371 quantumExecutor: QuantumExecutor,
372 *,
373 startMethod: Optional[str] = None,
374 failFast: bool = False,
375 pdb: Optional[str] = None,
376 executionGraphFixup: Optional[ExecutionGraphFixup] = None,
377 ):
378 self.numProc = numProc
379 self.timeout = timeout
380 self.quantumExecutor = quantumExecutor
381 self.failFast = failFast
382 self.pdb = pdb
383 self.executionGraphFixup = executionGraphFixup
384 self.report: Optional[Report] = None
386 # We set default start method as spawn for MacOS and fork for Linux;
387 # None for all other platforms to use multiprocessing default.
388 if startMethod is None:
389 methods = dict(linux="fork", darwin="spawn")
390 startMethod = methods.get(sys.platform)
391 self.startMethod = startMethod
393 def execute(self, graph: QuantumGraph, butler: Butler) -> None:
394 # Docstring inherited from QuantumGraphExecutor.execute
395 graph = self._fixupQuanta(graph)
396 self.report = Report()
397 try:
398 if self.numProc > 1:
399 self._executeQuantaMP(graph, butler, self.report)
400 else:
401 self._executeQuantaInProcess(graph, butler, self.report)
402 except Exception as exc:
403 self.report.set_exception(exc)
404 raise
406 def _fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
407 """Call fixup code to modify execution graph.
409 Parameters
410 ----------
411 graph : `QuantumGraph`
412 `QuantumGraph` to modify
414 Returns
415 -------
416 graph : `QuantumGraph`
417 Modified `QuantumGraph`.
419 Raises
420 ------
421 MPGraphExecutorError
422 Raised if execution graph cannot be ordered after modification,
423 i.e. it has dependency cycles.
424 """
425 if not self.executionGraphFixup:
426 return graph
428 _LOG.debug("Call execution graph fixup method")
429 graph = self.executionGraphFixup.fixupQuanta(graph)
431 # Detect if there is now a cycle created within the graph
432 if graph.findCycle():
433 raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
435 return graph
437 def _executeQuantaInProcess(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
438 """Execute all Quanta in current process.
440 Parameters
441 ----------
442 graph : `QuantumGraph`
443 `QuantumGraph` that is to be executed
444 butler : `lsst.daf.butler.Butler`
445 Data butler instance
446 report : `Report`
447 Object for reporting execution status.
448 """
449 successCount, totalCount = 0, len(graph)
450 failedNodes: set[QuantumNode] = set()
451 for qnode in graph:
453 assert qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
455 # Any failed inputs mean that the quantum has to be skipped.
456 inputNodes = graph.determineInputsToQuantumNode(qnode)
457 if inputNodes & failedNodes:
458 _LOG.error(
459 "Upstream job failed for task <%s dataId=%s>, skipping this task.",
460 qnode.taskDef,
461 qnode.quantum.dataId,
462 )
463 failedNodes.add(qnode)
464 failed_quantum_report = QuantumReport(
465 status=ExecutionStatus.SKIPPED, dataId=qnode.quantum.dataId, taskLabel=qnode.taskDef.label
466 )
467 report.quantaReports.append(failed_quantum_report)
468 continue
470 _LOG.debug("Executing %s", qnode)
471 try:
472 self.quantumExecutor.execute(qnode.taskDef, qnode.quantum, butler)
473 successCount += 1
474 except Exception as exc:
475 if self.pdb and sys.stdin.isatty() and sys.stdout.isatty():
476 _LOG.error(
477 "Task <%s dataId=%s> failed; dropping into pdb.",
478 qnode.taskDef,
479 qnode.quantum.dataId,
480 exc_info=exc,
481 )
482 try:
483 pdb = importlib.import_module(self.pdb)
484 except ImportError as imp_exc:
485 raise MPGraphExecutorError(
486 f"Unable to import specified debugger module ({self.pdb}): {imp_exc}"
487 ) from exc
488 if not hasattr(pdb, "post_mortem"):
489 raise MPGraphExecutorError(
490 f"Specified debugger module ({self.pdb}) can't debug with post_mortem",
491 ) from exc
492 pdb.post_mortem(exc.__traceback__)
493 failedNodes.add(qnode)
494 report.status = ExecutionStatus.FAILURE
495 if self.failFast:
496 raise MPGraphExecutorError(
497 f"Task <{qnode.taskDef} dataId={qnode.quantum.dataId}> failed."
498 ) from exc
499 else:
500 # Note that there could be exception safety issues, which
501 # we presently ignore.
502 _LOG.error(
503 "Task <%s dataId=%s> failed; processing will continue for remaining tasks.",
504 qnode.taskDef,
505 qnode.quantum.dataId,
506 exc_info=exc,
507 )
508 finally:
509 # sqlalchemy has some objects that can last until a garbage
510 # collection cycle is run, which can happen at unpredictable
511 # times, run a collection loop here explicitly.
512 gc.collect()
514 quantum_report = self.quantumExecutor.getReport()
515 if quantum_report:
516 report.quantaReports.append(quantum_report)
518 _LOG.info(
519 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
520 successCount,
521 len(failedNodes),
522 totalCount - successCount - len(failedNodes),
523 totalCount,
524 )
526 # Raise an exception if there were any failures.
527 if failedNodes:
528 raise MPGraphExecutorError("One or more tasks failed during execution.")
530 def _executeQuantaMP(self, graph: QuantumGraph, butler: Butler, report: Report) -> None:
531 """Execute all Quanta in separate processes.
533 Parameters
534 ----------
535 graph : `QuantumGraph`
536 `QuantumGraph` that is to be executed.
537 butler : `lsst.daf.butler.Butler`
538 Data butler instance
539 report : `Report`
540 Object for reporting execution status.
541 """
543 disable_implicit_threading() # To prevent thread contention
545 _LOG.debug("Using %r for multiprocessing start method", self.startMethod)
547 # re-pack input quantum data into jobs list
548 jobs = _JobList(graph)
550 # check that all tasks can run in sub-process
551 for job in jobs.jobs:
552 taskDef = job.qnode.taskDef
553 if not taskDef.taskClass.canMultiprocess:
554 raise MPGraphExecutorError(
555 f"Task {taskDef.taskName} does not support multiprocessing; use single process"
556 )
558 finishedCount, failedCount = 0, 0
559 while jobs.pending or jobs.running:
561 _LOG.debug("#pendingJobs: %s", len(jobs.pending))
562 _LOG.debug("#runningJobs: %s", len(jobs.running))
564 # See if any jobs have finished
565 for job in jobs.running:
566 assert job.process is not None, "Process cannot be None"
567 if not job.process.is_alive():
568 _LOG.debug("finished: %s", job)
569 # finished
570 exitcode = job.process.exitcode
571 quantum_report = job.report()
572 report.quantaReports.append(quantum_report)
573 if exitcode == 0:
574 jobs.setJobState(job, JobState.FINISHED)
575 job.cleanup()
576 _LOG.debug("success: %s took %.3f seconds", job, time.time() - job.started)
577 else:
578 if job.terminated:
579 # Was killed due to timeout.
580 if report.status == ExecutionStatus.SUCCESS:
581 # Do not override global FAILURE status
582 report.status = ExecutionStatus.TIMEOUT
583 message = f"Timeout ({self.timeout} sec) for task {job}, task is killed"
584 jobs.setJobState(job, JobState.TIMED_OUT)
585 else:
586 report.status = ExecutionStatus.FAILURE
587 # failMessage() has to be called before cleanup()
588 message = job.failMessage()
589 jobs.setJobState(job, JobState.FAILED)
591 job.cleanup()
592 _LOG.debug("failed: %s", job)
593 if self.failFast or exitcode == InvalidQuantumError.EXIT_CODE:
594 # stop all running jobs
595 for stopJob in jobs.running:
596 if stopJob is not job:
597 stopJob.stop()
598 if job.state is JobState.TIMED_OUT:
599 raise MPTimeoutError(f"Timeout ({self.timeout} sec) for task {job}.")
600 else:
601 raise MPGraphExecutorError(message)
602 else:
603 _LOG.error("%s; processing will continue for remaining tasks.", message)
604 else:
605 # check for timeout
606 now = time.time()
607 if now - job.started > self.timeout:
608 # Try to kill it, and there is a chance that it
609 # finishes successfully before it gets killed. Exit
610 # status is handled by the code above on next
611 # iteration.
612 _LOG.debug("Terminating job %s due to timeout", job)
613 job.stop()
615 # Fail jobs whose inputs failed, this may need several iterations
616 # if the order is not right, will be done in the next loop.
617 if jobs.failedNodes:
618 for job in jobs.pending:
619 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
620 assert job.qnode.quantum.dataId is not None, "Quantum DataId cannot be None"
621 if jobInputNodes & jobs.failedNodes:
622 quantum_report = QuantumReport(
623 status=ExecutionStatus.SKIPPED,
624 dataId=job.qnode.quantum.dataId,
625 taskLabel=job.qnode.taskDef.label,
626 )
627 report.quantaReports.append(quantum_report)
628 jobs.setJobState(job, JobState.FAILED_DEP)
629 _LOG.error("Upstream job failed for task %s, skipping this task.", job)
631 # see if we can start more jobs
632 if len(jobs.running) < self.numProc:
633 for job in jobs.pending:
634 jobInputNodes = graph.determineInputsToQuantumNode(job.qnode)
635 if jobInputNodes <= jobs.finishedNodes:
636 # all dependencies have completed, can start new job
637 if len(jobs.running) < self.numProc:
638 _LOG.debug("Submitting %s", job)
639 jobs.submit(job, butler, self.quantumExecutor, self.startMethod)
640 if len(jobs.running) >= self.numProc:
641 # Cannot start any more jobs, wait until something
642 # finishes.
643 break
645 # Do cleanup for timed out jobs if necessary.
646 jobs.cleanup()
648 # Print progress message if something changed.
649 newFinished, newFailed = len(jobs.finishedNodes), len(jobs.failedNodes)
650 if (finishedCount, failedCount) != (newFinished, newFailed):
651 finishedCount, failedCount = newFinished, newFailed
652 totalCount = len(jobs.jobs)
653 _LOG.info(
654 "Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
655 finishedCount,
656 failedCount,
657 totalCount - finishedCount - failedCount,
658 totalCount,
659 )
661 # Here we want to wait until one of the running jobs completes
662 # but multiprocessing does not provide an API for that, for now
663 # just sleep a little bit and go back to the loop.
664 if jobs.running:
665 time.sleep(0.1)
667 if jobs.failedNodes:
668 # print list of failed jobs
669 _LOG.error("Failed jobs:")
670 for job in jobs.jobs:
671 if job.state != JobState.FINISHED:
672 _LOG.error(" - %s: %s", job.state.name, job)
674 # if any job failed raise an exception
675 if jobs.failedNodes == jobs.timedOutNodes:
676 raise MPTimeoutError("One or more tasks timed out during execution.")
677 else:
678 raise MPGraphExecutorError("One or more tasks failed or timed out during execution.")
680 def getReport(self) -> Optional[Report]:
681 # Docstring inherited from base class
682 if self.report is None:
683 raise RuntimeError("getReport() called before execute()")
684 return self.report