Coverage for python / lsst / pipe / base / quantum_graph / aggregator / _workers.py: 63%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:59 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("Event", "Queue", "SpawnWorkerFactory", "ThreadWorkerFactory", "Worker", "WorkerFactory") 

31 

32import multiprocessing.context 

33import multiprocessing.synchronize 

34import queue 

35import threading 

36from abc import ABC, abstractmethod 

37from collections.abc import Callable 

38from typing import Any, Literal, overload 

39 

40_TINY_TIMEOUT = 0.01 

41 

42type Event = threading.Event | multiprocessing.synchronize.Event 

43 

44 

45class Worker(ABC): 

46 """A thin abstraction over `threading.Thread` and `multiprocessing.Process` 

47 that also provides a variable to track whether it reported successful 

48 completion. 

49 """ 

50 

51 def __init__(self) -> None: 

52 self.successful = False 

53 

54 @property 

55 @abstractmethod 

56 def name(self) -> str: 

57 """Name of the worker, as assigned at creation.""" 

58 raise NotImplementedError() 

59 

60 @abstractmethod 

61 def join(self, timeout: float | None = None) -> None: 

62 """Wait for the worker to finish. 

63 

64 Parameters 

65 ---------- 

66 timeout : `float`, optional 

67 How long to wait in seconds. If the timeout is exceeded, 

68 `is_alive` can be used to see whether the worker finished or not. 

69 """ 

70 raise NotImplementedError() 

71 

72 @abstractmethod 

73 def is_alive(self) -> bool: 

74 """Return whether the worker is still running.""" 

75 raise NotImplementedError() 

76 

77 def kill(self) -> None: 

78 """Kill the worker, if possible.""" 

79 

80 

81class Queue[T](ABC): 

82 """A thin abstraction over `queue.Queue` and `multiprocessing.Queue` that 

83 provides better control over disorderly shutdowns. 

84 """ 

85 

86 @overload 

87 def get(self, *, block: Literal[True]) -> T: ... 87 ↛ exitline 87 didn't return from function 'get' because

88 

89 @overload 

90 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ... 90 ↛ exitline 90 didn't return from function 'get' because

91 

92 @abstractmethod 

93 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: 

94 """Get an object or return `None` if the queue is empty. 

95 

96 Parameters 

97 ---------- 

98 timeout : `float` or `None`, optional 

99 Maximum number of seconds to wait while blocking. 

100 block : `bool`, optional 

101 Whether to block until an object is available. 

102 

103 Returns 

104 ------- 

105 obj : `object` or `None` 

106 Object from the queue, or `None` if it was empty. Note that this 

107 is different from the behavior of the built-in Python queues, 

108 which raise `queue.Empty` instead. 

109 """ 

110 raise NotImplementedError() 

111 

112 @abstractmethod 

113 def put(self, item: T) -> None: 

114 """Add an object to the queue. 

115 

116 Parameters 

117 ---------- 

118 item : `object` 

119 Item to add. 

120 """ 

121 raise NotImplementedError() 

122 

123 def clear(self) -> bool: 

124 """Clear out all objects currently on the queue. 

125 

126 This does not guarantee that more objects will not be added later. 

127 """ 

128 found_anything: bool = False 

129 while self.get() is not None: 

130 found_anything = True 

131 return found_anything 

132 

133 def kill(self) -> None: 

134 """Prepare a queue for a disorderly shutdown, without assuming that 

135 any other workers using it are still alive and functioning. 

136 """ 

137 

138 

139class WorkerFactory(ABC): 

140 """A simple abstract interface that can be implemented by both threading 

141 and multiprocessing. 

142 """ 

143 

144 @abstractmethod 

145 def make_queue(self) -> Queue[Any]: 

146 """Make an empty queue that can be used to pass objects between 

147 workers created by this factory. 

148 """ 

149 raise NotImplementedError() 

150 

151 @abstractmethod 

152 def make_event(self) -> Event: 

153 """Make an event that can be used to communicate a boolean state change 

154 to workers created by this factory. 

155 """ 

156 raise NotImplementedError() 

157 

158 @abstractmethod 

159 def make_worker( 

160 self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None 

161 ) -> Worker: 

