Coverage for python/lsst/ctrl/pool/pool.py: 18%

597 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-23 18:38 +0000

1# MPI process pool 

2# Copyright 2013 Paul A. Price 

3# 

4# This program is free software: you can redistribute it and/or modify 

5# it under the terms of the GNU General Public License as published by 

6# the Free Software Foundation, either version 3 of the License, or 

7# (at your option) any later version. 

8# 

9# This program is distributed in the hope that it will be useful, 

10# but WITHOUT ANY WARRANTY; without even the implied warranty of 

11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

12# GNU General Public License for more details. 

13# 

14# You should have received a copy of the GNU General Public License 

15# along with this program. If not, see <http://www.gnu.org/copyleft/gpl.html> 

16# 

17 

18import os 

19import sys 

20import time 

21import types 

22import copyreg 

23import threading 

24from functools import wraps, partial 

25from contextlib import contextmanager 

26 

27import mpi4py.MPI as mpi 

28 

29from lsst.pipe.base import Struct 

30 

31 

32__all__ = ["Comm", "Pool", "startPool", "setBatchType", "getBatchType", "abortOnError", "NODE", ] 

33 

34NODE = "%s:%d" % (os.uname()[1], os.getpid()) # Name of node 

35 

36 

37def unpickleInstanceMethod(obj, name): 

38 """Unpickle an instance method 

39 

40 This has to be a named function rather than a lambda because 

41 pickle needs to find it. 

42 """ 

43 return getattr(obj, name) 

44 

45 

46def pickleInstanceMethod(method): 

47 """Pickle an instance method 

48 

49 The instance method is divided into the object and the 

50 method name. 

51 """ 

52 obj = method.__self__ 

53 name = method.__name__ 

54 return unpickleInstanceMethod, (obj, name) 

55 

56 

57copyreg.pickle(types.MethodType, pickleInstanceMethod) 

58 

59 

60def unpickleFunction(moduleName, funcName): 

61 """Unpickle a function 

62 

63 This has to be a named function rather than a lambda because 

64 pickle needs to find it. 

65 """ 

66 import importlib 

67 module = importlib.import_module(moduleName) 

68 return getattr(module, funcName) 

69 

70 

71def pickleFunction(function): 

72 """Pickle a function 

73 

74 This assumes that we can recreate the function object by grabbing 

75 it from the proper module. This may be violated if the function 

76 is a lambda or in __main__. In that case, I recommend recasting 

77 the function as an object with a __call__ method. 

78 

79 Another problematic case may be a wrapped (e.g., decorated) method 

80 in a class: the 'method' is then a function, and recreating it is 

81 not as easy as we assume here. 

82 """ 

83 moduleName = function.__module__ 

84 funcName = function.__name__ 

85 return unpickleFunction, (moduleName, funcName) 

86 

87 

88copyreg.pickle(types.FunctionType, pickleFunction) 

89 

90try: 

91 _batchType 

92except NameError: 

93 _batchType = "unknown" 

94 

95 

96def getBatchType(): 

97 """Return a string giving the type of batch system in use""" 

98 return _batchType 

99 

100 

101def setBatchType(batchType): 

102 """Return a string giving the type of batch system in use""" 

103 global _batchType 

104 _batchType = batchType 

105 

106 

107def abortOnError(func): 

108 """Function decorator to throw an MPI abort on an unhandled exception""" 

109 @wraps(func) 

110 def wrapper(*args, **kwargs): 

111 try: 

112 return func(*args, **kwargs) 

113 except Exception as e: 

114 sys.stderr.write("%s on %s in %s: %s\n" % (type(e).__name__, NODE, func.__name__, e)) 

115 import traceback 

116 traceback.print_exc(file=sys.stderr) 

117 sys.stdout.flush() 

118 sys.stderr.flush() 

119 if getBatchType() is not None: 

120 mpi.COMM_WORLD.Abort(1) 

121 else: 

122 raise 

123 return wrapper 

124 

125 

126class PickleHolder: 

127 """Singleton to hold what's about to be pickled. 

128 

129 We hold onto the object in case there's trouble pickling, 

130 so we can figure out what class in particular is causing 

131 the trouble. 

132 

133 The held object is in the 'obj' attribute. 

134 

135 Here we use the __new__-style singleton pattern, because 

136 we specifically want __init__ to be called each time. 

137 """ 

138 

139 _instance = None 

140 

141 def __new__(cls, hold=None): 

142 if cls._instance is None: 

143 cls._instance = super(PickleHolder, cls).__new__(cls) 

144 cls._instance.__init__(hold) 

145 cls._instance.obj = None 

146 return cls._instance 

147 

148 def __init__(self, hold=None): 

149 """Hold onto new object""" 

150 if hold is not None: 

151 self.obj = hold 

152 

153 def __enter__(self): 

154 pass 

155 

156 def __exit__(self, excType, excVal, tb): 

157 """Drop held object if there were no problems""" 

158 if excType is None: 

159 self.obj = None 

160 

161 

162def guessPickleObj(): 

163 """Try to guess what's not pickling after an exception 

164 

165 This tends to work if the problem is coming from the 

166 regular pickle module. If it's coming from the bowels 

167 of mpi4py, there's not much that can be done. 

168 """ 

169 import sys 

170 excType, excValue, tb = sys.exc_info() 

171 # Build a stack of traceback elements 

172 stack = [] 

173 while tb: 

174 stack.append(tb) 

175 tb = tb.tb_next 

176 

177 try: 

178 # This is the code version of a my way to find what's not pickling in pdb. 

179 # This should work if it's coming from the regular pickle module, and they 

180 # haven't changed the variable names since python 2.7.3. 

181 return stack[-2].tb_frame.f_locals["obj"] 

182 except Exception: 

183 return None 

