Преглед на файлове

Worker Pool: Correctly handle lost worker processes. See Python bug 9205

http://bugs.python.org/issue9205
Ask Solem преди 14 години
родител
ревизия
69b56879e5
променени са 2 файла, в които са добавени 64 реда и са изтрити 13 реда
  1. 57 12
      celery/concurrency/processes/pool.py
  2. 7 1
      celery/exceptions.py

+ 57 - 12
celery/concurrency/processes/pool.py

@@ -25,6 +25,7 @@ from multiprocessing import Process, cpu_count, TimeoutError
 from multiprocessing.util import Finalize, debug
 
 from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded
+from celery.exceptions import WorkerLostError
 
 #
 # Constants representing the state of a pool
@@ -106,6 +107,13 @@ def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=(),
             result = (True, func(*args, **kwds))
         except Exception, e:
             result = (False, e)
+        except BaseException, e:
+            # Job raised SystemExit or equivalent, so tell the result
+            # handler and exit the process so it can be replaced.
+            err = WorkerLostError(
+                    "Worker has terminated by user request: %r" % (e, ))
+            put((job, i, (False, err)))
+            raise
         try:
             put((job, i, result))
         except Exception, exc:
@@ -348,11 +356,13 @@ class TimeoutHandler(PoolThread):
 
 class ResultHandler(PoolThread):
 
-    def __init__(self, outqueue, get, cache, putlock):
+    def __init__(self, outqueue, get, cache, putlock, poll, workers_gone):
         self.outqueue = outqueue
         self.get = get
         self.cache = cache
         self.putlock = putlock
+        self.poll = poll
+        self.workers_gone = workers_gone
         super(ResultHandler, self).__init__()
 
     def run(self):
@@ -360,6 +370,8 @@ class ResultHandler(PoolThread):
         outqueue = self.outqueue
         cache = self.cache
         putlock = self.putlock
+        poll = self.poll
+        workers_gone = self.workers_gone
 
         debug('result handler starting')
         while 1:
@@ -399,20 +411,30 @@ class ResultHandler(PoolThread):
 
         while cache and self._state != TERMINATE:
             try:
-                task = get()
+                ready, task = poll(0.2)
             except (IOError, EOFError), exc:
                 debug('result handler got %s -- exiting',
                         exc.__class__.__name__)
                 return
 
-            if task is None:
-                debug('result handler ignoring extra sentinel')
-                continue
-            job, i, obj = task
-            try:
-                cache[job]._set(i, obj)
-            except KeyError:
-                pass
+            if ready:
+                if task is None:
+                    debug('result handler ignoring extra sentinel')
+                    continue
+
+                job, i, obj = task
+                try:
+                    cache[job]._set(i, obj)
+                except KeyError:
+                    pass
+            else:
+                if workers_gone():
+                    debug("%s active job(s), but no active workers! "
+                          "Terminating..." % (len(cache), ))
+                    err = WorkerLostError(
+                            "The worker processing this job has terminated.")
+                    for job in cache.values():
+                        job._set(None, (False, err))
 
         if hasattr(outqueue, '_reader'):
             debug('ensuring that outqueue is not full')
@@ -499,7 +521,8 @@ class Pool(object):
         # Thread processing results in the outqueue.
         self._result_handler = self.ResultHandler(self._outqueue,
                                         self._quick_get, self._cache,
-                                        self._putlock)
+                                        self._putlock, self._poll_result,
+                                        self._workers_gone)
         self._result_handler.start()
 
         self._terminate = Finalize(
@@ -525,6 +548,10 @@ class Pool(object):
         w.start()
         return w
 
+    def _workers_gone(self):
+        self._join_exited_workers()
+        return not self._pool
+
     def _join_exited_workers(self):
         """Cleanup after any worker processes which have exited due to
         reaching their specified lifetime. Returns True if any workers were
@@ -568,6 +595,13 @@ class Pool(object):
         self._quick_get = self._outqueue._reader.recv
         self._quick_get_ack = self._ackqueue._reader.recv
 
+        def _poll_result(timeout):
+            if self._outqueue._reader.poll(timeout):
+                return True, self._quick_get()
+            return False, None
+
+        self._poll_result = _poll_result
+
     def apply(self, func, args=(), kwds={}):
         '''
         Equivalent of `apply()` builtin
@@ -684,8 +718,11 @@ class Pool(object):
     def close(self):
         debug('closing pool')
         if self._state == RUN:
-            self._state = CLOSE
+            # Worker handler can't run while the result
+            # handler does its second pass, so wait for it to finish.
             self._worker_handler.close()
+            self._worker_handler.join()
+            self._state = CLOSE
             self._taskqueue.put(None)
 
     def terminate(self):
@@ -994,6 +1031,14 @@ class ThreadPool(Pool):
         self._quick_get = self._outqueue.get
         self._quick_get_ack = self._ackqueue.get
 
+        def _poll_result(timeout):
+            try:
+                return True, self._quick_get(timeout=timeout)
+            except Queue.Empty:
+                return False, None
+
+        self._poll_result = _poll_result
+
     @staticmethod
     def _help_stuff_finish(inqueue, task_handler, size):
         # put sentinels at head of inqueue to make workers finish

+ 7 - 1
celery/exceptions.py

@@ -8,13 +8,14 @@ UNREGISTERED_FMT = """
 Task of kind %s is not registered, please make sure it's imported.
 """.strip()
 
-
 class QueueNotFound(KeyError):
     """Task routed to a queue not in CELERY_QUEUES."""
+    pass
 
 
 class TimeLimitExceeded(Exception):
     """The time limit has been exceeded and the job has been terminated."""
+    pass
 
 
 class SoftTimeLimitExceeded(Exception):
@@ -23,6 +24,11 @@ class SoftTimeLimitExceeded(Exception):
     pass
 
 
+class WorkerLostError(Exception):
+    """The worker processing a task has exited prematurely."""
+    pass
+
+
 class ImproperlyConfigured(Exception):
     """Celery is somehow improperly configured."""
     pass