Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _workers.py: 63%
119 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:59 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:59 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("Event", "Queue", "SpawnWorkerFactory", "ThreadWorkerFactory", "Worker", "WorkerFactory")
32import multiprocessing.context
33import multiprocessing.synchronize
34import queue
35import threading
36from abc import ABC, abstractmethod
37from collections.abc import Callable
38from typing import Any, Literal, overload
40_TINY_TIMEOUT = 0.01
42type Event = threading.Event | multiprocessing.synchronize.Event
45class Worker(ABC):
46 """A thin abstraction over `threading.Thread` and `multiprocessing.Process`
47 that also provides a variable to track whether it reported successful
48 completion.
49 """
51 def __init__(self) -> None:
52 self.successful = False
54 @property
55 @abstractmethod
56 def name(self) -> str:
57 """Name of the worker, as assigned at creation."""
58 raise NotImplementedError()
60 @abstractmethod
61 def join(self, timeout: float | None = None) -> None:
62 """Wait for the worker to finish.
64 Parameters
65 ----------
66 timeout : `float`, optional
67 How long to wait in seconds. If the timeout is exceeded,
68 `is_alive` can be used to see whether the worker finished or not.
69 """
70 raise NotImplementedError()
72 @abstractmethod
73 def is_alive(self) -> bool:
74 """Return whether the worker is still running."""
75 raise NotImplementedError()
77 def kill(self) -> None:
78 """Kill the worker, if possible."""
81class Queue[T](ABC):
82 """A thin abstraction over `queue.Queue` and `multiprocessing.Queue` that
83 provides better control over disorderly shutdowns.
84 """
86 @overload
87 def get(self, *, block: Literal[True]) -> T: ... 87 ↛ exitline 87 didn't return from function 'get' because
89 @overload
90 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ... 90 ↛ exitline 90 didn't return from function 'get' because
92 @abstractmethod
93 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
94 """Get an object or return `None` if the queue is empty.
96 Parameters
97 ----------
98 timeout : `float` or `None`, optional
99 Maximum number of seconds to wait while blocking.
100 block : `bool`, optional
101 Whether to block until an object is available.
103 Returns
104 -------
105 obj : `object` or `None`
106 Object from the queue, or `None` if it was empty. Note that this
107 is different from the behavior of the built-in Python queues,
108 which raise `queue.Empty` instead.
109 """
110 raise NotImplementedError()
112 @abstractmethod
113 def put(self, item: T) -> None:
114 """Add an object to the queue.
116 Parameters
117 ----------
118 item : `object`
119 Item to add.
120 """
121 raise NotImplementedError()
123 def clear(self) -> bool:
124 """Clear out all objects currently on the queue.
126 This does not guarantee that more objects will not be added later.
127 """
128 found_anything: bool = False
129 while self.get() is not None:
130 found_anything = True
131 return found_anything
133 def kill(self) -> None:
134 """Prepare a queue for a disorderly shutdown, without assuming that
135 any other workers using it are still alive and functioning.
136 """
139class WorkerFactory(ABC):
140 """A simple abstract interface that can be implemented by both threading
141 and multiprocessing.
142 """
144 @abstractmethod
145 def make_queue(self) -> Queue[Any]:
146 """Make an empty queue that can be used to pass objects between
147 workers created by this factory.
148 """
149 raise NotImplementedError()
151 @abstractmethod
152 def make_event(self) -> Event:
153 """Make an event that can be used to communicate a boolean state change
154 to workers created by this factory.
155 """
156 raise NotImplementedError()
158 @abstractmethod
159 def make_worker(
160 self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
161 ) -> Worker:
162 """Make a worker that runs the given callable.
164 Parameters
165 ----------
166 target : `~collections.abc.Callable`
167 A callable to invoke on the worker.
168 args : `tuple`
169 Positional arguments to pass to the callable.
170 name : `str`, optional
171 Human-readable name for the worker.
173 Returns
174 -------
175 worker : `Worker`
176 Process or thread that is already running the given callable.
177 """
178 raise NotImplementedError()
181class _ThreadWorker(Worker):
182 """An implementation of `Worker` backed by the `threading` module."""
184 def __init__(self, thread: threading.Thread):
185 super().__init__()
186 self._thread = thread
188 @property
189 def name(self) -> str:
190 return self._thread.name
192 def join(self, timeout: float | None = None) -> None:
193 self._thread.join(timeout=timeout)
195 def is_alive(self) -> bool:
196 return self._thread.is_alive()
199class _ThreadQueue[T](Queue[T]):
200 def __init__(self) -> None:
201 self._impl = queue.Queue[T]()
203 @overload
204 def get(self, *, block: Literal[True]) -> T: ... 204 ↛ exitline 204 didn't return from function 'get' because
206 @overload
207 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ... 207 ↛ exitline 207 didn't return from function 'get' because
209 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
210 try:
211 return self._impl.get(block=block, timeout=timeout)
212 except queue.Empty:
213 return None
215 def put(self, item: T) -> None:
216 self._impl.put(item, block=False)
219class ThreadWorkerFactory(WorkerFactory):
220 """An implementation of `WorkerFactory` backed by the `threading`
221 module.
222 """
224 def make_queue(self) -> Queue[Any]:
225 return _ThreadQueue()
227 def make_event(self) -> Event:
228 return threading.Event()
230 def make_worker(
231 self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
232 ) -> Worker:
233 thread = threading.Thread(target=target, args=args, name=name)
234 thread.start()
235 return _ThreadWorker(thread)
238class _ProcessWorker(Worker):
239 """An implementation of `Worker` backed by the `multiprocessing` module."""
241 def __init__(self, process: multiprocessing.context.SpawnProcess):
242 super().__init__()
243 self._process = process
245 @property
246 def name(self) -> str:
247 return self._process.name
249 def join(self, timeout: float | None = None) -> None:
250 self._process.join(timeout=timeout)
252 def is_alive(self) -> bool:
253 return self._process.is_alive()
255 def kill(self) -> None:
256 """Kill the worker, if possible."""
257 self._process.kill()
260class _ProcessQueue[T](Queue[T]):
261 def __init__(self, impl: multiprocessing.Queue):
262 self._impl = impl
264 @overload
265 def get(self, *, block: Literal[True]) -> T: ... 265 ↛ exitline 265 didn't return from function 'get' because
267 @overload
268 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ... 268 ↛ exitline 268 didn't return from function 'get' because
270 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
271 try:
272 return self._impl.get(block=block, timeout=timeout)
273 except queue.Empty:
274 return None
276 def put(self, item: T) -> None:
277 self._impl.put(item, block=False)
279 def kill(self) -> None:
280 self._impl.cancel_join_thread()
281 self._impl.close()
284class SpawnWorkerFactory(WorkerFactory):
285 """An implementation of `WorkerFactory` backed by the `multiprocessing`
286 module, with new processes started by spawning.
287 """
289 def __init__(self) -> None:
290 self._ctx = multiprocessing.get_context("spawn")
292 def make_queue(self) -> Queue[Any]:
293 return _ProcessQueue(self._ctx.Queue())
295 def make_event(self) -> Event:
296 return self._ctx.Event()
298 def make_worker(
299 self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
300 ) -> Worker:
301 process = self._ctx.Process(target=target, args=args, name=name)
302 process.start()
303 return _ProcessWorker(process)