184 

185 

186@contextmanager 

187def pickleSniffer(abort=False): 

188 """Context manager to sniff out pickle problems 

189 

190 If there's a pickle error, you're normally told what the problem 

191 class is. However, all SWIG objects are reported as "SwigPyObject". 

192 In order to figure out which actual SWIG-ed class is causing 

193 problems, we need to go digging. 

194 

195 Use like this: 

196 

197 with pickleSniffer(): 

198 someOperationInvolvingPickle() 

199 

200 If 'abort' is True, will call MPI abort in the event of problems. 

201 """ 

202 try: 

203 yield 

204 except Exception as e: 

205 if "SwigPyObject" not in str(e) or "pickle" not in str(e): 

206 raise 

207 import sys 

208 import traceback 

209 

210 sys.stderr.write("Pickling error detected: %s\n" % e) 

211 traceback.print_exc(file=sys.stderr) 

212 obj = guessPickleObj() 

213 heldObj = PickleHolder().obj 

214 if obj is None and heldObj is not None: 

215 # Try to reproduce using what was being pickled using the regular pickle module, 

216 # so we've got a chance of figuring out what the problem is. 

217 import pickle 

218 try: 

219 pickle.dumps(heldObj) 

220 sys.stderr.write("Hmmm, that's strange: no problem with pickling held object?!?!\n") 

221 except Exception: 

222 obj = guessPickleObj() 

223 if obj is None: 

224 sys.stderr.write("Unable to determine class causing pickle problems.\n") 

225 else: 

226 sys.stderr.write("Object that could not be pickled: %s\n" % obj) 

227 if abort: 

228 if getBatchType() is not None: 

229 mpi.COMM_WORLD.Abort(1) 

230 else: 

231 sys.exit(1) 

232 

233 

234def catchPicklingError(func): 

235 """Function decorator to catch errors in pickling and print something useful""" 

236 @wraps(func) 

237 def wrapper(*args, **kwargs): 

238 with pickleSniffer(True): 

239 return func(*args, **kwargs) 

240 return wrapper 

241 

242 

243class Comm(mpi.Intracomm): 

244 """Wrapper to mpi4py's MPI.Intracomm class to avoid busy-waiting. 

245 

246 As suggested by Lisandro Dalcin at: 

247 * http://code.google.com/p/mpi4py/issues/detail?id=4 and 

248 * https://groups.google.com/forum/?fromgroups=#!topic/mpi4py/nArVuMXyyZI 

249 """ 

250 

251 def __new__(cls, comm=mpi.COMM_WORLD, recvSleep=0.1, barrierSleep=0.1): 

252 """!Construct an MPI.Comm wrapper 

253 

254 @param cls Class 

255 @param comm MPI.Intracomm to wrap a duplicate of 

256 @param recvSleep Sleep time (seconds) for recv() 

257 @param barrierSleep Sleep time (seconds) for Barrier() 

258 """ 

259 self = super(Comm, cls).__new__(cls, comm.Dup()) 

260 self._barrierComm = None # Duplicate communicator used for Barrier point-to-point checking 

261 self._recvSleep = recvSleep 

262 self._barrierSleep = barrierSleep 

263 return self 

264 

265 def recv(self, obj=None, source=0, tag=0, status=None): 

266 """Version of comm.recv() that doesn't busy-wait""" 

267 sts = mpi.Status() 

268 while not self.Iprobe(source=source, tag=tag, status=sts): 

269 time.sleep(self._recvSleep) 

270 return super(Comm, self).recv(buf=obj, source=sts.source, tag=sts.tag, status=status) 

271 

272 def send(self, obj=None, *args, **kwargs): 

273 with PickleHolder(obj): 

274 return super(Comm, self).send(obj, *args, **kwargs) 

275 

276 def _checkBarrierComm(self): 

277 """Ensure the duplicate communicator is available""" 

278 if self._barrierComm is None: 

279 self._barrierComm = self.Dup() 

280 

281 def Barrier(self, tag=0): 

282 """Version of comm.Barrier() that doesn't busy-wait 

283 

284 A duplicate communicator is used so as not to interfere with the user's own communications. 

285 """ 

286 self._checkBarrierComm() 

287 size = self._barrierComm.Get_size() 

288 if size == 1: 

289 return 

290 rank = self._barrierComm.Get_rank() 

291 mask = 1 

292 while mask < size: 

293 dst = (rank + mask) % size 

294 src = (rank - mask + size) % size 

295 req = self._barrierComm.isend(None, dst, tag) 

296 while not self._barrierComm.Iprobe(src, tag): 

297 time.sleep(self._barrierSleep) 

298 self._barrierComm.recv(None, src, tag) 

299 req.Wait() 

300 mask <<= 1 

301 

302 def broadcast(self, value, root=0): 

303 with PickleHolder(value): 

304 return super(Comm, self).bcast(value, root=root) 

305 

306 def scatter(self, dataList, root=0, tag=0): 

307 """Scatter data across the nodes 

308 

309 The default version apparently pickles the entire 'dataList', 

310 which can cause errors if the pickle size grows over 2^31 bytes 

311 due to fundamental problems with pickle in python 2. Instead, 

312 we send the data to each slave node in turn; this reduces the 

313 pickle size. 

314 

315 @param dataList List of data to distribute; one per node 

316 (including root) 

317 @param root Index of root node 

318 @param tag Message tag (integer) 

319 @return Data for this node 

320 """ 

321 if self.Get_rank() == root: 

322 for rank, data in enumerate(dataList): 

323 if rank == root: 

324 continue 

325 self.send(data, rank, tag=tag) 

326 return dataList[root] 

327 else: 

328 return self.recv(source=root, tag=tag) 

329 

330 def Free(self): 

331 if self._barrierComm is not None: 

