Browse Source

multiprocessing.pool: Proposed fix for Python Issue #9205 (http://bugs.python.org/issue9205)

Initial tests seems to show its working fine, but need more thorough testing.
Ask Solem 14 years ago
parent
commit
409811c9c2
1 changed files with 151 additions and 124 deletions
  1. 151 124
      celery/concurrency/processes/pool.py

+ 151 - 124
celery/concurrency/processes/pool.py

@@ -34,9 +34,24 @@ RUN = 0
 CLOSE = 1
 CLOSE = 1
 TERMINATE = 2
 TERMINATE = 2
 
 
+#
+# Constants representing the state of a job
+#
+
+ACK = 0
+READY = 1
+
 # Signal used for soft time limits.
 # Signal used for soft time limits.
 SIG_SOFT_TIMEOUT = getattr(signal, "SIGUSR1", None)
 SIG_SOFT_TIMEOUT = getattr(signal, "SIGUSR1", None)
 
 
+#
+# Exceptions
+#
+
+class WorkerLostError(Exception):
+    """The worker processing a job has exited prematurely."""
+    pass
+
 #
 #
 # Miscellaneous
 # Miscellaneous
 #
 #
@@ -71,13 +86,11 @@ def soft_timeout_sighandler(signum, frame):
     raise SoftTimeLimitExceeded()
     raise SoftTimeLimitExceeded()
 
 
 
 
-def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=(),
-        maxtasks=None):
+def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     pid = os.getpid()
     pid = os.getpid()
     put = outqueue.put
     put = outqueue.put
     get = inqueue.get
     get = inqueue.get
-    ack = ackqueue.put
     if hasattr(inqueue, '_writer'):
     if hasattr(inqueue, '_writer'):
         inqueue._writer.close()
         inqueue._writer.close()
         outqueue._reader.close()
         outqueue._reader.close()
@@ -101,16 +114,16 @@ def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=(),
             break
             break
 
 
         job, i, func, args, kwds = task
         job, i, func, args, kwds = task
-        ack((job, i, time.time(), pid))
+        put((ACK, (job, i, time.time(), pid)))
         try:
         try:
             result = (True, func(*args, **kwds))
             result = (True, func(*args, **kwds))
         except Exception, e:
         except Exception, e:
             result = (False, e)
             result = (False, e)
         try:
         try:
-            put((job, i, result))
+            put((READY, (job, i, result)))
         except Exception, exc:
         except Exception, exc:
             wrapped = MaybeEncodingError(exc, result[1])
             wrapped = MaybeEncodingError(exc, result[1])
-            put((job, i, (False, wrapped)))
+            put((READY, (job, i, (False, wrapped))))
 
 
         completed += 1
         completed += 1
     debug('worker exiting after %d tasks' % completed)
     debug('worker exiting after %d tasks' % completed)
@@ -199,65 +212,6 @@ class TaskHandler(PoolThread):
         debug('task handler exiting')
         debug('task handler exiting')
 
 
 
 
-class AckHandler(PoolThread):
-
-    def __init__(self, ackqueue, get, cache):
-        self.ackqueue = ackqueue
-        self.get = get
-        self.cache = cache
-
-        super(AckHandler, self).__init__()
-
-    def run(self):
-        debug('ack handler starting')
-        get = self.get
-        cache = self.cache
-
-        while 1:
-            try:
-                task = get()
-            except (IOError, EOFError), exc:
-                debug('ack handler got %s -- exiting',
-                        exc.__class__.__name__)
-
-            if self._state:
-                assert self._state == TERMINATE
-                debug('ack handler found thread._state=TERMINATE')
-                break
-
-            if task is None:
-                debug('ack handler got sentinel')
-                break
-
-            job, i, time_accepted, pid = task
-            try:
-                cache[job]._ack(i, time_accepted, pid)
-            except (KeyError, AttributeError), exc:
-                # Object gone, or doesn't support _ack (e.g. IMapIterator)
-                pass
-
-        while cache and self._state != TERMINATE:
-            try:
-                task = get()
-            except (IOError, EOFError), exc:
-                debug('ack handler got %s -- exiting',
-                        exc.__class__.__name__)
-                return
-
-            if task is None:
-                debug('result handler ignoring extra sentinel')
-                continue
-
-            job, i, time_accepted, pid = task
-            try:
-                cache[job]._ack(i, time_accepted, pid)
-            except KeyError:
-                pass
-
-        debug('ack handler exiting: len(cache)=%s, thread._state=%s',
-                len(cache), self._state)
-
-
 class TimeoutHandler(PoolThread):
 class TimeoutHandler(PoolThread):
 
 
     def __init__(self, processes, cache, t_soft, t_hard):
     def __init__(self, processes, cache, t_soft, t_hard):
