Ver Fonte

Report chord errors when task process terminated

Ask Solem há 9 anos atrás
pai
commit
cd48cd34ae

+ 0 - 5
celery/app/trace.py

@@ -291,7 +291,6 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     pop_request = request_stack.pop
     push_task = _task_stack.push
     pop_task = _task_stack.pop
-    on_chord_part_return = backend.on_chord_part_return
     _does_info = logger.isEnabledFor(logging.INFO)
 
     prerun_receivers = signals.task_prerun.receivers
@@ -368,8 +367,6 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     )
                 except Exception as exc:
                     I, R, state, retval = on_error(task_request, exc, uuid)
-                    if task_request.chord:
-                        on_chord_part_return(task, state, exc)
                 except BaseException as exc:
                     raise
                 else:
@@ -404,8 +401,6 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     except EncodeError as exc:
                         I, R, state, retval = on_error(task_request, exc, uuid)
                     else:
-                        if task_request.chord:
-                            on_chord_part_return(task, state, retval)
                         if task_on_success:
                             task_on_success(retval, uuid, args, kwargs)
                         if success_receivers:

+ 17 - 13
celery/backends/base.py

@@ -112,15 +112,19 @@ class BaseBackend(object):
         """Mark a task as started"""
         return self.store_result(task_id, meta, status=states.STARTED)
 
-    def mark_as_done(self, task_id, result, request=None):
+    def mark_as_done(self, task_id, result, request=None, state=states.SUCCESS):
         """Mark task as successfully executed."""
-        return self.store_result(task_id, result,
-                                 status=states.SUCCESS, request=request)
+        self.store_result(task_id, result, status=state, request=request)
+        if request and request.chord:
+            self.on_chord_part_return(request, state)
 
-    def mark_as_failure(self, task_id, exc, traceback=None, request=None):
+    def mark_as_failure(self, task_id, exc,
+                        traceback=None, request=None, state=states.FAILURE):
         """Mark task as executed with failure. Stores the exception."""
-        return self.store_result(task_id, exc, status=states.FAILURE,
-                                 traceback=traceback, request=request)
+        self.store_result(task_id, exc, status=state,
+                          traceback=traceback, request=request)
+        if request and request.chord:
+            self.on_chord_part_return(request, state, exc)
 
     def chord_error_from_stack(self, callback, exc=None):
         from celery import group
@@ -346,7 +350,7 @@ class BaseBackend(object):
     def add_to_chord(self, chord_id, result):
         raise NotImplementedError('Backend does not support add_to_chord')
 