332 self._barrierComm.Free() 

333 super(Comm, self).Free() 

334 

335 

336class NoOp: 

337 """Object to signal no operation""" 

338 pass 

339 

340 

341class Tags: 

342 """Provides tag numbers by symbolic name in attributes""" 

343 

344 def __init__(self, *nameList): 

345 self._nameList = nameList 

346 for i, name in enumerate(nameList, 1): 

347 setattr(self, name, i) 

348 

349 def __repr__(self): 

350 return self.__class__.__name__ + repr(self._nameList) 

351 

352 def __reduce__(self): 

353 return self.__class__, tuple(self._nameList) 

354 

355 

356class Cache(Struct): 

357 """An object to hold stuff between different scatter calls 

358 

359 Includes a communicator by default, to allow intercommunication 

360 between nodes. 

361 """ 

362 

363 def __init__(self, comm): 

364 super(Cache, self).__init__(comm=comm) 

365 

366 

367class SingletonMeta(type): 

368 """!Metaclass to produce a singleton 

369 

370 Doing a singleton mixin without a metaclass (via __new__) is 

371 annoying because the user has to name his __init__ something else 

372 (otherwise it's called every time, which undoes any changes). 

373 Using this metaclass, the class's __init__ is called exactly once. 

374 

375 Because this is a metaclass, note that: 

376 * "self" here is the class 

377 * "__init__" is making the class (it's like the body of the 

378 class definition). 

379 * "__call__" is making an instance of the class (it's like 

380 "__new__" in the class). 

381 """ 

382 

383 def __init__(cls, name, bases, dict_): 

384 super(SingletonMeta, cls).__init__(name, bases, dict_) 

385 cls._instance = None 

386 

387 def __call__(cls, *args, **kwargs): 

388 if cls._instance is None: 

389 cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs) 

390 return cls._instance 

391 

392 

393class Debugger(metaclass=SingletonMeta): 

394 """Debug logger singleton 

395 

396 Disabled by default; to enable, do: 'Debugger().enabled = True' 

397 You can also redirect the output by changing the 'out' attribute. 

398 """ 

399 

400 def __init__(self): 

401 self.enabled = False 

402 self.out = sys.stderr 

403 

404 def log(self, source, msg, *args): 

405 """!Log message 

406 

407 The 'args' are only stringified if we're enabled. 

408 

409 @param source: name of source 

410 @param msg: message to write 

411 @param args: additional outputs to append to message 

412 """ 

413 if self.enabled: 

414 self.out.write("%s: %s" % (source, msg)) 

415 for arg in args: 

416 self.out.write(" %s" % arg) 

417 self.out.write("\n") 

418 

419 

420class ReductionThread(threading.Thread): 

421 """Thread to do reduction of results 

422 

423 "A thread?", you say. "What about the python GIL?" 

424 Well, because we 'sleep' when there's no immediate response from the 

425 slaves, that gives the thread a chance to fire; and threads are easier 

426 to manage (e.g., shared memory) than a process. 

427 """ 

428 def __init__(self, reducer, initial=None, sleep=0.1): 

429 """!Constructor 

430 

431 The 'reducer' should take two values and return a single 

432 (reduced) value. 

433 

434 @param reducer Function that does the reducing 

435 @param initial Initial value for reduction, or None 

436 @param sleep Time to sleep when there's nothing to do (sec) 

437 """ 

438 threading.Thread.__init__(self, name="reducer") 

439 self._queue = [] # Queue of stuff to be reduced 

440 self._lock = threading.Lock() # Lock for the queue 

441 self._reducer = reducer 

442 self._sleep = sleep 

443 self._result = initial # Final result 

444 self._done = threading.Event() # Signal that everything is done 

445 

446 def _doReduce(self): 

447 """Do the actual work 

448 

449 We pull the data out of the queue and release the lock before 

450 operating on it. This stops us from blocking the addition of 

451 new data to the queue. 

452 """ 

453 with self._lock: 

454 queue = self._queue 

455 self._queue = [] 

456 for data in queue: 

457 self._result = self._reducer(self._result, data) if self._result is not None else data 

458 

459 def run(self): 

460 """Do the work 

461 

462 Thread entry point, called by Thread.start 

463 """ 

464 while True: 

465 self._doReduce() 

466 if self._done.wait(self._sleep): 

467 self._doReduce() 

468 return 

469 

470 def add(self, data): 

471 """Add data to the queue to be reduced""" 

472 with self._lock: 

473 self._queue.append(data) 

474 

475 def join(self): 

476 """Complete the thread 

477 

478 Unlike Thread.join (which always returns 'None'), we return the result 

479 we calculated. 

480 """ 

481 self._done.set() 

482 threading.Thread.join(self) 

483 return self._result 

484 

485 

486class PoolNode(metaclass=SingletonMeta): 

487 """Node in MPI process pool 

488 

489 WARNING: You should not let a pool instance hang around at program 

490 termination, as the garbage collection behaves differently, and may 

491 cause a segmentation fault (signal 11). 

492 """ 

493 

494 def __init__(self, comm=None, root=0): 

495 if comm is None: 

496 comm = Comm() 

497 self.comm = comm 

498 self.rank = self.comm.rank 

499 self.root = root 

500 self.size = self.comm.size 

501 self._cache = {} 

502 self._store = {} 

503 self.debugger = Debugger() 

504 self.node = NODE 

505 

506 def _getCache(self, context, index): 

507 """Retrieve cache for particular data 

508 

509 The cache is updated with the contents of the store. 

510 """ 

511 if context not in self._cache: 

512 self._cache[context] = {} 

513 if context not in self._store: 

514 self._store[context] = {} 

515 cache = self._cache[context] 

516 store = self._store[context] 

517 if index not in cache: 

