Explorar o código

Pool: Fixes race condition when MainThread is waiting for putlock. Closes #264

Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
fef61fb4a3
Modificáronse 1 ficheiros con 41 adicións e 13 borrados
  1. 41 13
      celery/concurrency/processes/pool.py

+ 41 - 13
celery/concurrency/processes/pool.py

@@ -75,17 +75,19 @@ class LaxBoundedSemaphore(threading._Semaphore):
 
     def release(self):
         if self._Semaphore__value < self._initial_value:
-            return _Semaphore.release(self)
+            _Semaphore.release(self)
         if __debug__:
             self._note("%s.release: success, value=%s (unchanged)" % (
                 self, self._Semaphore__value))
 
+    def clear(self):
+        while self._Semaphore__value < self._initial_value:
+            _Semaphore.release(self)
 
 #
-# Code run by worker processes
+# Exceptions
 #
 
-
 class MaybeEncodingError(Exception):
     """Wraps unpickleable object."""
 
@@ -102,9 +104,18 @@ class MaybeEncodingError(Exception):
                     self.value, self.exc)
 
 
+class WorkersJoined(Exception):
+    """All workers have terminated."""
+
+
 def soft_timeout_sighandler(signum, frame):
     raise SoftTimeLimitExceeded()
 
+#
+# Code run by worker processes
+#
+
+
 
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     pid = os.getpid()
@@ -410,10 +421,7 @@ class ResultHandler(PoolThread):
 
                 on_state_change(task)
 
-        # Notify waiting threads
-        if putlock is not None:
-            putlock.release()
-
+        time_terminate = None
         while cache and self._state != TERMINATE:
             try:
                 ready, task = poll(0.2)
@@ -427,7 +435,19 @@ class ResultHandler(PoolThread):
                     continue
 
                 on_state_change(task)
-            join_exited_workers()
+            try:
+                join_exited_workers(shutdown=True)
+            except WorkersJoined:
+                now = time.time()
+                if not time_terminate:
+                    time_terminate = now
+                else:
+                    if now - time_terminate > 5.0:
+                        debug('result handler exiting: timed out')
+                        break
+                    debug('result handler: all workers terminated, '
+                          'timeout in %ss' % (
+                              abs(min(now - time_terminate - 5.0, 0))))
 
         if hasattr(outqueue, '_reader'):
             debug('ensuring that outqueue is not full')
@@ -536,7 +556,7 @@ class Pool(object):
         w.start()
         return w
 
-    def _join_exited_workers(self, lost_worker_timeout=10.0):
+    def _join_exited_workers(self, shutdown=False, lost_worker_timeout=10.0):
         """Cleanup after any worker processes which have exited due to
         reaching their specified lifetime. Returns True if any workers were
         cleaned up.
@@ -552,13 +572,17 @@ class Pool(object):
                 err = WorkerLostError("Worker exited prematurely.")
                 job._set(None, (False, err))
 
+        if shutdown and not len(self._pool):
+            raise WorkersJoined()
+
         cleaned = []
         for i in reversed(range(len(self._pool))):
             worker = self._pool[i]
             if worker.exitcode is not None:
                 # worker exited
-                debug('cleaning up worker %d' % i)
+                debug('Supervisor: cleaning up worker %d' % i)
                 worker.join()
+                debug('Supervisor: worked %d joined' % i)
                 cleaned.append(worker.pid)
                 del self._pool[i]
         if cleaned:
@@ -711,8 +735,10 @@ class Pool(object):
         result = ApplyResult(self._cache, callback,
                              accept_callback, timeout_callback,
                              error_callback)
-        if waitforslot:
+        if waitforslot and self._putlock is not None:
             self._putlock.acquire()
+            if self._state != RUN:
+                return
         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
         return result
 
@@ -758,6 +784,8 @@ class Pool(object):
             self._worker_handler.close()
             self._worker_handler.join()
             self._taskqueue.put(None)
+            if self._putlock:
+                self._putlock.clear()
 
     def terminate(self):
         debug('terminating pool')
@@ -773,6 +801,7 @@ class Pool(object):
         self._task_handler.join()
         debug('joining result handler')
         self._result_handler.join()
+        debug('result handler joined')
         for i, p in enumerate(self._pool):
             debug('joining worker %s/%s (%r)' % (i, len(self._pool), p, ))
             p.join()
@@ -802,8 +831,6 @@ class Pool(object):
         debug('helping task handler/workers to finish')
         cls._help_stuff_finish(inqueue, task_handler, len(pool))
 
-        assert result_handler.is_alive() or len(cache) == 0
-
         result_handler.terminate()
         outqueue.put(None)                  # sentinel
 
@@ -834,6 +861,7 @@ class Pool(object):
                     # worker has not yet exited
                     debug('cleaning up worker %d' % p.pid)
                     p.join()
+            debug('pool workers joined')
 DynamicPool = Pool
 
 #