Browse Source

Chord unlock must be called after header. Closes #1700

Ask Solem 11 years ago
parent
commit
a31557793a

+ 5 - 13
celery/app/builtins.py

@@ -278,8 +278,6 @@ def add_chain_task(app):
                     tasks.append(task)
                 prev_task, prev_res = task, res
 
-            print(tasks)
-
             return tasks, results
 
         def apply_async(self, args=(), kwargs={}, group_id=None, chord=None,
@@ -357,17 +355,11 @@ def add_chord_task(app):
             results = [AsyncResult(prepare_member(task, body, group_id))
                        for task in header.tasks]
 
-            # - fallback implementations schedules the chord_unlock task here
-            app.backend.on_chord_apply(group_id, body,
-                                       interval=interval,
-                                       countdown=countdown,
-                                       max_retries=max_retries,
-                                       propagate=propagate,
-                                       result=results)
-            # - call the header group, returning the GroupResult.
-            final_res = header(*partial_args, task_id=group_id)
-
-            return final_res
+            return self.backend.apply_chord(
+                header, partial_args, group_id,
+                body, interval=interval, countdown=countdown,
+                max_retries=max_retries, propagate=propagate, result=results,
+            )
 
         def _prepare_member(self, task, body, group_id):
             opts = task.options

+ 11 - 6
celery/backends/base.py

@@ -311,7 +311,11 @@ class BaseBackend(object):
         self.app.tasks['celery.chord_unlock'].apply_async(
             (group_id, body, ), kwargs, countdown=countdown,
         )
-    on_chord_apply = fallback_chord_unlock
+
+    def apply_chord(self, header, partial_args, group_id, body, **options):
+        result = header(*partial_args, task_id=group_id)
+        self.fallback_chord_unlock(group_id, body, **options)
+        return result
 
     def current_task_children(self, request=None):
         request = request or getattr(current_task(), 'request', None)
@@ -335,6 +339,8 @@ class KeyValueStoreBackend(BaseBackend):
             self.key_t = self.key_t.__func__  # remove binding
         self._encode_prefixes()
         super(KeyValueStoreBackend, self).__init__(*args, **kwargs)
+        if self.implements_incr:
+            self.apply_chord = self._apply_chord_incr
 
     def _encode_prefixes(self):
         self.task_keyprefix = self.key_t(self.task_keyprefix)
@@ -459,11 +465,10 @@ class KeyValueStoreBackend(BaseBackend):
             meta['result'] = result_from_tuple(result, self.app)
             return meta
 
-    def on_chord_apply(self, group_id, body, result=None, **kwargs):
-        if self.implements_incr:
-            self.save_group(group_id, self.app.GroupResult(group_id, result))
-        else:
-            self.fallback_chord_unlock(group_id, body, result, **kwargs)
+    def _apply_chord_incr(self, header, partial_args, group_id, body,
+                          result=None, **options):
+        self.save_group(group_id, self.app.GroupResult(group_id, result))
+        return header(*partial_args, task_id=group_id)
 
     def on_chord_part_return(self, task, propagate=None):
         if not self.implements_incr:

+ 4 - 2
celery/backends/cache.py

@@ -128,9 +128,11 @@ class CacheBackend(KeyValueStoreBackend):
     def delete(self, key):
         return self.client.delete(key)
 
-    def on_chord_apply(self, group_id, body, result=None, **kwargs):
+    def _apply_chord_incr(self, header, partial_args, group_id, body, **opts):
         self.client.set(self.get_key_for_chord(group_id), '0', time=86400)
-        self.save_group(group_id, self.app.GroupResult(group_id, result))
+        return super(CacheBackend, self)._apply_chord_incr(
+            header, partial_args, group_id, body, **opts
+        )
 
     def incr(self, key):
         return self.client.incr(key)

+ 9 - 5
celery/tests/backends/test_base.py

@@ -14,6 +14,7 @@ from celery.utils.serialization import UnpickleableExceptionWrapper
 from celery.utils.serialization import get_pickleable_exception as gpe
 
 from celery import states
+from celery import group
 from celery.backends.base import (
     BaseBackend,
     KeyValueStoreBackend,
@@ -62,10 +63,10 @@ class test_BaseBackend_interface(AppCase):
     def test_on_chord_part_return(self):
         self.b.on_chord_part_return(None)
 
-    def test_on_chord_apply(self, unlock='celery.chord_unlock'):
+    def test_apply_chord(self, unlock='celery.chord_unlock'):
         self.app.tasks[unlock] = Mock()
-        self.b.on_chord_apply(
-            'dakj221', 'sdokqweok',
+        self.b.apply_chord(group(app=self.app), (),
+            'dakj221', None,
             result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
         )
         self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
@@ -364,9 +365,12 @@ class test_KeyValueStoreBackend(AppCase):
     def test_chord_apply_fallback(self):
         self.b.implements_incr = False
         self.b.fallback_chord_unlock = Mock()
-        self.b.on_chord_apply('group_id', 'body', 'result', foo=1)
+        self.b.apply_chord(
+            group(app=self.app), (), 'group_id', 'body',
+            result='result', foo=1,
+        )
         self.b.fallback_chord_unlock.assert_called_with(
-            'group_id', 'body', 'result', foo=1,
+            'group_id', 'body', result='result', foo=1,
         )
 
     def test_get_missing_meta(self):

+ 4 - 3
celery/tests/backends/test_cache.py

@@ -9,6 +9,7 @@ from kombu.utils.encoding import str_to_bytes
 
 from celery import signature
 from celery import states
+from celery import group
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
 from celery.five import items, string, text_t
@@ -62,10 +63,10 @@ class test_CacheBackend(AppCase):
             self.assertEqual(self.tb.get_status(self.tid), states.FAILURE)
             self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
 
-    def test_on_chord_apply(self):
+    def test_apply_chord(self):
         tb = CacheBackend(backend='memory://', app=self.app)
         gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
-        tb.on_chord_apply(gid, {}, result=res)
+        tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
 
     @patch('celery.result.GroupResult.restore')
     def test_on_chord_part_return(self, restore):
@@ -82,7 +83,7 @@ class test_CacheBackend(AppCase):
 
         gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
         task.request.group = gid
-        tb.on_chord_apply(gid, {}, result=res)
+        tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
 
         self.assertFalse(deps.join_native.called)
         tb.on_chord_part_return(task)

+ 4 - 3
celery/tests/backends/test_redis.py

@@ -8,6 +8,7 @@ from kombu.utils import cached_property, uuid
 
 from celery import signature
 from celery import states
+from celery import group
 from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.timeutils import timedelta_seconds
@@ -130,9 +131,9 @@ class test_RedisBackend(AppCase):
         b = self.Backend(expires=timedelta(minutes=1), app=self.app)
         self.assertEqual(b.expires, 60)
 
-    def test_on_chord_apply(self):
-        self.Backend(app=self.app).on_chord_apply(
-            'group_id', {},
+    def test_apply_chord(self):
+        self.Backend(app=self.app).apply_chord(
+            group(app=self.app), (), 'group_id', {},
             result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
         )
 

+ 1 - 1
celery/tests/tasks/test_chord.py

@@ -224,4 +224,4 @@ class test_Chord_task(ChordCase):
         body = dict()
         Chord(group(self.add.subtask((i, i)) for i in range(5)), body)
         Chord([self.add.subtask((j, j)) for j in range(5)], body)
-        self.assertEqual(self.app.backend.on_chord_apply.call_count, 2)
+        self.assertEqual(self.app.backend.apply_chord.call_count, 2)