518 cache[index] = Cache(self.comm) 

519 cache[index].__dict__.update(store) 

520 return cache[index] 

521 

522 def log(self, msg, *args): 

523 """Log a debugging message""" 

524 self.debugger.log("Node %d" % self.rank, msg, *args) 

525 

526 def isMaster(self): 

527 return self.rank == self.root 

528 

529 def _processQueue(self, context, func, queue, *args, **kwargs): 

530 """!Process a queue of data 

531 

532 The queue consists of a list of (index, data) tuples, 

533 where the index maps to the cache, and the data is 

534 passed to the 'func'. 

535 

536 The 'func' signature should be func(cache, data, *args, **kwargs) 

537 if 'context' is non-None; otherwise func(data, *args, **kwargs). 

538 

539 @param context: Namespace for cache; None to not use cache 

540 @param func: function for slaves to run 

541 @param queue: List of (index,data) tuples to process 

542 @param args: Constant arguments 

543 @param kwargs: Keyword arguments 

544 @return list of results from applying 'func' to dataList 

545 """ 

546 return self._reduceQueue(context, None, func, queue, *args, **kwargs) 

547 

548 def _reduceQueue(self, context, reducer, func, queue, *args, **kwargs): 

549 """!Reduce a queue of data 

550 

551 The queue consists of a list of (index, data) tuples, 

552 where the index maps to the cache, and the data is 

553 passed to the 'func', the output of which is reduced 

554 using the 'reducer' (if non-None). 

555 

556 The 'func' signature should be func(cache, data, *args, **kwargs) 

557 if 'context' is non-None; otherwise func(data, *args, **kwargs). 

558 

559 The 'reducer' signature should be reducer(old, new). If the 'reducer' 

560 is None, then we will return the full list of results 

561 

562 @param context: Namespace for cache; None to not use cache 

563 @param reducer: function for master to run to reduce slave results; or None 

564 @param func: function for slaves to run 

565 @param queue: List of (index,data) tuples to process 

566 @param args: Constant arguments 

567 @param kwargs: Keyword arguments 

568 @return reduced result (if reducer is non-None) or list of results 

569 from applying 'func' to dataList 

570 """ 

571 if context is not None: 

572 resultList = [func(self._getCache(context, i), data, *args, **kwargs) for i, data in queue] 

573 else: 

574 resultList = [func(data, *args, **kwargs) for i, data in queue] 

575 if reducer is None: 

576 return resultList 

577 if len(resultList) == 0: 

578 return None 

579 output = resultList.pop(0) 

580 for result in resultList: 

581 output = reducer(output, result) 

582 return output 

583 

584 def storeSet(self, context, **kwargs): 

585 """Set values in store for a particular context""" 

586 self.log("storing", context, kwargs) 

587 if context not in self._store: 

588 self._store[context] = {} 

589 for name, value in kwargs.items(): 

590 self._store[context][name] = value 

591 

592 def storeDel(self, context, *nameList): 

593 """Delete value in store for a particular context""" 

594 self.log("deleting from store", context, nameList) 

595 if context not in self._store: 

596 raise KeyError("No such context: %s" % context) 

597 for name in nameList: 

598 del self._store[context][name] 

599 

600 def storeClear(self, context): 

601 """Clear stored data for a particular context""" 

602 self.log("clearing store", context) 

603 if context not in self._store: 

604 raise KeyError("No such context: %s" % context) 

605 self._store[context] = {} 

606 

607 def cacheClear(self, context): 

608 """Reset cache for a particular context""" 

609 self.log("clearing cache", context) 

610 if context not in self._cache: 

611 return 

612 self._cache[context] = {} 

613 

614 def cacheList(self, context): 

615 """List contents of cache""" 

616 cache = self._cache[context] if context in self._cache else {} 

617 sys.stderr.write("Cache on %s (%s): %s\n" % (self.node, context, cache)) 

618 

619 def storeList(self, context): 

620 """List contents of store for a particular context""" 

621 if context not in self._store: 

622 raise KeyError("No such context: %s" % context) 

623 sys.stderr.write("Store on %s (%s): %s\n" % (self.node, context, self._store[context])) 

624 

625 

626class PoolMaster(PoolNode): 

627 """Master node instance of MPI process pool 

628 

629 Only the master node should instantiate this. 

630 

631 WARNING: You should not let a pool instance hang around at program 

632 termination, as the garbage collection behaves differently, and may 

633 cause a segmentation fault (signal 11). 

634 """ 

635 

636 def __init__(self, *args, **kwargs): 

637 super(PoolMaster, self).__init__(*args, **kwargs) 

638 assert self.root == self.rank, "This is the master node" 

639 

640 def __del__(self): 

641 """Ensure slaves exit when we're done""" 

642 self.exit() 

643 

644 def log(self, msg, *args): 

645 """Log a debugging message""" 

646 self.debugger.log("Master", msg, *args) 

647 

648 def command(self, cmd): 

649 """Send command to slaves 

650 

651 A command is the name of the PoolSlave method they should run. 

652 """ 

653 self.log("command", cmd) 

654 self.comm.broadcast(cmd, root=self.root) 

655 

656 def map(self, context, func, dataList, *args, **kwargs): 

657 """!Scatter work to slaves and gather the results 

658 

659 Work is distributed dynamically, so that slaves that finish 

660 quickly will receive more work. 

661 

662 Each slave applies the function to the data they're provided. 

663 The slaves may optionally be passed a cache instance, which 

664 they can use to store data for subsequent executions (to ensure 

665 subsequent data is distributed in the same pattern as before, 

666 use the 'mapToPrevious' method). The cache also contains 

667 data that has been stored on the slaves. 

668 

669 The 'func' signature should be func(cache, data, *args, **kwargs) 

670 if 'context' is non-None; otherwise func(data, *args, **kwargs). 

671 

672 @param context: Namespace for cache 

673 @param func: function for slaves to run; must be picklable 

674 @param dataList: List of data to distribute to slaves; must be picklable 

675 @param args: List of constant arguments 

676 @param kwargs: Dict of constant arguments 

677 @return list of results from applying 'func' to dataList 

678 """ 

