Browse Source

Pool: Fixes race condition when MainThread is waiting for putlock. Closes #264

Ask Solem 14 years ago
parent
commit
fef61fb4a3
1 changed files with 41 additions and 13 deletions
  1. 41 13
      celery/concurrency/processes/pool.py

+ 41 - 13
celery/concurrency/processes/pool.py

@@ -75,17 +75,19 @@ class LaxBoundedSemaphore(threading._Semaphore):
 
 
     def release(self):
     def release(self):
         if self._Semaphore__value < self._initial_value:
         if self._Semaphore__value < self._initial_value:
-            return _Semaphore.release(self)
+            _Semaphore.release(self)
         if __debug__:
         if __debug__:
             self._note("%s.release: success, value=%s (unchanged)" % (
             self._note("%s.release: success, value=%s (unchanged)" % (
                 self, self._Semaphore__value))
                 self, self._Semaphore__value))
 
 
+    def clear(self):
+        while self._Semaphore__value < self._initial_value:
+            _Semaphore.release(self)
 
 
 #
 #
-# Code run by worker processes
+# Exceptions
 #
 #
 
 
-
 class MaybeEncodingError(Exception):
 class MaybeEncodingError(Exception):
     """Wraps unpickleable object."""
     """Wraps unpickleable object."""
 
 
@@ -102,9 +104,18 @@ class MaybeEncodingError(Exception):
                     self.value, self.exc)
                     self.value, self.exc)
 
 
 
 
+class WorkersJoined(Exception):
+    """All workers have terminated."""
+
+
 def soft_timeout_sighandler(signum, frame):
 def soft_timeout_sighandler(signum, frame):
     raise SoftTimeLimitExceeded()
     raise SoftTimeLimitExceeded()
 
 
+#
+# Code run by worker processes
+#
+
+
 
 
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     pid = os.getpid()
     pid = os.getpid()
@@ -410,10 +421,7 @@ class ResultHandler(PoolThread):
 
 
                 on_state_change(task)
                 on_state_change(task)
 
 
-        # Notify waiting threads
-        if putlock is not None:
-            putlock.release()
-
+        time_terminate = None
         while cache and self._state != TERMINATE:
         while cache and self._state != TERMINATE:
             try:
             try:
                 ready, task = poll(0.2)
                 ready, task = poll(0.2)
@@ -427,7 +435,19 @@ class ResultHandler(PoolThread):
                     continue
                     continue
 
 
                 on_state_change(task)
                 on_state_change(task)
-            join_exited_workers()
+            try:
+                join_exited_workers(shutdown=True)
+            except WorkersJoined:
+                now = time.time()
+                if not time_terminate:
+                    time_terminate = now
+                else:
+                    if now - time_terminate > 5.0:
+                        debug('result handler exiting: timed out')
+                        break
+                    debug('result handler: all workers terminated, '
+                          'timeout in %ss' % (
+                              abs(min(now - time_terminate - 5.0, 0))))
 
 
         if hasattr(outqueue, '_reader'):
         if hasattr(outqueue, '_reader'):
             debug('ensuring that outqueue is not full')
             debug('ensuring that outqueue is not full')
@@ -536,7 +556,7 @@ class Pool(object):
         w.start()
         w.start()
         return w
         return w
 
 
-    def _join_exited_workers(self, lost_worker_timeout=10.0):
+    def _join_exited_workers(self, shutdown=False, lost_worker_timeout=10.0):
         """Cleanup after any worker processes which have exited due to
         """Cleanup after any worker processes which have exited due to
         reaching their specified lifetime. Returns True if any workers were
         reaching their specified lifetime. Returns True if any workers were
         cleaned up.
         cleaned up.
@@ -552,13 +572,17 @@ class Pool(object):
                 err = WorkerLostError("Worker exited prematurely.")
                 err = WorkerLostError("Worker exited prematurely.")
                 job._set(None, (False, err))
                 job._set(None, (False, err))
 
 
+        if shutdown and not len(self._pool):
+            raise WorkersJoined()
+
         cleaned = []
         cleaned = []
         for i in reversed(range(len(self._pool))):
         for i in reversed(range(len(self._pool))):
             worker = self._pool[i]
             worker = self._pool[i]
             if worker.exitcode is not None:
             if worker.exitcode is not None:
                 # worker exited
                 # worker exited
-                debug('cleaning up worker %d' % i)
+                debug('Supervisor: cleaning up worker %d' % i)
                 worker.join()
                 worker.join()
+                debug('Supervisor: worked %d joined' % i)
                 cleaned.append(worker.pid)
                 cleaned.append(worker.pid)
                 del self._pool[i]
                 del self._pool[i]
         if cleaned:
         if cleaned:
@@ -711,8 +735,10 @@ class Pool(object):
         result = ApplyResult(self._cache, callback,
         result = ApplyResult(self._cache, callback,
                              accept_callback, timeout_callback,
                              accept_callback, timeout_callback,
                              error_callback)
                              error_callback)
-        if waitforslot:
+        if waitforslot and self._putlock is not None:
             self._putlock.acquire()
             self._putlock.acquire()
+            if self._state != RUN:
+                return
         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
         return result
         return result
 
 
@@ -758,6 +784,8 @@ class Pool(object):
             self._worker_handler.close()
             self._worker_handler.close()
             self._worker_handler.join()
             self._worker_handler.join()
             self._taskqueue.put(None)
             self._taskqueue.put(None)
+            if self._putlock:
+                self._putlock.clear()
 
 
     def terminate(self):
     def terminate(self):
         debug('terminating pool')
         debug('terminating pool')
@@ -773,6 +801,7 @@ class Pool(object):
         self._task_handler.join()
         self._task_handler.join()
         debug('joining result handler')
         debug('joining result handler')
         self._result_handler.join()
         self._result_handler.join()
+        debug('result handler joined')
         for i, p in enumerate(self._pool):
         for i, p in enumerate(self._pool):
             debug('joining worker %s/%s (%r)' % (i, len(self._pool), p, ))
             debug('joining worker %s/%s (%r)' % (i, len(self._pool), p, ))
             p.join()
             p.join()
@@ -802,8 +831,6 @@ class Pool(object):
         debug('helping task handler/workers to finish')
         debug('helping task handler/workers to finish')
         cls._help_stuff_finish(inqueue, task_handler, len(pool))
         cls._help_stuff_finish(inqueue, task_handler, len(pool))
 
 
-        assert result_handler.is_alive() or len(cache) == 0
-
         result_handler.terminate()
         result_handler.terminate()
         outqueue.put(None)                  # sentinel
         outqueue.put(None)                  # sentinel
 
 
@@ -834,6 +861,7 @@ class Pool(object):
                     # worker has not yet exited
                     # worker has not yet exited
                     debug('cleaning up worker %d' % p.pid)
                     debug('cleaning up worker %d' % p.pid)
                     p.join()
                     p.join()
+            debug('pool workers joined')
 DynamicPool = Pool
 DynamicPool = Pool
 
 
 #
 #