Bladeren bron

Worker Pool: Correctly handle lost worker processes. See Python bug 9205

http://bugs.python.org/issue9205
Ask Solem 14 jaren geleden
bovenliggende
commit
69b56879e5
2 gewijzigde bestanden met toevoegingen van 64 en 13 verwijderingen
  1. 57 12
      celery/concurrency/processes/pool.py
  2. 7 1
      celery/exceptions.py

+ 57 - 12
celery/concurrency/processes/pool.py

@@ -25,6 +25,7 @@ from multiprocessing import Process, cpu_count, TimeoutError
 from multiprocessing.util import Finalize, debug
 from multiprocessing.util import Finalize, debug
 
 
 from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded
 from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded
+from celery.exceptions import WorkerLostError
 
 
 #
 #
 # Constants representing the state of a pool
 # Constants representing the state of a pool
@@ -106,6 +107,13 @@ def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=(),
             result = (True, func(*args, **kwds))
             result = (True, func(*args, **kwds))
         except Exception, e:
         except Exception, e:
             result = (False, e)
             result = (False, e)
+        except BaseException, e:
+            # Job raised SystemExit or equivalent, so tell the result
+            # handler and exit the process so it can be replaced.
+            err = WorkerLostError(
+                    "Worker has terminated by user request: %r" % (e, ))
+            put((job, i, (False, err)))
+            raise
         try:
         try:
             put((job, i, result))
             put((job, i, result))
         except Exception, exc:
         except Exception, exc:
@@ -348,11 +356,13 @@ class TimeoutHandler(PoolThread):
 
 
 class ResultHandler(PoolThread):
 class ResultHandler(PoolThread):
 
 
-    def __init__(self, outqueue, get, cache, putlock):
+    def __init__(self, outqueue, get, cache, putlock, poll, workers_gone):
         self.outqueue = outqueue
         self.outqueue = outqueue
         self.get = get
         self.get = get
         self.cache = cache
         self.cache = cache
         self.putlock = putlock
         self.putlock = putlock
+        self.poll = poll
+        self.workers_gone = workers_gone
         super(ResultHandler, self).__init__()
         super(ResultHandler, self).__init__()
 
 
     def run(self):
     def run(self):
@@ -360,6 +370,8 @@ class ResultHandler(PoolThread):
         outqueue = self.outqueue
         outqueue = self.outqueue
         cache = self.cache
         cache = self.cache
         putlock = self.putlock
         putlock = self.putlock
+        poll = self.poll
+        workers_gone = self.workers_gone
 
 
         debug('result handler starting')
         debug('result handler starting')
         while 1:
         while 1:
@@ -399,20 +411,30 @@ class ResultHandler(PoolThread):
 
 
         while cache and self._state != TERMINATE:
         while cache and self._state != TERMINATE:
             try:
             try:
-                task = get()
+                ready, task = poll(0.2)
             except (IOError, EOFError), exc:
             except (IOError, EOFError), exc:
                 debug('result handler got %s -- exiting',
                 debug('result handler got %s -- exiting',
                         exc.__class__.__name__)
                         exc.__class__.__name__)
                 return
                 return
 
 
-            if task is None:
-                debug('result handler ignoring extra sentinel')
-                continue
-            job, i, obj = task
-            try:
-                cache[job]._set(i, obj)
-            except KeyError:
-                pass
+            if ready:
+                if task is None:
+                    debug('result handler ignoring extra sentinel')
+                    continue
+
+                job, i, obj = task
+                try:
+                    cache[job]._set(i, obj)
+                except KeyError:
+                    pass
+            else:
+                if workers_gone():
+                    debug("%s active job(s), but no active workers! "
+                          "Terminating..." % (len(cache), ))
+                    err = WorkerLostError(
+                            "The worker processing this job has terminated.")
+                    for job in cache.values():
+                        job._set(None, (False, err))
 
 
         if hasattr(outqueue, '_reader'):
         if hasattr(outqueue, '_reader'):
             debug('ensuring that outqueue is not full')
             debug('ensuring that outqueue is not full')