679 return self.reduce(context, None, func, dataList, *args, **kwargs) 

680 

681 @abortOnError 

682 @catchPicklingError 

683 def reduce(self, context, reducer, func, dataList, *args, **kwargs): 

684 """!Scatter work to slaves and reduce the results 

685 

686 Work is distributed dynamically, so that slaves that finish 

687 quickly will receive more work. 

688 

689 Each slave applies the function to the data they're provided. 

690 The slaves may optionally be passed a cache instance, which 

691 they can use to store data for subsequent executions (to ensure 

692 subsequent data is distributed in the same pattern as before, 

693 use the 'mapToPrevious' method). The cache also contains 

694 data that has been stored on the slaves. 

695 

696 The 'func' signature should be func(cache, data, *args, **kwargs) 

697 if 'context' is non-None; otherwise func(data, *args, **kwargs). 

698 

699 The 'reducer' signature should be reducer(old, new). If the 'reducer' 

700 is None, then we will return the full list of results 

701 

702 @param context: Namespace for cache 

703 @param reducer: function for master to run to reduce slave results; or None 

704 @param func: function for slaves to run; must be picklable 

705 @param dataList: List of data to distribute to slaves; must be picklable 

706 @param args: List of constant arguments 

707 @param kwargs: Dict of constant arguments 

708 @return reduced result (if reducer is non-None) or list of results 

709 from applying 'func' to dataList 

710 """ 

711 tags = Tags("request", "work") 

712 num = len(dataList) 

713 if self.size == 1 or num <= 1: 

714 return self._reduceQueue(context, reducer, func, list(zip(list(range(num)), dataList)), 

715 *args, **kwargs) 

716 if self.size == num: 

717 # We're shooting ourselves in the foot using dynamic distribution 

718 return self.reduceNoBalance(context, reducer, func, dataList, *args, **kwargs) 

719 

720 self.command("reduce") 

721 

722 # Send function 

723 self.log("instruct") 

724 self.comm.broadcast((tags, func, reducer, args, kwargs, context), root=self.root) 

725 

726 # Parcel out first set of data 

727 queue = list(zip(range(num), dataList)) # index, data 

728 output = [None]*num if reducer is None else None 

729 initial = [None if i == self.rank else queue.pop(0) if queue else NoOp() for 

730 i in range(self.size)] 

731 pending = min(num, self.size - 1) 

732 self.log("scatter initial jobs") 

733 self.comm.scatter(initial, root=self.rank) 

734 

735 while queue or pending > 0: 

736 status = mpi.Status() 

737 report = self.comm.recv(status=status, tag=tags.request, source=mpi.ANY_SOURCE) 

738 source = status.source 

739 self.log("gather from slave", source) 

740 if reducer is None: 

741 index, result = report 

742 output[index] = result 

743 

744 if queue: 

745 job = queue.pop(0) 

746 self.log("send job to slave", job[0], source) 

747 else: 

748 job = NoOp() 

749 pending -= 1 

750 self.comm.send(job, source, tag=tags.work) 

751 

752 if reducer is not None: 

753 results = self.comm.gather(None, root=self.root) 

754 output = None 

755 for rank in range(self.size): 

756 if rank == self.root: 

757 continue 

758 output = reducer(output, results[rank]) if output is not None else results[rank] 

759 

760 self.log("done") 

761 return output 

762 

763 def mapNoBalance(self, context, func, dataList, *args, **kwargs): 

764 """!Scatter work to slaves and gather the results 

765 

766 Work is distributed statically, so there is no load balancing. 

767 

768 Each slave applies the function to the data they're provided. 

769 The slaves may optionally be passed a cache instance, which 

770 they can store data in for subsequent executions (to ensure 

771 subsequent data is distributed in the same pattern as before, 

772 use the 'mapToPrevious' method). The cache also contains 

773 data that has been stored on the slaves. 

774 

775 The 'func' signature should be func(cache, data, *args, **kwargs) 

776 if 'context' is true; otherwise func(data, *args, **kwargs). 

777 

778 @param context: Namespace for cache 

779 @param func: function for slaves to run; must be picklable 

780 @param dataList: List of data to distribute to slaves; must be picklable 

781 @param args: List of constant arguments 

782 @param kwargs: Dict of constant arguments 

783 @return list of results from applying 'func' to dataList 

784 """ 

785 return self.reduceNoBalance(context, None, func, dataList, *args, **kwargs) 

786 

787 @abortOnError 

788 @catchPicklingError 

789 def reduceNoBalance(self, context, reducer, func, dataList, *args, **kwargs): 

790 """!Scatter work to slaves and reduce the results 

791 

792 Work is distributed statically, so there is no load balancing. 

793 

794 Each slave applies the function to the data they're provided. 

795 The slaves may optionally be passed a cache instance, which 

796 they can store data in for subsequent executions (to ensure 

797 subsequent data is distributed in the same pattern as before, 

798 use the 'mapToPrevious' method). The cache also contains 

799 data that has been stored on the slaves. 

800 

801 The 'func' signature should be func(cache, data, *args, **kwargs) 

802 if 'context' is true; otherwise func(data, *args, **kwargs). 

803 

804 The 'reducer' signature should be reducer(old, new). If the 'reducer' 

805 is None, then we will return the full list of results 

806 

807 @param context: Namespace for cache 

808 @param reducer: function for master to run to reduce slave results; or None 

809 @param func: function for slaves to run; must be picklable 

810 @param dataList: List of data to distribute to slaves; must be picklable 

811 @param args: List of constant arguments 

812 @param kwargs: Dict of constant arguments 

813 @return reduced result (if reducer is non-None) or list of results 

814 from applying 'func' to dataList 

815 """ 