@@ -348,10 +302,13 @@ class TimeoutHandler(PoolThread):
 
 
 class ResultHandler(PoolThread):
 class ResultHandler(PoolThread):
 
 
-    def __init__(self, outqueue, get, cache, putlock):
+    def __init__(self, outqueue, get, cache, poll,
+            join_exited_workers, putlock):
         self.outqueue = outqueue
         self.outqueue = outqueue
         self.get = get
         self.get = get
         self.cache = cache
         self.cache = cache
+        self.poll = poll
+        self.join_exited_workers = join_exited_workers
         self.putlock = putlock
         self.putlock = putlock
         super(ResultHandler, self).__init__()
         super(ResultHandler, self).__init__()
 
 
@@ -359,37 +316,58 @@ class ResultHandler(PoolThread):
         get = self.get
         get = self.get
         outqueue = self.outqueue
         outqueue = self.outqueue
         cache = self.cache
         cache = self.cache
+        poll = self.poll
+        join_exited_workers = self.join_exited_workers
         putlock = self.putlock
         putlock = self.putlock
 
 
+        def on_ack(job, i, time_accepted, pid):
+            try:
+                cache[job]._ack(i, time_accepted, pid)
+            except (KeyError, AttributeError):
+                # Object gone or doesn't support _ack (e.g. IMAPIterator).
+                pass
+
+        def on_ready(job, i, obj):
+            try:
+                cache[job]._set(i, obj)
+            except KeyError:
+                pass
+
+        state_handlers = {ACK: on_ack, READY: on_ready}
+
+        def on_state_change(task):
+            state, args = task
+            try:
+                state_handlers[state](*args)
+            except KeyError:
+                debug("Unknown job state: %s (args=%s)" % (state, args))
+
         debug('result handler starting')
         debug('result handler starting')
         while 1:
         while 1:
             try:
             try:
-                task = get()
+                ready, task = poll(0.2)
             except (IOError, EOFError), exc:
             except (IOError, EOFError), exc:
-                debug('result handler got %s -- exiting',
-                        exc.__class__.__name__)
+                debug('result handler got %r -- exiting' % (exc, ))
                 return
                 return
 
 
-            if putlock is not None:
-                try:
-                    putlock.release()
-                except ValueError:
-                    pass
-
             if self._state:
             if self._state:
                 assert self._state == TERMINATE
                 assert self._state == TERMINATE
                 debug('result handler found thread._state=TERMINATE')
                 debug('result handler found thread._state=TERMINATE')
                 break
                 break
 
 
-            if task is None:
-                debug('result handler got sentinel')
-                break
+            if ready:
+                if task is None:
+                    debug('result handler got sentinel')
+                    break
+
+                if putlock is not None:
+                    try:
+                        putlock.release()
+                    except ValueError:
+                        pass
+
+                on_state_change(task)
 
 
-            job, i, obj = task
-            try:
-                cache[job]._set(i, obj)
-            except KeyError:
-                pass
 
 
         if putlock is not None:
         if putlock is not None:
             try:
             try:
@@ -399,15 +377,19 @@ class ResultHandler(PoolThread):
 
 
         while cache and self._state != TERMINATE:
         while cache and self._state != TERMINATE:
             try:
             try:
-                task = get()
+                ready, task = poll(0.2)
             except (IOError, EOFError), exc:
             except (IOError, EOFError), exc:
-                debug('result handler got %s -- exiting',
-                        exc.__class__.__name__)
+                debug('result handler got %r -- exiting' % (exc, ))
                 return
                 return
 
 
-            if task is None:
-                debug('result handler ignoring extra sentinel')
-                continue
+            if ready:
+                if task is None:
+                    debug('result handler ignoring extra sentinel')
+                    continue
+
+                on_state_change(task)
+            join_exited_workers()
+
             job, i, obj = task
             job, i, obj = task
             try:
             try:
                 cache[job]._set(i, obj)
                 cache[job]._set(i, obj)
@@ -438,7 +420,6 @@ class Pool(object):
     Process = Process
     Process = Process
     Supervisor = Supervisor
     Supervisor = Supervisor
     TaskHandler = TaskHandler
     TaskHandler = TaskHandler
-    AckHandler = AckHandler
     TimeoutHandler = TimeoutHandler
     TimeoutHandler = TimeoutHandler
     ResultHandler = ResultHandler
     ResultHandler = ResultHandler
     SoftTimeLimitExceeded = SoftTimeLimitExceeded
     SoftTimeLimitExceeded = SoftTimeLimitExceeded