@@ -499,7 +521,8 @@ class Pool(object):
         # Thread processing results in the outqueue.
         # Thread processing results in the outqueue.
         self._result_handler = self.ResultHandler(self._outqueue,
         self._result_handler = self.ResultHandler(self._outqueue,
                                         self._quick_get, self._cache,
                                         self._quick_get, self._cache,
-                                        self._putlock)
+                                        self._putlock, self._poll_result,
+                                        self._workers_gone)
         self._result_handler.start()
         self._result_handler.start()
 
 
         self._terminate = Finalize(
         self._terminate = Finalize(
@@ -525,6 +548,10 @@ class Pool(object):
         w.start()
         w.start()
         return w
         return w
 
 
+    def _workers_gone(self):
+        self._join_exited_workers()
+        return not self._pool
+
     def _join_exited_workers(self):
     def _join_exited_workers(self):
         """Cleanup after any worker processes which have exited due to
         """Cleanup after any worker processes which have exited due to
         reaching their specified lifetime. Returns True if any workers were
         reaching their specified lifetime. Returns True if any workers were
@@ -568,6 +595,13 @@ class Pool(object):
         self._quick_get = self._outqueue._reader.recv
         self._quick_get = self._outqueue._reader.recv
         self._quick_get_ack = self._ackqueue._reader.recv
         self._quick_get_ack = self._ackqueue._reader.recv
 
 
+        def _poll_result(timeout):
+            if self._outqueue._reader.poll(timeout):
+                return True, self._quick_get()
+            return False, None
+
+        self._poll_result = _poll_result
+
     def apply(self, func, args=(), kwds={}):
     def apply(self, func, args=(), kwds={}):
         '''
         '''
         Equivalent of `apply()` builtin
         Equivalent of `apply()` builtin
@@ -684,8 +718,11 @@ class Pool(object):
     def close(self):
     def close(self):
         debug('closing pool')
         debug('closing pool')
         if self._state == RUN:
         if self._state == RUN:
-            self._state = CLOSE
+            # Worker handler can't run while the result
+            # handler does its second pass, so wait for it to finish.
             self._worker_handler.close()
             self._worker_handler.close()
+            self._worker_handler.join()
+            self._state = CLOSE
             self._taskqueue.put(None)
             self._taskqueue.put(None)
 
 
     def terminate(self):
     def terminate(self):
@@ -994,6 +1031,14 @@ class ThreadPool(Pool):
         self._quick_get = self._outqueue.get
         self._quick_get = self._outqueue.get
         self._quick_get_ack = self._ackqueue.get
         self._quick_get_ack = self._ackqueue.get
 
 
+        def _poll_result(timeout):
+            try:
+                return True, self._quick_get(timeout=timeout)
+            except Queue.Empty:
+                return False, None
+
+        self._poll_result = _poll_result
+
     @staticmethod
     @staticmethod
     def _help_stuff_finish(inqueue, task_handler, size):
     def _help_stuff_finish(inqueue, task_handler, size):
         # put sentinels at head of inqueue to make workers finish
         # put sentinels at head of inqueue to make workers finish

+ 7 - 1
celery/exceptions.py

@@ -8,13 +8,14 @@ UNREGISTERED_FMT = """
 Task of kind %s is not registered, please make sure it's imported.
 Task of kind %s is not registered, please make sure it's imported.
 """.strip()
 """.strip()
 
 
-
 class QueueNotFound(KeyError):
 class QueueNotFound(KeyError):
     """Task routed to a queue not in CELERY_QUEUES."""
     """Task routed to a queue not in CELERY_QUEUES."""
+    pass
 
 
 
 
 class TimeLimitExceeded(Exception):
 class TimeLimitExceeded(Exception):
     """The time limit has been exceeded and the job has been terminated."""
     """The time limit has been exceeded and the job has been terminated."""
+    pass
 
 
 
 
 class SoftTimeLimitExceeded(Exception):
 class SoftTimeLimitExceeded(Exception):
@@ -23,6 +24,11 @@ class SoftTimeLimitExceeded(Exception):
     pass
     pass
 
 
 
 
+class WorkerLostError(Exception):
+    """The worker processing a task has exited prematurely."""
+    pass
+
+
 class ImproperlyConfigured(Exception):
 class ImproperlyConfigured(Exception):
     """Celery is somehow improperly configured."""
     """Celery is somehow improperly configured."""
     pass
     pass