816 tags = Tags("result", "work") 

817 num = len(dataList) 

818 if self.size == 1 or num <= 1: 

819 return self._reduceQueue(context, reducer, func, list(zip(range(num), dataList)), *args, **kwargs) 

820 

821 self.command("mapNoBalance") 

822 

823 # Send function 

824 self.log("instruct") 

825 self.comm.broadcast((tags, func, args, kwargs, context), root=self.root) 

826 

827 # Divide up the jobs 

828 # Try to give root the least to do, so it also has time to manage 

829 queue = list(zip(range(num), dataList)) # index, data 

830 if num < self.size: 

831 distribution = [[queue[i]] for i in range(num)] 

832 distribution.insert(self.rank, []) 

833 for i in range(num, self.size - 1): 

834 distribution.append([]) 

835 elif num % self.size == 0: 

836 numEach = num//self.size 

837 distribution = [queue[i*numEach:(i+1)*numEach] for i in range(self.size)] 

838 else: 

839 numEach = num//self.size 

840 distribution = [queue[i*numEach:(i+1)*numEach] for i in range(self.size)] 

841 for i in range(numEach*self.size, num): 

842 distribution[(self.rank + 1) % self.size].append 

843 distribution = list([] for i in range(self.size)) 

844 for i, job in enumerate(queue, self.rank + 1): 

845 distribution[i % self.size].append(job) 

846 

847 # Distribute jobs 

848 for source in range(self.size): 

849 if source == self.rank: 

850 continue 

851 self.log("send jobs to ", source) 

852 self.comm.send(distribution[source], source, tag=tags.work) 

853 

854 # Execute our own jobs 

855 output = [None]*num if reducer is None else None 

856 

857 def ingestResults(output, nodeResults, distList): 

858 if reducer is None: 

859 for i, result in enumerate(nodeResults): 

860 index = distList[i][0] 

861 output[index] = result 

862 return output 

863 if output is None: 

864 output = nodeResults.pop(0) 

865 for result in nodeResults: 

866 output = reducer(output, result) 

867 return output 

868 

869 ourResults = self._processQueue(context, func, distribution[self.rank], *args, **kwargs) 

870 output = ingestResults(output, ourResults, distribution[self.rank]) 

871 

872 # Collect results 

873 pending = self.size - 1 

874 while pending > 0: 

875 status = mpi.Status() 

876 slaveResults = self.comm.recv(status=status, tag=tags.result, source=mpi.ANY_SOURCE) 

877 source = status.source 

878 self.log("gather from slave", source) 

879 output = ingestResults(output, slaveResults, distribution[source]) 

880 pending -= 1 

881 

882 self.log("done") 

883 return output 

884 

885 def mapToPrevious(self, context, func, dataList, *args, **kwargs): 

886 """!Scatter work to the same target as before 

887 

888 Work is distributed so that each slave handles the same 

889 indices in the dataList as when 'map' was called. 

890 This allows the right data to go to the right cache. 

891 

892 It is assumed that the dataList is the same length as when it was 

893 passed to 'map'. 

894 

895 The 'func' signature should be func(cache, data, *args, **kwargs). 

896 

897 @param context: Namespace for cache 

898 @param func: function for slaves to run; must be picklable 

899 @param dataList: List of data to distribute to slaves; must be picklable 

900 @param args: List of constant arguments 

901 @param kwargs: Dict of constant arguments 

902 @return list of results from applying 'func' to dataList 

903 """ 

904 return self.reduceToPrevious(context, None, func, dataList, *args, **kwargs) 

905 

906 @abortOnError 

907 @catchPicklingError 

908 def reduceToPrevious(self, context, reducer, func, dataList, *args, **kwargs): 

909 """!Reduction where work goes to the same target as before 

910 

911 Work is distributed so that each slave handles the same 

912 indices in the dataList as when 'map' was called. 

913 This allows the right data to go to the right cache. 

914 

915 It is assumed that the dataList is the same length as when it was 

916 passed to 'map'. 

917 

918 The 'func' signature should be func(cache, data, *args, **kwargs). 

919 

920 The 'reducer' signature should be reducer(old, new). If the 'reducer' 

921 is None, then we will return the full list of results 

922 

923 @param context: Namespace for cache 

924 @param reducer: function for master to run to reduce slave results; or None 

925 @param func: function for slaves to run; must be picklable 

926 @param dataList: List of data to distribute to slaves; must be picklable 

927 @param args: List of constant arguments 

928 @param kwargs: Dict of constant arguments 

929 @return reduced result (if reducer is non-None) or list of results 

930 from applying 'func' to dataList 

931 """ 

932 if context is None: 

933 raise ValueError("context must be set to map to same nodes as previous context") 

934 tags = Tags("result", "work") 

935 num = len(dataList) 

936 if self.size == 1 or num <= 1: 

937 # Can do everything here 

938 return self._reduceQueue(context, reducer, func, list(zip(range(num), dataList)), *args, **kwargs) 

939 if self.size == num: 

940 # We're shooting ourselves in the foot using dynamic distribution 

941 return self.reduceNoBalance(context, reducer, func, dataList, *args, **kwargs) 

942 

943 self.command("mapToPrevious") 

944 

945 # Send function 

946 self.log("instruct") 

947 self.comm.broadcast((tags, func, args, kwargs, context), root=self.root) 

948 

949 requestList = self.comm.gather(None, root=self.root) 

950 self.log("listen", requestList) 