@@ -478,15 +459,12 @@ class Pool(object):
 
 
         self._putlock = threading.BoundedSemaphore(self._processes)
         self._putlock = threading.BoundedSemaphore(self._processes)
 
 
-        self._task_handler = self.TaskHandler(self._taskqueue, self._quick_put,
-                                         self._outqueue, self._pool)
+        self._task_handler = self.TaskHandler(self._taskqueue,
+                                              self._quick_put,
+                                              self._outqueue,
+                                              self._pool)
         self._task_handler.start()
         self._task_handler.start()
 
 
-        # Thread processing acknowledgements from the ackqueue.
-        self._ack_handler = self.AckHandler(self._ackqueue,
-                self._quick_get_ack, self._cache)
-        self._ack_handler.start()
-
         # Thread killing timedout jobs.
         # Thread killing timedout jobs.
         if self.timeout or self.soft_timeout:
         if self.timeout or self.soft_timeout:
             self._timeout_handler = self.TimeoutHandler(
             self._timeout_handler = self.TimeoutHandler(
@@ -499,14 +477,15 @@ class Pool(object):
         # Thread processing results in the outqueue.
         # Thread processing results in the outqueue.
         self._result_handler = self.ResultHandler(self._outqueue,
         self._result_handler = self.ResultHandler(self._outqueue,
                                         self._quick_get, self._cache,
                                         self._quick_get, self._cache,
+                                        self._poll_result,
+                                        self._join_exited_workers,
                                         self._putlock)
                                         self._putlock)
         self._result_handler.start()
         self._result_handler.start()
 
 
         self._terminate = Finalize(
         self._terminate = Finalize(
             self, self._terminate_pool,
             self, self._terminate_pool,
             args=(self._taskqueue, self._inqueue, self._outqueue,
             args=(self._taskqueue, self._inqueue, self._outqueue,
-                  self._ackqueue, self._pool, self._ack_handler,
-                  self._worker_handler, self._task_handler,
+                  self._pool, self._worker_handler, self._task_handler,
                   self._result_handler, self._cache,
                   self._result_handler, self._cache,
                   self._timeout_handler),
                   self._timeout_handler),
             exitpriority=15,
             exitpriority=15,
@@ -515,7 +494,7 @@ class Pool(object):
     def _create_worker_process(self):
     def _create_worker_process(self):
         w = self.Process(
         w = self.Process(
             target=worker,
             target=worker,
-            args=(self._inqueue, self._outqueue, self._ackqueue,
+            args=(self._inqueue, self._outqueue,
                     self._initializer, self._initargs,
                     self._initializer, self._initargs,
                     self._maxtasksperchild),
                     self._maxtasksperchild),
             )
             )
@@ -530,6 +509,7 @@ class Pool(object):
         reaching their specified lifetime. Returns True if any workers were
         reaching their specified lifetime. Returns True if any workers were
         cleaned up.
         cleaned up.
         """
         """
+        cleaned = []
         for i in reversed(range(len(self._pool))):
         for i in reversed(range(len(self._pool))):
             worker = self._pool[i]
             worker = self._pool[i]
             if worker.exitcode is not None:
             if worker.exitcode is not None:
@@ -541,8 +521,17 @@ class Pool(object):
                     except ValueError:
                     except ValueError:
                         pass
                         pass
                 worker.join()
                 worker.join()
+                cleaned.append(worker.pid)
                 del self._pool[i]
                 del self._pool[i]
-        return len(self._pool) < self._processes
+        if cleaned:
+            for job in self._cache.values():
+                for worker_pid in job.worker_pids():
+                    if worker_pid in cleaned:
+                        err = WorkerLostError("Worker exited prematurely.")
+                        job._set(None, (False, err))
+                        continue
+            return True
+        return False
 
 
     def _repopulate_pool(self):
     def _repopulate_pool(self):
         """Bring the number of pool processes up to the specified number,
         """Bring the number of pool processes up to the specified number,
@@ -550,6 +539,8 @@ class Pool(object):
         """
         """
         debug('repopulating pool')
         debug('repopulating pool')
         for i in range(self._processes - len(self._pool)):
         for i in range(self._processes - len(self._pool)):
+            if self._state != RUN:
+                return
             self._create_worker_process()
             self._create_worker_process()
             debug('added worker')
             debug('added worker')
 
 
@@ -563,10 +554,14 @@ class Pool(object):
         from multiprocessing.queues import SimpleQueue
         from multiprocessing.queues import SimpleQueue
         self._inqueue = SimpleQueue()
         self._inqueue = SimpleQueue()
         self._outqueue = SimpleQueue()
         self._outqueue = SimpleQueue()
-        self._ackqueue = SimpleQueue()
         self._quick_put = self._inqueue._writer.send
         self._quick_put = self._inqueue._writer.send
         self._quick_get = self._outqueue._reader.recv
         self._quick_get = self._outqueue._reader.recv
-        self._quick_get_ack = self._ackqueue._reader.recv
+
+        def _poll_result(timeout):
+            if self._outqueue._reader.poll(timeout):
+                return True, self._quick_get()
+            return False, None
+        self._poll_result = _poll_result
 
 
     def apply(self, func, args=(), kwds={}):
     def apply(self, func, args=(), kwds={}):
         '''
         '''
@@ -686,6 +681,7 @@ class Pool(object):
         if self._state == RUN:
         if self._state == RUN:
             self._state = CLOSE
             self._state = CLOSE
             self._worker_handler.close()
             self._worker_handler.close()
+            self._worker_handler.join()
             self._taskqueue.put(None)
             self._taskqueue.put(None)
 
 
     def terminate(self):
     def terminate(self):
@@ -696,12 +692,15 @@ class Pool(object):
 
 
     def join(self):
     def join(self):
         assert self._state in (CLOSE, TERMINATE)
         assert self._state in (CLOSE, TERMINATE)
+        debug('joining worker handler')
         self._worker_handler.join()
         self._worker_handler.join()
+        debug('joining task handler')
         self._task_handler.join()
         self._task_handler.join()
+        debug('joining result handler')
         self._result_handler.join()
         self._result_handler.join()
-        for p in self._pool:
+        for i, p in enumerate(self._pool):
+            debug('joining worker %s/%s (%r)' % (i, len(self._pool), p, ))
             p.join()
             p.join()
-        debug('after join()')
 
 
     @staticmethod
     @staticmethod
     def _help_stuff_finish(inqueue, task_handler, size):
     def _help_stuff_finish(inqueue, task_handler, size):
@@ -713,8 +712,8 @@ class Pool(object):
             time.sleep(0)
             time.sleep(0)
 
 
     @classmethod
     @classmethod
-    def _terminate_pool(cls, taskqueue, inqueue, outqueue, ackqueue, pool,
-                        ack_handler, worker_handler, task_handler,
+    def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
+                        worker_handler, task_handler,
                         result_handler, cache, timeout_handler):
                         result_handler, cache, timeout_handler):
 
 
         # this is guaranteed to only be called once
         # this is guaranteed to only be called once
@@ -733,9 +732,6 @@ class Pool(object):
         result_handler.terminate()
         result_handler.terminate()
         outqueue.put(None)                  # sentinel
         outqueue.put(None)                  # sentinel
 
 
-        ack_handler.terminate()
-        ackqueue.put(None)                  # sentinel
-
         if timeout_handler is not None:
         if timeout_handler is not None:
             timeout_handler.terminate()
             timeout_handler.terminate()
 
 
@@ -752,9 +748,6 @@ class Pool(object):
         debug('joining result handler')
         debug('joining result handler')
         result_handler.join(1e100)
         result_handler.join(1e100)
 
 
-        debug('joining ack handler')
-        ack_handler.join(1e100)
-
         if timeout_handler is not None:
         if timeout_handler is not None:
             debug('joining timeout handler')
             debug('joining timeout handler')
             timeout_handler.join(1e100)
             timeout_handler.join(1e100)
@@ -779,14 +772,15 @@ class ApplyResult(object):
         self._cond = threading.Condition(threading.Lock())
         self._cond = threading.Condition(threading.Lock())
         self._job = job_counter.next()
         self._job = job_counter.next()
         self._cache = cache
         self._cache = cache
-        self._accepted = False
-        self._accept_pid = None
-        self._time_accepted = None
         self._ready = False
         self._ready = False
         self._callback = callback
         self._callback = callback
-        self._errback = error_callback
         self._accept_callback = accept_callback
         self._accept_callback = accept_callback
+        self._errback = error_callback
         self._timeout_callback = timeout_callback
         self._timeout_callback = timeout_callback
+
+        self._accepted = False
+        self._worker_pid = None
+        self._time_accepted = None
         cache[self._job] = self
         cache[self._job] = self
 
 
     def ready(self):
     def ready(self):
@@ -799,6 +793,12 @@ class ApplyResult(object):
         assert self._ready
         assert self._ready
         return self._success
         return self._success
 
 
+    def accepted(self):
+        return self._accepted
+
+    def worker_pids(self):
+        return filter(None, [self._worker_pid])
+
     def wait(self, timeout=None):
     def wait(self, timeout=None):
         self._cond.acquire()
         self._cond.acquire()
         try:
         try:
@@ -829,16 +829,16 @@ class ApplyResult(object):
         finally:
         finally:
             self._cond.release()
             self._cond.release()
         if self._accepted:
         if self._accepted:
-            del self._cache[self._job]
+            self._cache.pop(self._job, None)
 
 
     def _ack(self, i, time_accepted, pid):
     def _ack(self, i, time_accepted, pid):
         self._accepted = True
         self._accepted = True
         self._time_accepted = time_accepted
         self._time_accepted = time_accepted
-        self._accept_pid = pid
+        self._worker_pid = pid
         if self._accept_callback:
         if self._accept_callback:
             self._accept_callback()
             self._accept_callback()
         if self._ready:
         if self._ready:
-            del self._cache[self._job]
+            self._cache.pop(self._job, None)
 
 
 #
 #
 # Class whose instances are returned by `Pool.map_async()`
 # Class whose instances are returned by `Pool.map_async()`
@@ -849,7 +849,11 @@ class MapResult(ApplyResult):
     def __init__(self, cache, chunksize, length, callback):
     def __init__(self, cache, chunksize, length, callback):
         ApplyResult.__init__(self, cache, callback)
         ApplyResult.__init__(self, cache, callback)
         self._success = True
         self._success = True
+        self._length = length
         self._value = [None] * length
         self._value = [None] * length
+        self._accepted = [False] * length
+        self._worker_pid = [None] * length
+        self._time_accepted = [None] * length
         self._chunksize = chunksize
         self._chunksize = chunksize
         if chunksize <= 0:
         if chunksize <= 0:
             self._number_left = 0
             self._number_left = 0
@@ -865,7 +869,8 @@ class MapResult(ApplyResult):
             if self._number_left == 0:
             if self._number_left == 0:
                 if self._callback:
                 if self._callback:
                     self._callback(self._value)
                     self._callback(self._value)
-                del self._cache[self._job]
+                if self._accepted:
+                    self._cache.pop(self._job, None)
                 self._cond.acquire()
                 self._cond.acquire()
                 try:
                 try:
                     self._ready = True
                     self._ready = True
@@ -876,7 +881,8 @@ class MapResult(ApplyResult):
         else:
         else:
             self._success = False
             self._success = False
             self._value = result
             self._value = result
-            del self._cache[self._job]
+            if self._accepted:
+                self._cache.pop(self._job, None)
             self._cond.acquire()
             self._cond.acquire()
             try:
             try:
                 self._ready = True
                 self._ready = True
@@ -884,6 +890,22 @@ class MapResult(ApplyResult):
             finally:
             finally:
                 self._cond.release()
                 self._cond.release()
 
 
+    def _ack(self, i, time_accepted, pid):
+        start = i * self._chunksize
+        stop = (i + 1) * self._chunksize
+        for j in range(start, stop):
+            self._accepted[j] = True
+            self._worker_pid[j] = pid
+            self._time_accepted[j] = time_accepted
+        if self._ready:
+            self._cache.pop(self._job, None)
+
+    def accepted(self):
+        return all(self._accepted)
+
+    def worker_pids(self):
+        return filter(None, self._worker_pid)
+
 #
 #
 # Class whose instances are returned by `Pool.imap()`
 # Class whose instances are returned by `Pool.imap()`
 #
 #
@@ -989,10 +1011,15 @@ class ThreadPool(Pool):
     def _setup_queues(self):
     def _setup_queues(self):
         self._inqueue = Queue.Queue()
         self._inqueue = Queue.Queue()
         self._outqueue = Queue.Queue()
         self._outqueue = Queue.Queue()
-        self._ackqueue = Queue.Queue()
         self._quick_put = self._inqueue.put
         self._quick_put = self._inqueue.put
         self._quick_get = self._outqueue.get
         self._quick_get = self._outqueue.get
-        self._quick_get_ack = self._ackqueue.get
+
+        def _poll_result(timeout):
+            try:
+                return True, self._quick_get(timeout=timeout)
+            except Queue.Empty:
+                return False, None
+        self._poll_result = _poll_result
 
 
     @staticmethod
     @staticmethod
     def _help_stuff_finish(inqueue, task_handler, size):
     def _help_stuff_finish(inqueue, task_handler, size):