Переглянути джерело

Allow Extraction of Chord Results On Error (#4888)

* Keep group ID in task results

* Don't delete group results on error

* Tolerant group persistance in result storage

Not everything that gets passed here has a group attribute, and even
Request objects sometimes don't have the necessary data in their dict

* Test using stored group ID to recover chord result

* Accept all args to chord error callback

* isort-check fix for chord error handling test

* Fix test_chord_on_error fail in full integration

propagate=False stops working?

* Require redis for chord error handling test

* Explain test structure more

* Test storage of group_id in result meta
Nicholas Pilon 6 роки тому
батько
коміт
97fd3acac6

+ 2 - 0
celery/backends/base.py

@@ -661,6 +661,8 @@ class BaseKeyValueStoreBackend(Backend):
             'children': self.current_task_children(request),
             'task_id': bytes_to_str(task_id),
         }
+        if request and getattr(request, 'group', None):
+            meta['group_id'] = request.group
         self.set(self.get_key_for_task(task_id), self.encode(meta))
         return result
 

+ 6 - 3
celery/backends/redis.py

@@ -360,13 +360,16 @@ class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
             if readycount == total:
                 decode, unpack = self.decode, self._unpack_chord_result
                 with client.pipeline() as pipe:
-                    resl, _, _ = pipe \
+                    resl, = pipe \
                         .lrange(jkey, 0, total) \
-                        .delete(jkey) \
-                        .delete(tkey) \
                         .execute()
                 try:
                     callback.delay([unpack(tup, decode) for tup in resl])
+                    with client.pipeline() as pipe:
+                        _, _ = pipe \
+                            .delete(jkey) \
+                            .delete(tkey) \
+                            .execute()
                 except Exception as exc:  # pylint: disable=broad-except
                     logger.exception(
                         'Chord callback for %r raised: %r', request.group, exc)

+ 1 - 1
celery/worker/request.py

@@ -498,7 +498,7 @@ class Request(object):
     def group(self):
         # used by backend.on_chord_part_return when failures reported
         # by parent process
-        return self.request_dict['group']
+        return self.request_dict.get('group')
 
 
 def create_request_cls(base, task, pool, hostname, eventer,

+ 14 - 0
t/integration/tasks.py

@@ -178,3 +178,17 @@ def build_chain_inside_task(self):
     )
     result = test_chain()
     return result
+
+
+class ExpectedException(Exception):
+    pass
+
+
+@shared_task
+def fail(*args):
+    raise ExpectedException('Task expected to fail')
+
+
+@shared_task
+def chord_error(*args):
+    return args

+ 62 - 3
t/integration/test_canvas.py

@@ -10,9 +10,10 @@ from celery.result import AsyncResult, GroupResult, ResultSet
 
 from .conftest import flaky, get_active_redis_channels, get_redis_connection
 from .tasks import (add, add_chord_to_chord, add_replaced, add_to_all,
-                    add_to_all_to_chord, build_chain_inside_task, collect_ids,
-                    delayed_sum, delayed_sum_with_soft_guard, identity, ids,
-                    print_unicode, redis_echo, second_order_replace1, tsum)
+                    add_to_all_to_chord, build_chain_inside_task, chord_error,
+                    collect_ids, delayed_sum, delayed_sum_with_soft_guard,
+                    fail, identity, ids, print_unicode, redis_echo,
+                    second_order_replace1, tsum)
 
 TIMEOUT = 120
 
@@ -521,3 +522,61 @@ class test_chord:
         assert value == 1
         assert root_id == expected_root_id
         assert parent_id is None
+
+    def test_chord_on_error(self, manager):
+        from celery import states
+        from .tasks import ExpectedException
+        import time
+
+        if not manager.app.conf.result_backend.startswith('redis'):
+            raise pytest.skip('Requires redis result backend.')
+
+        # Run the chord and wait for the error callback to finish.
+        c1 = chord(
+            header=[add.s(1, 2), add.s(3, 4), fail.s()],
+            body=print_unicode.s('This should not be called').on_error(
+                chord_error.s()),
+        )
+        res = c1()
+        try:
+            res.wait(propagate=False)
+        except ExpectedException:
+            pass
+        # Got to wait for children to populate.
+        while not res.children:
+            time.sleep(0.1)
+        try:
+            res.children[0].children[0].wait(propagate=False)
+        except ExpectedException:
+            pass
+
+        # Extract the results of the successful tasks from the chord.
+        #
+        # We could do this inside the error handler, and probably would in a
+        #  real system, but for the purposes of the test it's obnoxious to get
+        #  data out of the error handler.
+        #
+        # So for clarity of our test, we instead do it here.
+
+        # Use the error callback's result to find the failed task.
+        error_callback_result = AsyncResult(
+            res.children[0].children[0].result[0])
+        failed_task_id = error_callback_result.result.args[0].split()[3]
+
+        # Use new group_id result metadata to get group ID.
+        failed_task_result = AsyncResult(failed_task_id)
+        original_group_id = failed_task_result._get_task_meta()['group_id']
+
+        # Use group ID to get preserved group result.
+        backend = fail.app.backend
+        j_key = backend.get_key_for_group(original_group_id, '.j')
+        redis_connection = get_redis_connection()
+        chord_results = [backend.decode(t) for t in
+                         redis_connection.lrange(j_key, 0, 3)]
+
+        # Validate group result
+        assert [cr[3] for cr in chord_results if cr[2] == states.SUCCESS] == \
+               [3, 7]
+
+        assert len([cr for cr in chord_results if cr[2] != states.SUCCESS]
+                   ) == 1

+ 13 - 0
t/unit/backends/test_base.py

@@ -422,6 +422,19 @@ class test_KeyValueStoreBackend:
         self.b.forget(tid)
         assert self.b.get_state(tid) == states.PENDING
 
+    def test_store_result_group_id(self):
+        tid = uuid()
+        state = 'SUCCESS'
+        result = 10
+        request = Mock()
+        request.group = 'gid'
+        request.children = []
+        self.b.store_result(
+            tid, state=state, result=result, request=request,
+        )
+        stored_meta = self.b.decode(self.b.get(self.b.get_key_for_task(tid)))
+        assert stored_meta['group_id'] == request.group
+
     def test_strip_prefix(self):
         x = self.b.get_key_for_task('x1b34')
         assert self.b._strip_prefix(x) == 'x1b34'