-    def on_chord_part_return(self, task, state, result, propagate=False):
+    def on_chord_part_return(self, request, state, result, propagate=False):
         pass
 
     def fallback_chord_unlock(self, group_id, body, result=None,
@@ -540,20 +544,20 @@ class KeyValueStoreBackend(BaseBackend):
 
         return header(*partial_args, task_id=group_id, **fixed_options or {})
 
-    def on_chord_part_return(self, task, state, result, propagate=None):
+    def on_chord_part_return(self, request, state, result, propagate=None):
         if not self.implements_incr:
             return
         app = self.app
         if propagate is None:
             propagate = app.conf.CELERY_CHORD_PROPAGATES
-        gid = task.request.group
+        gid = request.group
         if not gid:
             return
         key = self.get_key_for_chord(gid)
         try:
-            deps = GroupResult.restore(gid, backend=task.backend)
+            deps = GroupResult.restore(gid, backend=self)
         except Exception as exc:
-            callback = maybe_signature(task.request.chord, app=app)
+            callback = maybe_signature(request.chord, app=app)
             logger.error('Chord %r raised: %r', gid, exc, exc_info=1)
             return self.chord_error_from_stack(
                 callback,
@@ -563,7 +567,7 @@ class KeyValueStoreBackend(BaseBackend):
             try:
                 raise ValueError(gid)
             except ValueError as exc:
-                callback = maybe_signature(task.request.chord, app=app)
+                callback = maybe_signature(request.chord, app=app)
                 logger.error('Chord callback %r raised: %r', gid, exc,
                              exc_info=1)
                 return self.chord_error_from_stack(
@@ -576,7 +580,7 @@ class KeyValueStoreBackend(BaseBackend):
             logger.warning('Chord counter incremented too many times for %r',
                            gid)
         elif val == size:
-            callback = maybe_signature(task.request.chord, app=app)
+            callback = maybe_signature(request.chord, app=app)
             j = deps.join_native if deps.supports_native_join else deps.join
             try:
                 with allow_join_result():

+ 1 - 2
celery/backends/redis.py

@@ -196,9 +196,8 @@ class RedisBackend(KeyValueStoreBackend):
         options['task_id'] = group_id
         return header(*partial_args, **options or {})
 
-    def _new_chord_return(self, task, state, result, propagate=None):
+    def _new_chord_return(self, request, state, result, propagate=None):
         app = self.app
-        request = task.request
         tid, gid = request.id, request.group
         if not gid or not tid:
             return

+ 11 - 5
celery/tests/backends/test_base.py

@@ -298,7 +298,9 @@ class test_KeyValueStoreBackend(AppCase):
         self.b.get_key_for_chord.side_effect = AssertionError(
             'should not get here',
         )
-        self.assertIsNone(self.b.on_chord_part_return(task, state, result))
+        self.assertIsNone(
+            self.b.on_chord_part_return(task.request, state, result),
+        )
 
     @contextmanager
     def _chord_part_context(self, b):
@@ -326,14 +328,18 @@ class test_KeyValueStoreBackend(AppCase):
 
     def test_chord_part_return_propagate_set(self):
         with self._chord_part_context(self.b) as (task, deps, _):
-            self.b.on_chord_part_return(task, 'SUCCESS', 10, propagate=True)
+            self.b.on_chord_part_return(
+                task.request, 'SUCCESS', 10, propagate=True,
+            )
             self.assertFalse(self.b.expire.called)
             deps.delete.assert_called_with()
             deps.join_native.assert_called_with(propagate=True, timeout=3.0)
 
     def test_chord_part_return_propagate_default(self):
         with self._chord_part_context(self.b) as (task, deps, _):
-            self.b.on_chord_part_return(task, 'SUCCESS', 10, propagate=None)
+            self.b.on_chord_part_return(
+                task.request, 'SUCCESS', 10, propagate=None,
+            )
             self.assertFalse(self.b.expire.called)
             deps.delete.assert_called_with()
             deps.join_native.assert_called_with(
@@ -345,7 +351,7 @@ class test_KeyValueStoreBackend(AppCase):
         with self._chord_part_context(self.b) as (task, deps, callback):
             deps._failed_join_report = lambda: iter([])
             deps.join_native.side_effect = KeyError('foo')
-            self.b.on_chord_part_return(task, 'SUCCESS', 10)
+            self.b.on_chord_part_return(task.request, 'SUCCESS', 10)
             self.assertTrue(self.b.fail_from_current_stack.called)
             args = self.b.fail_from_current_stack.call_args
             exc = args[1]['exc']
@@ -359,7 +365,7 @@ class test_KeyValueStoreBackend(AppCase):
                 self.app.AsyncResult('culprit'),
             ])
             deps.join_native.side_effect = KeyError('foo')
-            b.on_chord_part_return(task, 'SUCCESS', 10)
+            b.on_chord_part_return(task.request, 'SUCCESS', 10)
             self.assertTrue(b.fail_from_current_stack.called)
             args = b.fail_from_current_stack.call_args
             exc = args[1]['exc']

+ 2 - 2
celery/tests/backends/test_cache.py

@@ -87,10 +87,10 @@ class test_CacheBackend(AppCase):
         tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
 
         self.assertFalse(deps.join_native.called)
-        tb.on_chord_part_return(task, 'SUCCESS', 10)
+        tb.on_chord_part_return(task.request, 'SUCCESS', 10)
         self.assertFalse(deps.join_native.called)
 
-        tb.on_chord_part_return(task, 'SUCCESS', 10)
+        tb.on_chord_part_return(task.request, 'SUCCESS', 10)
         deps.join_native.assert_called_with(propagate=True, timeout=3.0)
         deps.delete.assert_called_with()
 

+ 1 - 1
celery/tests/backends/test_redis.py

@@ -259,7 +259,7 @@ class test_RedisBackend(AppCase):
         tasks = [create_task() for i in range(10)]
 
         for i in range(10):
-            b.on_chord_part_return(tasks[i], states.SUCCESS, i)
+            b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
             self.assertTrue(b.client.rpush.call_count)
             b.client.rpush.reset_mock()
         self.assertTrue(b.client.lrange.call_count)

+ 13 - 2
celery/tests/tasks/test_trace.py

@@ -103,8 +103,19 @@ class test_trace(TraceCase):
             return x + y
         add.backend = Mock()
 
-        self.trace(add, (2, 2), {}, request={'chord': uuid()})
-        add.backend.on_chord_part_return.assert_called_with(add, 'SUCCESS', 4)
+        class TestRequest(object):
+
+            def __init__(self, request):
+                self.request = request
+
+            def __eq__(self, other):
+                return self.request['chord'] == other['chord']
+
+        request = {'chord': uuid()}
+        self.trace(add, (2, 2), {}, request=request)
+        add.backend.on_chord_part_return.assert_called_with(
+            TestRequest(request), 'SUCCESS', 4,
+        )
 
     def test_when_backend_cleanup_raises(self):
 

+ 19 - 4
celery/worker/request.py

@@ -211,7 +211,7 @@ class Request(object):
             self.acknowledge()
 
         request = self.request_dict
-        args, kwargs, embed = self.message.payload
+        args, kwargs, embed = self._payload
         request.update({'loglevel': loglevel, 'logfile': logfile,
                         'hostname': self.hostname, 'is_eager': False,
                         'args': args, 'kwargs': kwargs}, **embed or {})
@@ -348,9 +348,7 @@ class Request(object):
                     'terminated', True, string(exc), False)
                 send_failed_event = False  # already sent revoked event
             elif isinstance(exc, WorkerLostError) or not return_ok:
-                self.task.backend.mark_as_failure(
-                    self.id, exc, request=self,
-                )
+                self.task.backend.mark_as_failure(self.id, exc, request=self)
         # (acks_late) acknowledge after result stored.
         if self.task.acks_late:
             reject_and_requeue = (
@@ -453,6 +451,23 @@ class Request(object):
         # used similarly to reply_to
         return self.request_dict['correlation_id']
 
+    @cached_property
+    def _payload(self):
+        return self.message.payload
+
+    @cached_property
+    def chord(self):
+        # used by backend.on_chord_part_return when failures reported
+        # by parent process
+        _, _, embed = self._payload
+        return embed['chord']
+
+    @cached_property
+    def group(self):
+        # used by backend.on_chord_part_return when failures reported
+        # by parent process
+        return self.request_dict['group']
+
 
 def create_request_cls(base, task, pool, hostname, eventer,
                        ref=ref, revoked_tasks=revoked_tasks,