162 """Make a worker that runs the given callable. 

163 

164 Parameters 

165 ---------- 

166 target : `~collections.abc.Callable` 

167 A callable to invoke on the worker. 

168 args : `tuple` 

169 Positional arguments to pass to the callable. 

170 name : `str`, optional 

171 Human-readable name for the worker. 

172 

173 Returns 

174 ------- 

175 worker : `Worker` 

176 Process or thread that is already running the given callable. 

177 """ 

178 raise NotImplementedError() 

179 

180 

181class _ThreadWorker(Worker): 

182 """An implementation of `Worker` backed by the `threading` module.""" 

183 

184 def __init__(self, thread: threading.Thread): 

185 super().__init__() 

186 self._thread = thread 

187 

188 @property 

189 def name(self) -> str: 

190 return self._thread.name 

191 

192 def join(self, timeout: float | None = None) -> None: 

193 self._thread.join(timeout=timeout) 

194 

195 def is_alive(self) -> bool: 

196 return self._thread.is_alive() 

197 

198 

199class _ThreadQueue[T](Queue[T]): 

200 def __init__(self) -> None: 

201 self._impl = queue.Queue[T]() 

202 

203 @overload 

204 def get(self, *, block: Literal[True]) -> T: ... 204 ↛ exitline 204 didn't return from function 'get' because

205 

206 @overload 

207 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ... 207 ↛ exitline 207 didn't return from function 'get' because

208 

209 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: 

210 try: 

211 return self._impl.get(block=block, timeout=timeout) 

212 except queue.Empty: 

213 return None 

214 

215 def put(self, item: T) -> None: 

216 self._impl.put(item, block=False) 

217 

218 

219class ThreadWorkerFactory(WorkerFactory): 

220 """An implementation of `WorkerFactory` backed by the `threading` 

221 module. 

222 """ 

223 

224 def make_queue(self) -> Queue[Any]: 

225 return _ThreadQueue() 

226 

227 def make_event(self) -> Event: 

228 return threading.Event() 

229 

230 def make_worker( 

231 self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None 

232 ) -> Worker: 

233 thread = threading.Thread(target=target, args=args, name=name) 

234 thread.start() 

235 return _ThreadWorker(thread) 

236 

237 

238class _ProcessWorker(Worker): 

239 """An implementation of `Worker` backed by the `multiprocessing` module.""" 

240 

241 def __init__(self, process: multiprocessing.context.SpawnProcess): 

242 super().__init__() 

243 self._process = process 

244 

245 @property 

246 def name(self) -> str: 

247 return self._process.name 

248 

249 def join(self, timeout: float | None = None) -> None: 

250 self._process.join(timeout=timeout) 

251 

252 def is_alive(self) -> bool: 

253 return self._process.is_alive() 

254 

255 def kill(self) -> None: 

256 """Kill the worker, if possible.""" 

257 self._process.kill() 

258 

259 

260class _ProcessQueue[T](Queue[T]): 

261 def __init__(self, impl: multiprocessing.Queue): 

262 self._impl = impl 

263 

264 @overload 

265 def get(self, *, block: Literal[True]) -> T: ... 265 ↛ exitline 265 didn't return from function 'get' because

266 

267 @overload 

268 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ... 268 ↛ exitline 268 didn't return from function 'get' because

269 

270 def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: 

271 try: 

272 return self._impl.get(block=block, timeout=timeout) 

273 except queue.Empty: 

274 return None 

275 

276 def put(self, item: T) -> None: 

277 self._impl.put(item, block=False) 

278 

279 def kill(self) -> None: 

280 self._impl.cancel_join_thread() 

281 self._impl.close() 

282 

283 

284class SpawnWorkerFactory(WorkerFactory): 

285 """An implementation of `WorkerFactory` backed by the `multiprocessing` 

286 module, with new processes started by spawning. 

287 """ 

288 

289 def __init__(self) -> None: 

290 self._ctx = multiprocessing.get_context("spawn") 

291 

292 def make_queue(self) -> Queue[Any]: 

293 return _ProcessQueue(self._ctx.Queue()) 

294 

295 def make_event(self) -> Event: 

296 return self._ctx.Event() 

297 

298 def make_worker( 

299 self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None 

300 ) -> Worker: 

301 process = self._ctx.Process(target=target, args=args, name=name) 

302 process.start() 

303 return _ProcessWorker(process)