lsst.ctrl.pool  14.0-1-ga2912ff+12
pool.py
Go to the documentation of this file.
1 from future import standard_library
2 standard_library.install_aliases()
3 from builtins import zip
4 from builtins import range
5 from past.builtins import basestring
6 from builtins import object
7 # MPI process pool
8 # Copyright 2013 Paul A. Price
9 #
10 # This program is free software: you can redistribute it and/or modify
11 # it under the terms of the GNU General Public License as published by
12 # the Free Software Foundation, either version 3 of the License, or
13 # (at your option) any later version.
14 #
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 # GNU General Public License for more details.
19 #
20 # You should have received a copy of the GNU General Public License
21 # along with this program. If not, see <http://www.gnu.org/copyleft/gpl.html>
22 #
23 
24 import os
25 import sys
26 import time
27 import types
28 import copyreg
29 import threading
30 from functools import wraps, partial
31 from contextlib import contextmanager
32 
33 import mpi4py.MPI as mpi
34 
35 from lsst.pipe.base import Struct
36 from future.utils import with_metaclass
37 
38 __all__ = ["Comm", "Pool", "startPool", "setBatchType", "getBatchType", "abortOnError", "NODE", ]
39 
40 NODE = "%s:%d" % (os.uname()[1], os.getpid()) # Name of node
41 
42 
43 def unpickleInstanceMethod(obj, name):
44  """Unpickle an instance method
45 
46  This has to be a named function rather than a lambda because
47  pickle needs to find it.
48  """
49  return getattr(obj, name)
50 
51 
53  """Pickle an instance method
54 
55  The instance method is divided into the object and the
56  method name.
57  """
58  obj = method.__self__
59  name = method.__name__
60  return unpickleInstanceMethod, (obj, name)
61 
62 copyreg.pickle(types.MethodType, pickleInstanceMethod)
63 
64 
65 def unpickleFunction(moduleName, funcName):
66  """Unpickle a function
67 
68  This has to be a named function rather than a lambda because
69  pickle needs to find it.
70  """
71  import importlib
72  module = importlib.import_module(moduleName)
73  return getattr(module, funcName)
74 
75 
76 def pickleFunction(function):
77  """Pickle a function
78 
79  This assumes that we can recreate the function object by grabbing
80  it from the proper module. This may be violated if the function
81  is a lambda or in __main__. In that case, I recommend recasting
82  the function as an object with a __call__ method.
83 
84  Another problematic case may be a wrapped (e.g., decorated) method
85  in a class: the 'method' is then a function, and recreating it is
86  not as easy as we assume here.
87  """
88  moduleName = function.__module__
89  funcName = function.__name__
90  return unpickleFunction, (moduleName, funcName)
91 
92 copyreg.pickle(types.FunctionType, pickleFunction)
93 
94 try:
95  _batchType
96 except NameError:
97  _batchType = "unknown"
98 
100  """Return a string giving the type of batch system in use"""
101  return _batchType
102 
103 def setBatchType(batchType):
104  """Return a string giving the type of batch system in use"""
105  global _batchType
106  _batchType = batchType
107 
108 def abortOnError(func):
109  """Function decorator to throw an MPI abort on an unhandled exception"""
110  @wraps(func)
111  def wrapper(*args, **kwargs):
112  try:
113  return func(*args, **kwargs)
114  except Exception as e:
115  sys.stderr.write("%s on %s in %s: %s\n" % (type(e).__name__, NODE, func.__name__, e))
116  import traceback
117  traceback.print_exc(file=sys.stderr)
118  if getBatchType() is not None:
119  mpi.COMM_WORLD.Abort(1)
120  else:
121  raise
122  return wrapper
123 
124 
125 class PickleHolder(object):
126  """Singleton to hold what's about to be pickled.
127 
128  We hold onto the object in case there's trouble pickling,
129  so we can figure out what class in particular is causing
130  the trouble.
131 
132  The held object is in the 'obj' attribute.
133 
134  Here we use the __new__-style singleton pattern, because
135  we specifically want __init__ to be called each time.
136  """
137 
138  _instance = None
139 
140  def __new__(cls, hold=None):
141  if cls._instance is None:
142  cls._instance = super(PickleHolder, cls).__new__(cls)
143  cls._instance.__init__(hold)
144  cls._instance.obj = None
145  return cls._instance
146 
147  def __init__(self, hold=None):
148  """Hold onto new object"""
149  if hold is not None:
150  self.obj = hold
151 
152  def __enter__(self):
153  pass
154 
155  def __exit__(self, excType, excVal, tb):
156  """Drop held object if there were no problems"""
157  if excType is None:
158  self.obj = None
159 
160 
162  """Try to guess what's not pickling after an exception
163 
164  This tends to work if the problem is coming from the
165  regular pickle module. If it's coming from the bowels
166  of mpi4py, there's not much that can be done.
167  """
168  import sys
169  excType, excValue, tb = sys.exc_info()
170  # Build a stack of traceback elements
171  stack = []
172  while tb:
173  stack.append(tb)
174  tb = tb.tb_next
175 
176  try:
177  # This is the code version of a my way to find what's not pickling in pdb.
178  # This should work if it's coming from the regular pickle module, and they
179  # haven't changed the variable names since python 2.7.3.
180  return stack[-2].tb_frame.f_locals["obj"]
181  except:
182  return None
183 
184 
185 @contextmanager
186 def pickleSniffer(abort=False):
187  """Context manager to sniff out pickle problems
188 
189  If there's a pickle error, you're normally told what the problem
190  class is. However, all SWIG objects are reported as "SwigPyObject".
191  In order to figure out which actual SWIG-ed class is causing
192  problems, we need to go digging.
193 
194  Use like this:
195 
196  with pickleSniffer():
197  someOperationInvolvingPickle()
198 
199  If 'abort' is True, will call MPI abort in the event of problems.
200  """
201  try:
202  yield
203  except Exception as e:
204  if "SwigPyObject" not in str(e) or "pickle" not in str(e):
205  raise
206  import sys
207  import traceback
208 
209  sys.stderr.write("Pickling error detected: %s\n" % e)
210  traceback.print_exc(file=sys.stderr)
211  obj = guessPickleObj()
212  heldObj = PickleHolder().obj
213  if obj is None and heldObj is not None:
214  # Try to reproduce using what was being pickled using the regular pickle module,
215  # so we've got a chance of figuring out what the problem is.
216  import pickle
217  try:
218  pickle.dumps(heldObj)
219  sys.stderr.write("Hmmm, that's strange: no problem with pickling held object?!?!\n")
220  except Exception:
221  obj = guessPickleObj()
222  if obj is None:
223  sys.stderr.write("Unable to determine class causing pickle problems.\n")
224  else:
225  sys.stderr.write("Object that could not be pickled: %s\n" % obj)
226  if abort:
227  if getBatchType() is not None:
228  mpi.COMM_WORLD.Abort(1)
229  else:
230  sys.exit(1)
231 
233  """Function decorator to catch errors in pickling and print something useful"""
234  @wraps(func)
235  def wrapper(*args, **kwargs):
236  with pickleSniffer(True):
237  return func(*args, **kwargs)
238  return wrapper
239 
240 
241 class Comm(mpi.Intracomm):
242  """Wrapper to mpi4py's MPI.Intracomm class to avoid busy-waiting.
243 
244  As suggested by Lisandro Dalcin at:
245  * http://code.google.com/p/mpi4py/issues/detail?id=4 and
246  * https://groups.google.com/forum/?fromgroups=#!topic/mpi4py/nArVuMXyyZI
247  """
248 
249  def __new__(cls, comm=mpi.COMM_WORLD, recvSleep=0.1, barrierSleep=0.1):
250  """!Construct an MPI.Comm wrapper
251 
252  @param cls Class
253  @param comm MPI.Intracomm to wrap a duplicate of
254  @param recvSleep Sleep time (seconds) for recv()
255  @param barrierSleep Sleep time (seconds) for Barrier()
256  """
257  self = super(Comm, cls).__new__(cls, comm.Dup())
258  self._barrierComm = None # Duplicate communicator used for Barrier point-to-point checking
259  self._recvSleep = recvSleep
260  self._barrierSleep = barrierSleep
261  return self
262 
263  def recv(self, obj=None, source=0, tag=0, status=None):
264  """Version of comm.recv() that doesn't busy-wait"""
265  sts = mpi.Status()
266  while not self.Iprobe(source=source, tag=tag, status=sts):
267  time.sleep(self._recvSleep)
268  return super(Comm, self).recv(buf=obj, source=sts.source, tag=sts.tag, status=status)
269 
270  def send(self, obj=None, *args, **kwargs):
271  with PickleHolder(obj):
272  return super(Comm, self).send(obj, *args, **kwargs)
273 
274  def _checkBarrierComm(self):
275  """Ensure the duplicate communicator is available"""
276  if self._barrierComm is None:
277  self._barrierComm = self.Dup()
278 
279  def Barrier(self, tag=0):
280  """Version of comm.Barrier() that doesn't busy-wait
281 
282  A duplicate communicator is used so as not to interfere with the user's own communications.
283  """
284  self._checkBarrierComm()
285  size = self._barrierComm.Get_size()
286  if size == 1:
287  return
288  rank = self._barrierComm.Get_rank()
289  mask = 1
290  while mask < size:
291  dst = (rank + mask) % size
292  src = (rank - mask + size) % size
293  req = self._barrierComm.isend(None, dst, tag)
294  while not self._barrierComm.Iprobe(src, tag):
295  time.sleep(self._barrierSleep)
296  self._barrierComm.recv(None, src, tag)
297  req.Wait()
298  mask <<= 1
299 
300  def broadcast(self, value, root=0):
301  with PickleHolder(value):
302  return super(Comm, self).bcast(value, root=root)
303 
304  def scatter(self, dataList, root=0, tag=0):
305  """Scatter data across the nodes
306 
307  The default version apparently pickles the entire 'dataList',
308  which can cause errors if the pickle size grows over 2^31 bytes
309  due to fundamental problems with pickle in python 2. Instead,
310  we send the data to each slave node in turn; this reduces the
311  pickle size.
312 
313  @param dataList List of data to distribute; one per node
314  (including root)
315  @param root Index of root node
316  @param tag Message tag (integer)
317  @return Data for this node
318  """
319  if self.Get_rank() == root:
320  for rank, data in enumerate(dataList):
321  if rank == root:
322  continue
323  self.send(data, rank, tag=tag)
324  return dataList[root]
325  else:
326  return self.recv(source=root, tag=tag)
327 
328  def Free(self):
329  if self._barrierComm is not None:
330  self._barrierComm.Free()
331  super(Comm, self).Free()
332 
333 
334 class NoOp(object):
335  """Object to signal no operation"""
336  pass
337 
338 
339 class Tags(object):
340  """Provides tag numbers by symbolic name in attributes"""
341 
342  def __init__(self, *nameList):
343  self._nameList = nameList
344  for i, name in enumerate(nameList, 1):
345  setattr(self, name, i)
346 
347  def __repr__(self):
348  return self.__class__.__name__ + repr(self._nameList)
349 
350  def __reduce__(self):
351  return self.__class__, tuple(self._nameList)
352 
353 
354 class Cache(Struct):
355  """An object to hold stuff between different scatter calls
356 
357  Includes a communicator by default, to allow intercommunication
358  between nodes.
359  """
360 
361  def __init__(self, comm):
362  super(Cache, self).__init__(comm=comm)
363 
364 
365 class SingletonMeta(type):
366  """!Metaclass to produce a singleton
367 
368  Doing a singleton mixin without a metaclass (via __new__) is
369  annoying because the user has to name his __init__ something else
370  (otherwise it's called every time, which undoes any changes).
371  Using this metaclass, the class's __init__ is called exactly once.
372 
373  Because this is a metaclass, note that:
374  * "self" here is the class
375  * "__init__" is making the class (it's like the body of the
376  class definition).
377  * "__call__" is making an instance of the class (it's like
378  "__new__" in the class).
379  """
380 
381  def __init__(self, name, bases, dict_):
382  super(SingletonMeta, self).__init__(name, bases, dict_)
383  self._instance = None
384 
385  def __call__(self, *args, **kwargs):
386  if self._instance is None:
387  self._instance = super(SingletonMeta, self).__call__(*args, **kwargs)
388  return self._instance
389 
390 
391 class Debugger(with_metaclass(SingletonMeta, object)):
392  """Debug logger singleton
393 
394  Disabled by default; to enable, do: 'Debugger().enabled = True'
395  You can also redirect the output by changing the 'out' attribute.
396  """
397 
398  def __init__(self):
399  self.enabled = False
400  self.out = sys.stderr
401 
402  def log(self, source, msg, *args):
403  """!Log message
404 
405  The 'args' are only stringified if we're enabled.
406 
407  @param source: name of source
408  @param msg: message to write
409  @param args: additional outputs to append to message
410  """
411  if self.enabled:
412  self.out.write("%s: %s" % (source, msg))
413  for arg in args:
414  self.out.write(" %s" % arg)
415  self.out.write("\n")
416 
417 
418 class ReductionThread(threading.Thread):
419  """Thread to do reduction of results
420 
421  "A thread?", you say. "What about the python GIL?"
422  Well, because we 'sleep' when there's no immediate response from the
423  slaves, that gives the thread a chance to fire; and threads are easier
424  to manage (e.g., shared memory) than a process.
425  """
426  def __init__(self, reducer, initial=None, sleep=0.1):
427  """!Constructor
428 
429  The 'reducer' should take two values and return a single
430  (reduced) value.
431 
432  @param reducer Function that does the reducing
433  @param initial Initial value for reduction, or None
434  @param sleep Time to sleep when there's nothing to do (sec)
435  """
436  threading.Thread.__init__(self, name="reducer")
437  self._queue = [] # Queue of stuff to be reduced
438  self._lock = threading.Lock() # Lock for the queue
439  self._reducer = reducer
440  self._sleep = sleep
441  self._result = initial # Final result
442  self._done = threading.Event() # Signal that everything is done
443 
444  def _doReduce(self):
445  """Do the actual work
446 
447  We pull the data out of the queue and release the lock before
448  operating on it. This stops us from blocking the addition of
449  new data to the queue.
450  """
451  with self._lock:
452  queue = self._queue
453  self._queue = []
454  for data in queue:
455  self._result = self._reducer(self._result, data) if self._result is not None else data
456 
457  def run(self):
458  """Do the work
459 
460  Thread entry point, called by Thread.start
461  """
462  while True:
463  self._doReduce()
464  if self._done.wait(self._sleep):
465  self._doReduce()
466  return
467 
468  def add(self, data):
469  """Add data to the queue to be reduced"""
470  with self._lock:
471  self._queue.append(data)
472 
473  def join(self):
474  """Complete the thread
475 
476  Unlike Thread.join (which always returns 'None'), we return the result
477  we calculated.
478  """
479  self._done.set()
480  threading.Thread.join(self)
481  return self._result
482 
483 
484 class PoolNode(with_metaclass(SingletonMeta, object)):
485  """Node in MPI process pool
486 
487  WARNING: You should not let a pool instance hang around at program
488  termination, as the garbage collection behaves differently, and may
489  cause a segmentation fault (signal 11).
490  """
491 
492  def __init__(self, comm=None, root=0):
493  if comm is None:
494  comm = Comm()
495  self.comm = comm
496  self.rank = self.comm.rank
497  self.root = root
498  self.size = self.comm.size
499  self._cache = {}
500  self._store = {}
502  self.node = NODE
503 
504  def _getCache(self, context, index):
505  """Retrieve cache for particular data
506 
507  The cache is updated with the contents of the store.
508  """
509  if not context in self._cache:
510  self._cache[context] = {}
511  if not context in self._store:
512  self._store[context] = {}
513  cache = self._cache[context]
514  store = self._store[context]
515  if index not in cache:
516  cache[index] = Cache(self.comm)
517  cache[index].__dict__.update(store)
518  return cache[index]
519 
520  def log(self, msg, *args):
521  """Log a debugging message"""
522  self.debugger.log("Node %d" % self.rank, msg, *args)
523 
524  def isMaster(self):
525  return self.rank == self.root
526 
527  def _processQueue(self, context, func, queue, *args, **kwargs):
528  """!Process a queue of data
529 
530  The queue consists of a list of (index, data) tuples,
531  where the index maps to the cache, and the data is
532  passed to the 'func'.
533 
534  The 'func' signature should be func(cache, data, *args, **kwargs)
535  if 'context' is non-None; otherwise func(data, *args, **kwargs).
536 
537  @param context: Namespace for cache; None to not use cache
538  @param func: function for slaves to run
539  @param queue: List of (index,data) tuples to process
540  @param args: Constant arguments
541  @param kwargs: Keyword arguments
542  @return list of results from applying 'func' to dataList
543  """
544  return self._reduceQueue(context, None, func, queue, *args, **kwargs)
545 
546  def _reduceQueue(self, context, reducer, func, queue, *args, **kwargs):
547  """!Reduce a queue of data
548 
549  The queue consists of a list of (index, data) tuples,
550  where the index maps to the cache, and the data is
551  passed to the 'func', the output of which is reduced
552  using the 'reducer' (if non-None).
553 
554  The 'func' signature should be func(cache, data, *args, **kwargs)
555  if 'context' is non-None; otherwise func(data, *args, **kwargs).
556 
557  The 'reducer' signature should be reducer(old, new). If the 'reducer'
558  is None, then we will return the full list of results
559 
560  @param context: Namespace for cache; None to not use cache
561  @param reducer: function for master to run to reduce slave results; or None
562  @param func: function for slaves to run
563  @param queue: List of (index,data) tuples to process
564  @param args: Constant arguments
565  @param kwargs: Keyword arguments
566  @return reduced result (if reducer is non-None) or list of results
567  from applying 'func' to dataList
568  """
569  if context is not None:
570  resultList = [func(self._getCache(context, i), data, *args, **kwargs) for i, data in queue]
571  else:
572  resultList = [func(data, *args, **kwargs) for i, data in queue]
573  if reducer is None:
574  return resultList
575  if len(resultList) == 0:
576  return None
577  output = resultList.pop(0)
578  for result in resultList:
579  output = reducer(output, result)
580  return output
581 
582  def storeSet(self, context, **kwargs):
583  """Set values in store for a particular context"""
584  self.log("storing", context, kwargs)
585  if not context in self._store:
586  self._store[context] = {}
587  for name, value in kwargs.items():
588  self._store[context][name] = value
589 
590  def storeDel(self, context, *nameList):
591  """Delete value in store for a particular context"""
592  self.log("deleting from store", context, nameList)
593  if not context in self._store:
594  raise KeyError("No such context: %s" % context)
595  for name in nameList:
596  del self._store[context][name]
597 
598  def storeClear(self, context):
599  """Clear stored data for a particular context"""
600  self.log("clearing store", context)
601  if not context in self._store:
602  raise KeyError("No such context: %s" % context)
603  self._store[context] = {}
604 
605  def cacheClear(self, context):
606  """Reset cache for a particular context"""
607  self.log("clearing cache", context)
608  if not context in self._cache:
609  return
610  self._cache[context] = {}
611 
612  def cacheList(self, context):
613  """List contents of cache"""
614  cache = self._cache[context] if context in self._cache else {}
615  sys.stderr.write("Cache on %s (%s): %s\n" % (self.node, context, cache))
616 
617  def storeList(self, context):
618  """List contents of store for a particular context"""
619  if not context in self._store:
620  raise KeyError("No such context: %s" % context)
621  sys.stderr.write("Store on %s (%s): %s\n" % (self.node, context, self._store[context]))
622 
623 
625  """Master node instance of MPI process pool
626 
627  Only the master node should instantiate this.
628 
629  WARNING: You should not let a pool instance hang around at program
630  termination, as the garbage collection behaves differently, and may
631  cause a segmentation fault (signal 11).
632  """
633 
634  def __init__(self, *args, **kwargs):
635  super(PoolMaster, self).__init__(*args, **kwargs)
636  assert self.root == self.rank, "This is the master node"
637 
638  def __del__(self):
639  """Ensure slaves exit when we're done"""
640  self.exit()
641 
642  def log(self, msg, *args):
643  """Log a debugging message"""
644  self.debugger.log("Master", msg, *args)
645 
646  def command(self, cmd):
647  """Send command to slaves
648 
649  A command is the name of the PoolSlave method they should run.
650  """
651  self.log("command", cmd)
652  self.comm.broadcast(cmd, root=self.root)
653 
654  def map(self, context, func, dataList, *args, **kwargs):
655  """!Scatter work to slaves and gather the results
656 
657  Work is distributed dynamically, so that slaves that finish
658  quickly will receive more work.
659 
660  Each slave applies the function to the data they're provided.
661  The slaves may optionally be passed a cache instance, which
662  they can use to store data for subsequent executions (to ensure
663  subsequent data is distributed in the same pattern as before,
664  use the 'mapToPrevious' method). The cache also contains
665  data that has been stored on the slaves.
666 
667  The 'func' signature should be func(cache, data, *args, **kwargs)
668  if 'context' is non-None; otherwise func(data, *args, **kwargs).
669 
670  @param context: Namespace for cache
671  @param func: function for slaves to run; must be picklable
672  @param dataList: List of data to distribute to slaves; must be picklable
673  @param args: List of constant arguments
674  @param kwargs: Dict of constant arguments
675  @return list of results from applying 'func' to dataList
676  """
677  return self.reduce(context, None, func, dataList, *args, **kwargs)
678 
679  @abortOnError
680  @catchPicklingError
681  def reduce(self, context, reducer, func, dataList, *args, **kwargs):
682  """!Scatter work to slaves and reduce the results
683 
684  Work is distributed dynamically, so that slaves that finish
685  quickly will receive more work.
686 
687  Each slave applies the function to the data they're provided.
688  The slaves may optionally be passed a cache instance, which
689  they can use to store data for subsequent executions (to ensure
690  subsequent data is distributed in the same pattern as before,
691  use the 'mapToPrevious' method). The cache also contains
692  data that has been stored on the slaves.
693 
694  The 'func' signature should be func(cache, data, *args, **kwargs)
695  if 'context' is non-None; otherwise func(data, *args, **kwargs).
696 
697  The 'reducer' signature should be reducer(old, new). If the 'reducer'
698  is None, then we will return the full list of results
699 
700  @param context: Namespace for cache
701  @param reducer: function for master to run to reduce slave results; or None
702  @param func: function for slaves to run; must be picklable
703  @param dataList: List of data to distribute to slaves; must be picklable
704  @param args: List of constant arguments
705  @param kwargs: Dict of constant arguments
706  @return reduced result (if reducer is non-None) or list of results
707  from applying 'func' to dataList
708  """
709  tags = Tags("request", "work")
710  num = len(dataList)
711  if self.size == 1 or num <= 1:
712  return self._reduceQueue(context, reducer, func, list(zip(list(range(num)), dataList)),
713  *args, **kwargs)
714  if self.size == num:
715  # We're shooting ourselves in the foot using dynamic distribution
716  return self.reduceNoBalance(context, reducer, func, dataList, *args, **kwargs)
717 
718  self.command("reduce")
719 
720  # Send function
721  self.log("instruct")
722  self.comm.broadcast((tags, func, reducer, args, kwargs, context), root=self.root)
723 
724  # Parcel out first set of data
725  queue = list(zip(range(num), dataList)) # index, data
726  output = [None]*num if reducer is None else None
727  initial = [None if i == self.rank else queue.pop(0) if queue else NoOp() for
728  i in range(self.size)]
729  pending = min(num, self.size - 1)
730  self.log("scatter initial jobs")
731  self.comm.scatter(initial, root=self.rank)
732 
733  while queue or pending > 0:
734  status = mpi.Status()
735  report = self.comm.recv(status=status, tag=tags.request, source=mpi.ANY_SOURCE)
736  source = status.source
737  self.log("gather from slave", source)
738  if reducer is None:
739  index, result = report
740  output[index] = result
741 
742  if queue:
743  job = queue.pop(0)
744  self.log("send job to slave", job[0], source)
745  else:
746  job = NoOp()
747  pending -= 1
748  self.comm.send(job, source, tag=tags.work)
749 
750  if reducer is not None:
751  results = self.comm.gather(None, root=self.root)
752  output = None
753  for rank in range(self.size):
754  if rank == self.root:
755  continue
756  output = reducer(output, results[rank]) if output is not None else results[rank]
757 
758  self.log("done")
759  return output
760 
761  def mapNoBalance(self, context, func, dataList, *args, **kwargs):
762  """!Scatter work to slaves and gather the results
763 
764  Work is distributed statically, so there is no load balancing.
765 
766  Each slave applies the function to the data they're provided.
767  The slaves may optionally be passed a cache instance, which
768  they can store data in for subsequent executions (to ensure
769  subsequent data is distributed in the same pattern as before,
770  use the 'mapToPrevious' method). The cache also contains
771  data that has been stored on the slaves.
772 
773  The 'func' signature should be func(cache, data, *args, **kwargs)
774  if 'context' is true; otherwise func(data, *args, **kwargs).
775 
776  @param context: Namespace for cache
777  @param func: function for slaves to run; must be picklable
778  @param dataList: List of data to distribute to slaves; must be picklable
779  @param args: List of constant arguments
780  @param kwargs: Dict of constant arguments
781  @return list of results from applying 'func' to dataList
782  """
783  return self.reduceNoBalance(context, None, func, dataList, *args, **kwargs)
784 
785  @abortOnError
786  @catchPicklingError
787  def reduceNoBalance(self, context, reducer, func, dataList, *args, **kwargs):
788  """!Scatter work to slaves and reduce the results
789 
790  Work is distributed statically, so there is no load balancing.
791 
792  Each slave applies the function to the data they're provided.
793  The slaves may optionally be passed a cache instance, which
794  they can store data in for subsequent executions (to ensure
795  subsequent data is distributed in the same pattern as before,
796  use the 'mapToPrevious' method). The cache also contains
797  data that has been stored on the slaves.
798 
799  The 'func' signature should be func(cache, data, *args, **kwargs)
800  if 'context' is true; otherwise func(data, *args, **kwargs).
801 
802  The 'reducer' signature should be reducer(old, new). If the 'reducer'
803  is None, then we will return the full list of results
804 
805  @param context: Namespace for cache
806  @param reducer: function for master to run to reduce slave results; or None
807  @param func: function for slaves to run; must be picklable
808  @param dataList: List of data to distribute to slaves; must be picklable
809  @param args: List of constant arguments
810  @param kwargs: Dict of constant arguments
811  @return reduced result (if reducer is non-None) or list of results
812  from applying 'func' to dataList
813  """
814  tags = Tags("result", "work")
815  num = len(dataList)
816  if self.size == 1 or num <= 1:
817  return self._reduceQueue(context, reducer, func, list(zip(range(num), dataList)), *args, **kwargs)
818 
819  self.command("mapNoBalance")
820 
821  # Send function
822  self.log("instruct")
823  self.comm.broadcast((tags, func, args, kwargs, context), root=self.root)
824 
825  # Divide up the jobs
826  # Try to give root the least to do, so it also has time to manage
827  queue = list(zip(range(num), dataList)) # index, data
828  if num < self.size:
829  distribution = [[queue[i]] for i in range(num)]
830  distribution.insert(self.rank, [])
831  for i in range(num, self.size - 1):
832  distribution.append([])
833  elif num % self.size == 0:
834  numEach = num//self.size
835  distribution = [queue[i*numEach:(i+1)*numEach] for i in range(self.size)]
836  else:
837  numEach = num//self.size
838  distribution = [queue[i*numEach:(i+1)*numEach] for i in range(self.size)]
839  for i in range(numEach*self.size, num):
840  distribution[(self.rank + 1) % self.size].append
841  distribution = list([] for i in range(self.size))
842  for i, job in enumerate(queue, self.rank + 1):
843  distribution[i % self.size].append(job)
844 
845  # Distribute jobs
846  for source in range(self.size):
847  if source == self.rank:
848  continue
849  self.log("send jobs to ", source)
850  self.comm.send(distribution[source], source, tag=tags.work)
851 
852  # Execute our own jobs
853  output = [None]*num if reducer is None else None
854 
855  def ingestResults(output, nodeResults, distList):
856  if reducer is None:
857  for i, result in enumerate(nodeResults):
858  index = distList[i][0]
859  output[index] = result
860  return output
861  if output is None:
862  output = nodeResults.pop(0)
863  for result in nodeResults:
864  output = reducer(output, result)
865  return output
866 
867  ourResults = self._processQueue(context, func, distribution[self.rank], *args, **kwargs)
868  output = ingestResults(output, ourResults, distribution[self.rank])
869 
870  # Collect results
871  pending = self.size - 1
872  while pending > 0:
873  status = mpi.Status()
874  slaveResults = self.comm.recv(status=status, tag=tags.result, source=mpi.ANY_SOURCE)
875  source = status.source
876  self.log("gather from slave", source)
877  output = ingestResults(output, slaveResults, distribution[source])
878  pending -= 1
879 
880  self.log("done")
881  return output
882 
883  def mapToPrevious(self, context, func, dataList, *args, **kwargs):
884  """!Scatter work to the same target as before
885 
886  Work is distributed so that each slave handles the same
887  indices in the dataList as when 'map' was called.
888  This allows the right data to go to the right cache.
889 
890  It is assumed that the dataList is the same length as when it was
891  passed to 'map'.
892 
893  The 'func' signature should be func(cache, data, *args, **kwargs).
894 
895  @param context: Namespace for cache
896  @param func: function for slaves to run; must be picklable
897  @param dataList: List of data to distribute to slaves; must be picklable
898  @param args: List of constant arguments
899  @param kwargs: Dict of constant arguments
900  @return list of results from applying 'func' to dataList
901  """
902  return self.reduceToPrevious(context, None, func, dataList, *args, **kwargs)
903 
904  @abortOnError
905  @catchPicklingError
906  def reduceToPrevious(self, context, reducer, func, dataList, *args, **kwargs):
907  """!Reduction where work goes to the same target as before
908 
909  Work is distributed so that each slave handles the same
910  indices in the dataList as when 'map' was called.
911  This allows the right data to go to the right cache.
912 
913  It is assumed that the dataList is the same length as when it was
914  passed to 'map'.
915 
916  The 'func' signature should be func(cache, data, *args, **kwargs).
917 
918  The 'reducer' signature should be reducer(old, new). If the 'reducer'
919  is None, then we will return the full list of results
920 
921  @param context: Namespace for cache
922  @param reducer: function for master to run to reduce slave results; or None
923  @param func: function for slaves to run; must be picklable
924  @param dataList: List of data to distribute to slaves; must be picklable
925  @param args: List of constant arguments
926  @param kwargs: Dict of constant arguments
927  @return reduced result (if reducer is non-None) or list of results
928  from applying 'func' to dataList
929  """
930  if context is None:
931  raise ValueError("context must be set to map to same nodes as previous context")
932  tags = Tags("result", "work")
933  num = len(dataList)
934  if self.size == 1 or num <= 1:
935  # Can do everything here
936  return self._reduceQueue(context, reducer, func, list(zip(range(num), dataList)), *args, **kwargs)
937  if self.size == num:
938  # We're shooting ourselves in the foot using dynamic distribution
939  return self.reduceNoBalance(context, reducer, func, dataList, *args, **kwargs)
940 
941  self.command("mapToPrevious")
942 
943  # Send function
944  self.log("instruct")
945  self.comm.broadcast((tags, func, args, kwargs, context), root=self.root)
946 
947  requestList = self.comm.gather(None, root=self.root)
948  self.log("listen", requestList)
949  initial = [dataList[index] if (index is not None and index >= 0) else None for index in requestList]
950  self.log("scatter jobs", initial)
951  self.comm.scatter(initial, root=self.root)
952  pending = min(num, self.size - 1)
953 
954  if reducer is None:
955  output = [None]*num
956  else:
957  thread = ReductionThread(reducer)
958  thread.start()
959 
960  while pending > 0:
961  status = mpi.Status()
962  index, result, nextIndex = self.comm.recv(status=status, tag=tags.result, source=mpi.ANY_SOURCE)
963  source = status.source
964  self.log("gather from slave", source)
965  if reducer is None:
966  output[index] = result
967  else:
968  thread.add(result)
969 
970  if nextIndex >= 0:
971  job = dataList[nextIndex]
972  self.log("send job to slave", source)
973  self.comm.send(job, source, tag=tags.work)
974  else:
975  pending -= 1
976 
977  self.log("waiting on", pending)
978 
979  if reducer is not None:
980  output = thread.join()
981 
982  self.log("done")
983  return output
984 
985  @abortOnError
986  @catchPicklingError
987  def storeSet(self, context, **kwargs):
988  """!Store data on slave for a particular context
989 
990  The data is made available to functions through the cache. The
991  stored data differs from the cache in that it is identical for
992  all operations, whereas the cache is specific to the data being
993  operated upon.
994 
995  @param context: namespace for store
996  @param kwargs: dict of name=value pairs
997  """
998  super(PoolMaster, self).storeSet(context, **kwargs)
999  self.command("storeSet")
1000  self.log("give data")
1001  self.comm.broadcast((context, kwargs), root=self.root)
1002  self.log("done")
1003 
1004  @abortOnError
1005  def storeDel(self, context, *nameList):
1006  """Delete stored data on slave for a particular context"""
1007  super(PoolMaster, self).storeDel(context, *nameList)
1008  self.command("storeDel")
1009  self.log("tell names")
1010  self.comm.broadcast((context, nameList), root=self.root)
1011  self.log("done")
1012 
1013  @abortOnError
1014  def storeClear(self, context):
1015  """Reset data store for a particular context on master and slaves"""
1016  super(PoolMaster, self).storeClear(context)
1017  self.command("storeClear")
1018  self.comm.broadcast(context, root=self.root)
1019 
1020  @abortOnError
1021  def cacheClear(self, context):
1022  """Reset cache for a particular context on master and slaves"""
1023  super(PoolMaster, self).cacheClear(context)
1024  self.command("cacheClear")
1025  self.comm.broadcast(context, root=self.root)
1026 
1027  @abortOnError
1028  def cacheList(self, context):
1029  """List cache contents for a particular context on master and slaves"""
1030  super(PoolMaster, self).cacheList(context)
1031  self.command("cacheList")
1032  self.comm.broadcast(context, root=self.root)
1033 
1034  @abortOnError
1035  def storeList(self, context):
1036  """List store contents for a particular context on master and slaves"""
1037  super(PoolMaster, self).storeList(context)
1038  self.command("storeList")
1039  self.comm.broadcast(context, root=self.root)
1040 
1041  def exit(self):
1042  """Command slaves to exit"""
1043  self.command("exit")
1044 
1045 
1047  """Slave node instance of MPI process pool"""
1048 
1049  def log(self, msg, *args):
1050  """Log a debugging message"""
1051  assert self.rank != self.root, "This is not the master node."
1052  self.debugger.log("Slave %d" % self.rank, msg, *args)
1053 
1054  @abortOnError
1055  def run(self):
1056  """Serve commands of master node
1057 
1058  Slave accepts commands, which are the names of methods to execute.
1059  This exits when a command returns a true value.
1060  """
1061  menu = dict((cmd, getattr(self, cmd)) for cmd in ("reduce", "mapNoBalance", "mapToPrevious",
1062  "storeSet", "storeDel", "storeClear", "storeList",
1063  "cacheList", "cacheClear", "exit",))
1064  self.log("waiting for command from", self.root)
1065  command = self.comm.broadcast(None, root=self.root)
1066  self.log("command", command)
1067  while not menu[command]():
1068  self.log("waiting for command from", self.root)
1069  command = self.comm.broadcast(None, root=self.root)
1070  self.log("command", command)
1071  self.log("exiting")
1072 
1073  @catchPicklingError
1074  def reduce(self):
1075  """Reduce scattered data and return results"""
1076  self.log("waiting for instruction")
1077  tags, func, reducer, args, kwargs, context = self.comm.broadcast(None, root=self.root)
1078  self.log("waiting for job")
1079  job = self.comm.scatter(None, root=self.root)
1080 
1081  out = None # Reduction result
1082  while not isinstance(job, NoOp):
1083  index, data = job
1084  self.log("running job")
1085  result = self._processQueue(context, func, [(index, data)], *args, **kwargs)[0]
1086  if reducer is None:
1087  report = (index, result)
1088  else:
1089  report = None
1090  out = reducer(out, result) if out is not None else result
1091  self.comm.send(report, self.root, tag=tags.request)
1092  self.log("waiting for job")
1093  job = self.comm.recv(tag=tags.work, source=self.root)
1094 
1095  if reducer is not None:
1096  self.comm.gather(out, root=self.root)
1097  self.log("done")
1098 
1099  @catchPicklingError
1100  def mapNoBalance(self):
1101  """Process bulk scattered data and return results"""
1102  self.log("waiting for instruction")
1103  tags, func, args, kwargs, context = self.comm.broadcast(None, root=self.root)
1104  self.log("waiting for job")
1105  queue = self.comm.recv(tag=tags.work, source=self.root)
1106 
1107  resultList = []
1108  for index, data in queue:
1109  self.log("running job", index)
1110  result = self._processQueue(context, func, [(index, data)], *args, **kwargs)[0]
1111  resultList.append(result)
1112 
1113  self.comm.send(resultList, self.root, tag=tags.result)
1114  self.log("done")
1115 
1116  @catchPicklingError
1117  def mapToPrevious(self):
1118  """Process the same scattered data processed previously"""
1119  self.log("waiting for instruction")
1120  tags, func, args, kwargs, context = self.comm.broadcast(None, root=self.root)
1121  queue = list(self._cache[context].keys()) if context in self._cache else None
1122  index = queue.pop(0) if queue else -1
1123  self.log("request job", index)
1124  self.comm.gather(index, root=self.root)
1125  self.log("waiting for job")
1126  data = self.comm.scatter(None, root=self.root)
1127 
1128  while index >= 0:
1129  self.log("running job")
1130  result = func(self._getCache(context, index), data, *args, **kwargs)
1131  self.log("pending", queue)
1132  nextIndex = queue.pop(0) if queue else -1
1133  self.comm.send((index, result, nextIndex), self.root, tag=tags.result)
1134  index = nextIndex
1135  if index >= 0:
1136  data = self.comm.recv(tag=tags.work, source=self.root)
1137 
1138  self.log("done")
1139 
1140  def storeSet(self):
1141  """Set value in store"""
1142  context, kwargs = self.comm.broadcast(None, root=self.root)
1143  super(PoolSlave, self).storeSet(context, **kwargs)
1144 
1145  def storeDel(self):
1146  """Delete value in store"""
1147  context, nameList = self.comm.broadcast(None, root=self.root)
1148  super(PoolSlave, self).storeDel(context, *nameList)
1149 
1150  def storeClear(self):
1151  """Reset data store"""
1152  context = self.comm.broadcast(None, root=self.root)
1153  super(PoolSlave, self).storeClear(context)
1154 
1155  def cacheClear(self):
1156  """Reset cache"""
1157  context = self.comm.broadcast(None, root=self.root)
1158  super(PoolSlave, self).cacheClear(context)
1159 
1160  def cacheList(self):
1161  """List cache contents"""
1162  context = self.comm.broadcast(None, root=self.root)
1163  super(PoolSlave, self).cacheList(context)
1164 
1165  def storeList(self):
1166  """List store contents"""
1167  context = self.comm.broadcast(None, root=self.root)
1168  super(PoolSlave, self).storeList(context)
1169 
1170  def exit(self):
1171  """Allow exit from loop in 'run'"""
1172  return True
1173 
1174 
1175 class PoolWrapperMeta(type):
1176  """Metaclass for PoolWrapper to add methods pointing to PoolMaster
1177 
1178  The 'context' is automatically supplied to these methods as the first argument.
1179  """
1180 
1181  def __call__(self, context="default"):
1182  instance = super(PoolWrapperMeta, self).__call__(context)
1183  pool = PoolMaster()
1184  for name in ("map", "mapNoBalance", "mapToPrevious",
1185  "reduce", "reduceNoBalance", "reduceToPrevious",
1186  "storeSet", "storeDel", "storeClear", "storeList",
1187  "cacheList", "cacheClear",):
1188  setattr(instance, name, partial(getattr(pool, name), context))
1189  return instance
1190 
1191 
1192 class PoolWrapper(with_metaclass(PoolWrapperMeta, object)):
1193  """Wrap PoolMaster to automatically provide context"""
1194 
1195  def __init__(self, context="default"):
1196  self._pool = PoolMaster._instance
1197  self._context = context
1198 
1199  def __getattr__(self, name):
1200  return getattr(self._pool, name)
1201 
1202 
1203 class Pool(PoolWrapper): # Just gives PoolWrapper a nicer name for the user
1204  """Process Pool
1205 
1206  Use this class to automatically provide 'context' to
1207  the PoolMaster class. If you want to call functions
1208  that don't take a 'cache' object, use the PoolMaster
1209  class directly, and specify context=None.
1210  """
1211  pass
1212 
1213 
1214 def startPool(comm=None, root=0, killSlaves=True):
1215  """!Start a process pool.
1216 
1217  Returns a PoolMaster object for the master node.
1218  Slave nodes are run and then optionally killed.
1219 
1220  If you elect not to kill the slaves, note that they
1221  will emerge at the point this function was called,
1222  which is likely very different from the point the
1223  master is at, so it will likely be necessary to put
1224  in some rank dependent code (e.g., look at the 'rank'
1225  attribute of the returned pools).
1226 
1227  Note that the pool objects should be deleted (either
1228  by going out of scope or explicit 'del') before program
1229  termination to avoid a segmentation fault.
1230 
1231  @param comm: MPI communicator
1232  @param root: Rank of root/master node
1233  @param killSlaves: Kill slaves on completion?
1234  """
1235  if comm is None:
1236  comm = Comm()
1237  if comm.rank == root:
1238  return PoolMaster(comm, root=root)
1239  slave = PoolSlave(comm, root=root)
1240  slave.run()
1241  if killSlaves:
1242  del slave # Required to prevent segmentation fault on exit
1243  sys.exit()
1244  return slave
def _reduceQueue(self, context, reducer, func, queue, args, kwargs)
Reduce a queue of data.
Definition: pool.py:546
def __call__(self, args, kwargs)
Definition: pool.py:385
def __new__(cls, hold=None)
Definition: pool.py:140
def reduceToPrevious(self, context, reducer, func, dataList, args, kwargs)
Reduction where work goes to the same target as before.
Definition: pool.py:906
def log(self, source, msg, args)
Log message.
Definition: pool.py:402
def pickleInstanceMethod(method)
Definition: pool.py:52
def send(self, obj=None, args, kwargs)
Definition: pool.py:270
def storeSet(self, context, kwargs)
Definition: pool.py:582
def __init__(self, nameList)
Definition: pool.py:342
def unpickleFunction(moduleName, funcName)
Definition: pool.py:65
def guessPickleObj()
Definition: pool.py:161
def command(self, cmd)
Definition: pool.py:646
def storeDel(self, context, nameList)
Definition: pool.py:1005
def __init__(self, args, kwargs)
Definition: pool.py:634
def __init__(self, reducer, initial=None, sleep=0.1)
Constructor.
Definition: pool.py:426
def _processQueue(self, context, func, queue, args, kwargs)
Process a queue of data.
Definition: pool.py:527
def cacheList(self, context)
Definition: pool.py:1028
def __repr__(self)
Definition: pool.py:347
def __init__(self, hold=None)
Definition: pool.py:147
def reduce(self, context, reducer, func, dataList, args, kwargs)
Scatter work to slaves and reduce the results.
Definition: pool.py:681
def log(self, msg, args)
Definition: pool.py:642
def broadcast(self, value, root=0)
Definition: pool.py:300
def mapToPrevious(self, context, func, dataList, args, kwargs)
Scatter work to the same target as before.
Definition: pool.py:883
def abortOnError(func)
Definition: pool.py:108
def startPool(comm=None, root=0, killSlaves=True)
Start a process pool.
Definition: pool.py:1214
def __call__(self, context="default")
Definition: pool.py:1181
def unpickleInstanceMethod(obj, name)
Definition: pool.py:43
def map(self, context, func, dataList, args, kwargs)
Scatter work to slaves and gather the results.
Definition: pool.py:654
Metaclass to produce a singleton.
Definition: pool.py:365
def storeClear(self, context)
Definition: pool.py:598
def __new__(cls, comm=mpi.COMM_WORLD, recvSleep=0.1, barrierSleep=0.1)
Construct an MPI.Comm wrapper.
Definition: pool.py:249
def __init__(self, comm=None, root=0)
Definition: pool.py:492
def storeClear(self, context)
Definition: pool.py:1014
def storeDel(self, context, nameList)
Definition: pool.py:590
def recv(self, obj=None, source=0, tag=0, status=None)
Definition: pool.py:263
def _checkBarrierComm(self)
Definition: pool.py:274
def pickleFunction(function)
Definition: pool.py:76
def __exit__(self, excType, excVal, tb)
Definition: pool.py:155
def cacheList(self, context)
Definition: pool.py:612
def __reduce__(self)
Definition: pool.py:350
def storeList(self, context)
Definition: pool.py:617
def storeList(self, context)
Definition: pool.py:1035
def cacheClear(self, context)
Definition: pool.py:1021
def scatter(self, dataList, root=0, tag=0)
Definition: pool.py:304
def storeSet(self, context, kwargs)
Store data on slave for a particular context.
Definition: pool.py:987
def log(self, msg, args)
Definition: pool.py:1049
def setBatchType(batchType)
Definition: pool.py:103
def getBatchType()
Definition: pool.py:99
def pickleSniffer(abort=False)
Definition: pool.py:186
def Barrier(self, tag=0)
Definition: pool.py:279
def __init__(self, name, bases, dict_)
Definition: pool.py:381
def __init__(self, context="default")
Definition: pool.py:1195
def mapNoBalance(self, context, func, dataList, args, kwargs)
Scatter work to slaves and gather the results.
Definition: pool.py:761
def catchPicklingError(func)
Definition: pool.py:232
def __getattr__(self, name)
Definition: pool.py:1199
def log(self, msg, args)
Definition: pool.py:520
def __init__(self, comm)
Definition: pool.py:361
def reduceNoBalance(self, context, reducer, func, dataList, args, kwargs)
Scatter work to slaves and reduce the results.
Definition: pool.py:787
def _getCache(self, context, index)
Definition: pool.py:504
def cacheClear(self, context)
Definition: pool.py:605