Selaa lähdekoodia

multiprocessing.pool: Properly handle encoding errors, so pickling errors doesn't crash the worker processes.

Ask Solem 14 vuotta sitten
vanhempi
commit
eaa4d5ddc0
2 muutettua tiedostoa jossa 36 lisäystä ja 5 poistoa
  1. 7 1
      celery/concurrency/processes/__init__.py
  2. 29 4
      celery/concurrency/processes/pool.py

+ 7 - 1
celery/concurrency/processes/__init__.py

@@ -8,7 +8,7 @@ from celery import log
 from celery.datastructures import ExceptionInfo
 from celery.utils.functional import curry
 
-from celery.concurrency.processes.pool import Pool, RUN
+from celery.concurrency.processes.pool import Pool, RUN, MaybeEncodingError
 
 
 class TaskPool(object):
@@ -81,6 +81,7 @@ class TaskPool(object):
         errbacks = errbacks or []
 
         on_ready = curry(self.on_ready, callbacks, errbacks)
+        on_worker_error = curry(self.on_worker_error, errbacks)
 
         self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)" % (
             target, args, kwargs))
@@ -89,8 +90,13 @@ class TaskPool(object):
                                       callback=on_ready,
                                       accept_callback=accept_callback,
                                       timeout_callback=timeout_callback,
+                                      error_callback=on_worker_error,
                                       waitforslot=self.putlocks)
 
+    def on_worker_error(self, errbacks, exc):
+        einfo = ExceptionInfo((exc.__class__, exc, None))
+        [errback(einfo) for errback in errbacks]
+
     def on_ready(self, callbacks, errbacks, ret_value):
         """What to do when a worker task is ready and its return value has
         been collected."""

+ 29 - 4
celery/concurrency/processes/pool.py

@@ -55,6 +55,21 @@ def soft_timeout_sighandler(signum, frame):
     raise SoftTimeLimitExceeded()
 
 
+class MaybeEncodingError(Exception):
+    """Wraps unpickleable object."""
+
+    def __init__(self, exc, value):
+        self.exc = str(exc)
+        self.value = repr(value)
+        super(MaybeEncodingError, self).__init__(self.exc, self.value)
+
+    def __str__(self):
+        return "Error sending result: '%s'. Reason: '%s'." % (self.value,
+                                                              self.exc)
+    def __repr__(self):
+        return "<MaybeEncodingError: %s>" % str(self)
+
+
 def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=(),
         maxtasks=None):
     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
@@ -90,7 +105,13 @@ def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=(),
             result = (True, func(*args, **kwds))
         except Exception, e:
             result = (False, e)
-        put((job, i, result))
+        try:
+            put((job, i, result))
+        except Exception, exc:
+            wrapped = MaybeEncodingError(exc, result[1])
+            debug('Got possible encoding error while sending result: %s' % wrapped)
+            put((job, i, (False, wrapped)))
+
         completed += 1
     debug('worker exiting after %d tasks' % completed)
 
@@ -589,7 +610,7 @@ class Pool(object):
 
     def apply_async(self, func, args=(), kwds={},
             callback=None, accept_callback=None, timeout_callback=None,
-            waitforslot=False):
+            waitforslot=False, error_callback=None):
         '''
         Asynchronous equivalent of `apply()` builtin.
 
@@ -607,7 +628,8 @@ class Pool(object):
         '''
         assert self._state == RUN
         result = ApplyResult(self._cache, callback,
-                             accept_callback, timeout_callback)
+                             accept_callback, timeout_callback,
+                             error_callback)
         if waitforslot:
             self._putlock.acquire()
         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
@@ -742,7 +764,7 @@ DynamicPool = Pool
 class ApplyResult(object):
 
     def __init__(self, cache, callback, accept_callback=None,
-            timeout_callback=None):
+            timeout_callback=None, error_callback=None):
         self._cond = threading.Condition(threading.Lock())
         self._job = job_counter.next()
         self._cache = cache
@@ -751,6 +773,7 @@ class ApplyResult(object):
         self._time_accepted = None
         self._ready = False
         self._callback = callback
+        self._errback = error_callback
         self._accept_callback = accept_callback
         self._timeout_callback = timeout_callback
         cache[self._job] = self
@@ -786,6 +809,8 @@ class ApplyResult(object):
         self._success, self._value = obj
         if self._callback and self._success:
             self._callback(self._value)
+        if self._errback and not self._success:
+            self._errback(self._value)
         self._cond.acquire()
         try:
             self._ready = True