Przeglądaj źródła

Worker: Now calls errbacks for tasks even when result stored by parent process. Closes #2510

Ask Solem 9 lat temu
rodzic
commit
e6fb53488e

+ 13 - 13
celery/app/trace.py

@@ -141,15 +141,17 @@ class TraceInfo(object):
         self.state = state
         self.retval = retval
 
-    def handle_error_state(self, task, req, eager=False):
+    def handle_error_state(self, task, req,
+                           eager=False, call_errbacks=True):
         store_errors = not eager
         if task.ignore_result:
             store_errors = task.store_errors_even_if_ignored
-
         return {
             RETRY: self.handle_retry,
             FAILURE: self.handle_failure,
-        }[self.state](task, req, store_errors=store_errors)
+        }[self.state](task, req,
+                      store_errors=store_errors,
+                      call_errbacks=call_errbacks)
 
     def handle_reject(self, task, req, **kwargs):
         self._log_error(task, req, ExceptionInfo())
@@ -157,7 +159,7 @@ class TraceInfo(object):
     def handle_ignore(self, task, req, **kwargs):
         self._log_error(task, req, ExceptionInfo())
 
-    def handle_retry(self, task, req, store_errors=True):
+    def handle_retry(self, task, req, store_errors=True, **kwargs):
         """Handle retry exception."""
         # the exception raised is the Retry semi-predicate,
         # and it's exc' attribute is the original exception raised (if any).
@@ -180,7 +182,7 @@ class TraceInfo(object):
         finally:
             del(tb)
 
-    def handle_failure(self, task, req, store_errors=True):
+    def handle_failure(self, task, req, store_errors=True, call_errbacks=True):
         """Handle exception."""
         type_, _, tb = sys.exc_info()
         try:
@@ -189,7 +191,9 @@ class TraceInfo(object):
             einfo.exception = get_pickleable_exception(einfo.exception)
             einfo.type = get_pickleable_etype(einfo.type)
             task.backend.mark_as_failure(
-                req.id, exc, einfo.traceback, req, store_errors,
+                req.id, exc, einfo.traceback,
+                request=req, store_result=store_errors,
+                call_errbacks=call_errbacks,
             )
             task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
             signals.task_failure.send(sender=task, task_id=req.id,
@@ -306,13 +310,9 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         if propagate:
             raise
         I = Info(state, exc)
-        R = I.handle_error_state(task, request, eager=eager)
-        if call_errbacks:
-            root_id = request.root_id or uuid
-            group(
-                [signature(errback, app=app)
-                 for errback in request.errbacks or []], app=app,
-            ).apply_async((uuid,), parent_id=uuid, root_id=root_id)
+        R = I.handle_error_state(
+            task, request, eager=eager, call_errbacks=call_errbacks,
+        )
         return I, R, I.state, I.retval
 
     def trace_task(uuid, args, kwargs, request=None):

+ 12 - 4
celery/backends/base.py

@@ -26,7 +26,7 @@ from kombu.serialization import (
 from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
 
 from celery import states
-from celery import current_app, maybe_signature
+from celery import current_app, group, maybe_signature
 from celery.app import current_task
 from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
 from celery.five import items
@@ -121,14 +121,22 @@ class BaseBackend(object):
             self.on_chord_part_return(request, state, result)
 
     def mark_as_failure(self, task_id, exc,
-                        traceback=None, request=None, store_result=True,
+                        traceback=None, request=None,
+                        store_result=True, call_errbacks=True,
                         state=states.FAILURE):
         """Mark task as executed with failure. Stores the exception."""
         if store_result:
             self.store_result(task_id, exc, state,
                               traceback=traceback, request=request)
-        if request and request.chord:
-            self.on_chord_part_return(request, state, exc)
+        if request:
+            if request.chord:
+                self.on_chord_part_return(request, state, exc)
+            if call_errbacks:
+                root_id = request.root_id or task_id
+                group(
+                    [self.app.signature(errback)
+                     for errback in request.errbacks or []], app=self.app,
+                ).apply_async((task_id,), parent_id=task_id, root_id=root_id)
 
     def mark_as_revoked(self, task_id, reason='',
                         request=None, store_result=True, state=states.REVOKED):

+ 2 - 0
celery/tests/backends/test_base.py

@@ -270,6 +270,7 @@ class test_BaseBackend_dict(AppCase):
         b = BaseBackend(app=self.app)
         b._store_result = Mock()
         request = Mock(name='request')
+        request.errbacks = []
         b.on_chord_part_return = Mock()
         exc = KeyError()
         b.mark_as_failure('id', exc, request=request)
@@ -279,6 +280,7 @@ class test_BaseBackend_dict(AppCase):
         b = BaseBackend(app=self.app)
         b._store_result = Mock()
         request = Mock(name='request')
+        request.errbacks = []
         b.on_chord_part_return = Mock()
         b.mark_as_revoked('id', 'revoked', request=request)
         b.on_chord_part_return.assert_called_with(request, states.REVOKED, ANY)

+ 1 - 0
celery/tests/tasks/test_trace.py

@@ -319,6 +319,7 @@ class test_TraceInfo(TraceCase):
         x.handle_failure.assert_called_with(
             self.add_cast, self.add_cast.request,
             store_errors=self.add_cast.store_errors_even_if_ignored,
+            call_errbacks=True,
         )
 
     @patch('celery.app.trace.ExceptionInfo')

+ 8 - 1
celery/worker/request.py

@@ -466,11 +466,18 @@ class Request(object):
 
     @cached_property
     def chord(self):
-        # used by backend.on_chord_part_return when failures reported
+        # used by backend.mark_as_failure when failure is reported
         # by parent process
         _, _, embed = self._payload
         return embed.get('chord')
 
+    @cached_property
+    def errbacks(self):
+        # used by backend.mark_as_failure when failure is reported
+        # by parent process
+        _, _, embed = self._payload
+        return embed.get('errbacks')
+
     @cached_property
     def group(self):
         # used by backend.on_chord_part_return when failures reported