951 initial = [dataList[index] if (index is not None and index >= 0) else None for index in requestList] 

952 self.log("scatter jobs", initial) 

953 self.comm.scatter(initial, root=self.root) 

954 pending = min(num, self.size - 1) 

955 

956 if reducer is None: 

957 output = [None]*num 

958 else: 

959 thread = ReductionThread(reducer) 

960 thread.start() 

961 

962 while pending > 0: 

963 status = mpi.Status() 

964 index, result, nextIndex = self.comm.recv(status=status, tag=tags.result, source=mpi.ANY_SOURCE) 

965 source = status.source 

966 self.log("gather from slave", source) 

967 if reducer is None: 

968 output[index] = result 

969 else: 

970 thread.add(result) 

971 

972 if nextIndex >= 0: 

973 job = dataList[nextIndex] 

974 self.log("send job to slave", source) 

975 self.comm.send(job, source, tag=tags.work) 

976 else: 

977 pending -= 1 

978 

979 self.log("waiting on", pending) 

980 

981 if reducer is not None: 

982 output = thread.join() 

983 

984 self.log("done") 

985 return output 

986 

987 @abortOnError 

988 @catchPicklingError 

989 def storeSet(self, context, **kwargs): 

990 """!Store data on slave for a particular context 

991 

992 The data is made available to functions through the cache. The 

993 stored data differs from the cache in that it is identical for 

994 all operations, whereas the cache is specific to the data being 

995 operated upon. 

996 

997 @param context: namespace for store 

998 @param kwargs: dict of name=value pairs 

999 """ 

1000 super(PoolMaster, self).storeSet(context, **kwargs) 

1001 self.command("storeSet") 

1002 self.log("give data") 

1003 self.comm.broadcast((context, kwargs), root=self.root) 

1004 self.log("done") 

1005 

1006 @abortOnError 

1007 def storeDel(self, context, *nameList): 

1008 """Delete stored data on slave for a particular context""" 

1009 super(PoolMaster, self).storeDel(context, *nameList) 

1010 self.command("storeDel") 

1011 self.log("tell names") 

1012 self.comm.broadcast((context, nameList), root=self.root) 

1013 self.log("done") 

1014 

1015 @abortOnError 

1016 def storeClear(self, context): 

1017 """Reset data store for a particular context on master and slaves""" 

1018 super(PoolMaster, self).storeClear(context) 

1019 self.command("storeClear") 

1020 self.comm.broadcast(context, root=self.root) 

1021 

1022 @abortOnError 

1023 def cacheClear(self, context): 

1024 """Reset cache for a particular context on master and slaves""" 

1025 super(PoolMaster, self).cacheClear(context) 

1026 self.command("cacheClear") 

1027 self.comm.broadcast(context, root=self.root) 

1028 

1029 @abortOnError 

1030 def cacheList(self, context): 

1031 """List cache contents for a particular context on master and slaves""" 

1032 super(PoolMaster, self).cacheList(context) 

1033 self.command("cacheList") 

1034 self.comm.broadcast(context, root=self.root) 

1035 

1036 @abortOnError 

1037 def storeList(self, context): 

1038 """List store contents for a particular context on master and slaves""" 

1039 super(PoolMaster, self).storeList(context) 

1040 self.command("storeList") 

1041 self.comm.broadcast(context, root=self.root) 

1042 

1043 def exit(self): 

1044 """Command slaves to exit""" 

1045 self.command("exit") 

1046 

1047 

1048class PoolSlave(PoolNode): 

1049 """Slave node instance of MPI process pool""" 

1050 

1051 def log(self, msg, *args): 

1052 """Log a debugging message""" 

1053 assert self.rank != self.root, "This is not the master node." 

1054 self.debugger.log("Slave %d" % self.rank, msg, *args) 

1055 

1056 @abortOnError 

1057 def run(self): 

1058 """Serve commands of master node 

1059 

1060 Slave accepts commands, which are the names of methods to execute. 

1061 This exits when a command returns a true value. 

1062 """ 

1063 menu = dict((cmd, getattr(self, cmd)) for cmd in ("reduce", "mapNoBalance", "mapToPrevious", 

1064 "storeSet", "storeDel", "storeClear", "storeList", 

1065 "cacheList", "cacheClear", "exit",)) 

1066 self.log("waiting for command from", self.root) 

1067 command = self.comm.broadcast(None, root=self.root) 

1068 self.log("command", command) 

1069 while not menu[command](): 

1070 self.log("waiting for command from", self.root) 

1071 command = self.comm.broadcast(None, root=self.root) 

1072 self.log("command", command) 

1073 self.log("exiting") 

1074 

1075 @catchPicklingError 

1076 def reduce(self): 

1077 """Reduce scattered data and return results""" 

1078 self.log("waiting for instruction") 

1079 tags, func, reducer, args, kwargs, context = self.comm.broadcast(None, root=self.root) 

1080 self.log("waiting for job") 

1081 job = self.comm.scatter(None, root=self.root) 

1082 

1083 out = None # Reduction result 

1084 while not isinstance(job, NoOp): 

1085 index, data = job 

1086 self.log("running job") 

1087 result = self._processQueue(context, func, [(index, data)], *args, **kwargs)[0] 

1088 if reducer is None: 

1089 report = (index, result) 

1090 else: 

1091 report = None 

1092 out = reducer(out, result) if out is not None else result 

1093 self.comm.send(report, self.root, tag=tags.request) 

1094 self.log("waiting for job") 

1095 job = self.comm.recv(tag=tags.work, source=self.root) 

1096 

1097 if reducer is not None: 

1098 self.comm.gather(out, root=self.root) 

1099 self.log("done") 

1100 

1101 @catchPicklingError 

1102 def mapNoBalance(self): 

1103 """Process bulk scattered data and return results""" 

1104 self.log("waiting for instruction") 

1105 tags, func, args, kwargs, context = self.comm.broadcast(None, root=self.root) 

1106 self.log("waiting for job") 

1107 queue = self.comm.recv(tag=tags.work, source=self.root) 

1108 

1109 resultList = [] 

1110 for index, data in queue: 

1111 self.log("running job", index) 

1112 result = self._processQueue(context, func, [(index, data)], *args, **kwargs)[0] 

1113 resultList.append(result) 

1114 

1115 self.comm.send(resultList, self.root, tag=tags.result) 

1116 self.log("done") 

1117 

1118 @catchPicklingError 

1119 def mapToPrevious(self): 

1120 """Process the same scattered data processed previously""" 

1121 self.log("waiting for instruction") 

1122 tags, func, args, kwargs, context = self.comm.broadcast(None, root=self.root) 

1123 queue = list(self._cache[context].keys()) if context in self._cache else None 

1124 index = queue.pop(0) if queue else -1 

1125 self.log("request job", index) 

1126 self.comm.gather(index, root=self.root) 

1127 self.log("waiting for job") 

1128 data = self.comm.scatter(None, root=self.root) 

1129 

1130 while index >= 0: 

1131 self.log("running job") 

1132 result = func(self._getCache(context, index), data, *args, **kwargs) 

1133 self.log("pending", queue) 

1134 nextIndex = queue.pop(0) if queue else -1 

1135 self.comm.send((index, result, nextIndex), self.root, tag=tags.result) 

1136 index = nextIndex 

1137 if index >= 0: 

1138 data = self.comm.recv(tag=tags.work, source=self.root) 

1139 

1140 self.log("done") 

1141 

1142 def storeSet(self): 

1143 """Set value in store""" 

1144 context, kwargs = self.comm.broadcast(None, root=self.root) 

1145 super(PoolSlave, self).storeSet(context, **kwargs) 

1146 

1147 def storeDel(self): 

1148 """Delete value in store""" 

1149 context, nameList = self.comm.broadcast(None, root=self.root) 

1150 super(PoolSlave, self).storeDel(context, *nameList) 

1151 

1152 def storeClear(self): 

1153 """Reset data store""" 

1154 context = self.comm.broadcast(None, root=self.root) 

1155 super(PoolSlave, self).storeClear(context) 

1156 

1157 def cacheClear(self): 

1158 """Reset cache""" 

1159 context = self.comm.broadcast(None, root=self.root) 

1160 super(PoolSlave, self).cacheClear(context) 

1161 

1162 def cacheList(self): 

1163 """List cache contents""" 

1164 context = self.comm.broadcast(None, root=self.root) 

1165 super(PoolSlave, self).cacheList(context) 

1166 

1167 def storeList(self): 

1168 """List store contents""" 

1169 context = self.comm.broadcast(None, root=self.root) 

1170 super(PoolSlave, self).storeList(context) 

1171 

1172 def exit(self): 

1173 """Allow exit from loop in 'run'""" 

1174 return True 

1175 

1176 

1177class PoolWrapperMeta(type): 

1178 """Metaclass for PoolWrapper to add methods pointing to PoolMaster 

1179 

1180 The 'context' is automatically supplied to these methods as the first argument. 

1181 """ 

1182 

1183 def __call__(cls, context="default"): 

1184 instance = super(PoolWrapperMeta, cls).__call__(context) 

1185 pool = PoolMaster() 

1186 for name in ("map", "mapNoBalance", "mapToPrevious", 

1187 "reduce", "reduceNoBalance", "reduceToPrevious", 

1188 "storeSet", "storeDel", "storeClear", "storeList", 

1189 "cacheList", "cacheClear",): 

1190 setattr(instance, name, partial(getattr(pool, name), context)) 

1191 return instance 

1192 

1193 

1194class PoolWrapper(metaclass=PoolWrapperMeta): 

1195 """Wrap PoolMaster to automatically provide context""" 

1196 

1197 def __init__(self, context="default"): 

1198 self._pool = PoolMaster._instance 

1199 self._context = context 

1200 

1201 def __getattr__(self, name): 

1202 return getattr(self._pool, name) 

1203 

1204 

1205class Pool(PoolWrapper): # Just gives PoolWrapper a nicer name for the user 

1206 """Process Pool 

1207 

1208 Use this class to automatically provide 'context' to 

1209 the PoolMaster class. If you want to call functions 

1210 that don't take a 'cache' object, use the PoolMaster 

1211 class directly, and specify context=None. 

1212 """ 

1213 pass 

1214 

1215 

1216def startPool(comm=None, root=0, killSlaves=True): 

1217 """!Start a process pool. 

1218 

1219 Returns a PoolMaster object for the master node. 

1220 Slave nodes are run and then optionally killed. 

1221 

1222 If you elect not to kill the slaves, note that they 

1223 will emerge at the point this function was called, 

1224 which is likely very different from the point the 

1225 master is at, so it will likely be necessary to put 

1226 in some rank dependent code (e.g., look at the 'rank' 

1227 attribute of the returned pools). 

1228 

1229 Note that the pool objects should be deleted (either 

1230 by going out of scope or explicit 'del') before program 

1231 termination to avoid a segmentation fault. 

1232 

1233 @param comm: MPI communicator 

1234 @param root: Rank of root/master node 

1235 @param killSlaves: Kill slaves on completion? 

1236 """ 

1237 if comm is None: 

1238 comm = Comm() 

1239 if comm.rank == root: 

1240 return PoolMaster(comm, root=root) 

1241 slave = PoolSlave(comm, root=root) 

1242 slave.run() 

1243 if killSlaves: 

1244 del slave # Required to prevent segmentation fault on exit 

1245 sys.exit() 

1246 return slave