Browse Source

Support chords with empty headers (#4443)

Alex Hill 7 years ago
parent
commit
25f5e29610

+ 8 - 17
celery/backends/base.py

@@ -412,23 +412,19 @@ class Backend(object):
     def on_chord_part_return(self, request, state, result, **kwargs):
         pass
 
-    def fallback_chord_unlock(self, group_id, body, result=None,
-                              countdown=1, **kwargs):
-        kwargs['result'] = [r.as_tuple() for r in result]
+    def fallback_chord_unlock(self, header_result, body, countdown=1,
+                              **kwargs):
+        kwargs['result'] = [r.as_tuple() for r in header_result]
         self.app.tasks['celery.chord_unlock'].apply_async(
-            (group_id, body,), kwargs, countdown=countdown,
+            (header_result.id, body,), kwargs, countdown=countdown,
         )
 
     def ensure_chords_allowed(self):
         pass
 
-    def apply_chord(self, header, partial_args, group_id, body,
-                    options={}, **kwargs):
+    def apply_chord(self, header_result, body, **kwargs):
         self.ensure_chords_allowed()
-        fixed_options = {k: v for k, v in items(options) if k != 'task_id'}
-        result = header(*partial_args, task_id=group_id, **fixed_options or {})
-        self.fallback_chord_unlock(group_id, body, **kwargs)
-        return result
+        self.fallback_chord_unlock(header_result, body, **kwargs)
 
     def current_task_children(self, request=None):
         request = request or getattr(get_current_task(), 'request', None)
@@ -683,14 +679,9 @@ class BaseKeyValueStoreBackend(Backend):
             meta['result'] = result_from_tuple(result, self.app)
             return meta
 
-    def _apply_chord_incr(self, header, partial_args, group_id, body,
-                          result=None, options={}, **kwargs):
+    def _apply_chord_incr(self, header_result, body, **kwargs):
         self.ensure_chords_allowed()
-        self.save_group(group_id, self.app.GroupResult(group_id, result))
-
-        fixed_options = {k: v for k, v in items(options) if k != 'task_id'}
-
-        return header(*partial_args, task_id=group_id, **fixed_options or {})
+        header_result.save(backend=self)
 
     def on_chord_part_return(self, request, state, result, **kwargs):
         if not self.implements_incr:

+ 4 - 3
celery/backends/cache.py

@@ -132,10 +132,11 @@ class CacheBackend(KeyValueStoreBackend):
     def delete(self, key):
         return self.client.delete(key)
 
-    def _apply_chord_incr(self, header, partial_args, group_id, body, **opts):
-        self.client.set(self.get_key_for_chord(group_id), 0, time=self.expires)
+    def _apply_chord_incr(self, header_result, body, **kwargs):
+        chord_key = self.get_key_for_chord(header_result.id)
+        self.client.set(chord_key, 0, time=self.expires)
         return super(CacheBackend, self)._apply_chord_incr(
-            header, partial_args, group_id, body, **opts)
+            header_result, body, **kwargs)
 
     def incr(self, key):
         return self.client.incr(key)

+ 2 - 4
celery/backends/redis.py

@@ -251,15 +251,13 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
             raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
         return retval
 
-    def apply_chord(self, header, partial_args, group_id, body,
-                    result=None, options={}, **kwargs):
+    def apply_chord(self, header_result, body, **kwargs):
         # Overrides this to avoid calling GroupResult.save
         # pylint: disable=method-hidden
         # Note that KeyValueStoreBackend.__init__ sets self.apply_chord
         # if the implements_incr attr is set.  Redis backend doesn't set
         # this flag.
-        options['task_id'] = group_id
-        return header(*partial_args, **options or {})
+        pass
 
     def on_chord_part_return(self, request, state, result,
                              propagate=None, **kwargs):

+ 19 - 8
celery/canvas.py

@@ -1268,21 +1268,32 @@ class chord(Signature):
             options.pop('task_id', None)
             body.options.update(options)
 
-        results = header.freeze(
-            group_id=group_id, chord=body, root_id=root_id).results
         bodyres = body.freeze(task_id, root_id=root_id)
 
         # Chains should not be passed to the header tasks. See #3771
         options.pop('chain', None)
         # Neither should chords, for deeply nested chords to work
         options.pop('chord', None)
+        options.pop('task_id', None)
+
+        header.freeze(group_id=group_id, chord=body, root_id=root_id)
+        header_result = header(*partial_args, task_id=group_id, **options)
+
+        if len(header_result) > 0:
+            app.backend.apply_chord(
+                header_result,
+                body,
+                interval=interval,
+                countdown=countdown,
+                max_retries=max_retries,
+            )
+        # The execution of a chord body is normally triggered by its header's
+        # tasks completing. If the header is empty this will never happen, so
+        # we execute the body manually here.
+        else:
+            body.delay([])
 
-        parent = app.backend.apply_chord(
-            header, partial_args, group_id, body,
-            interval=interval, countdown=countdown,
-            options=options, max_retries=max_retries,
-            result=results)
-        bodyres.parent = parent
+        bodyres.parent = header_result
         return bodyres
 
     def clone(self, *args, **kwargs):

+ 14 - 0
t/integration/test_canvas.py

@@ -223,6 +223,20 @@ class test_chord:
         res2 = c2()
         assert res2.get(timeout=TIMEOUT) == [16]
 
+    def test_empty_header_chord(self, manager):
+        try:
+            manager.app.backend.ensure_chords_allowed()
+        except NotImplementedError as e:
+            raise pytest.skip(e.args[0])
+
+        c1 = chord([], body=add_to_all.s(9))
+        res1 = c1()
+        assert res1.get(timeout=TIMEOUT) == []
+
+        c2 = group([]) | add_to_all.s(9)
+        res2 = c2()
+        assert res2.get(timeout=TIMEOUT) == []
+
     @flaky
     def test_nested_chord(self, manager):
         try:

+ 10 - 6
t/unit/backends/test_base.py

@@ -76,10 +76,11 @@ class test_BaseBackend_interface:
 
     def test_apply_chord(self, unlock='celery.chord_unlock'):
         self.app.tasks[unlock] = Mock()
-        self.b.apply_chord(
-            group(app=self.app), (), 'dakj221', None,
-            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
+        header_result = self.app.GroupResult(
+            uuid(),
+            [self.app.AsyncResult(x) for x in range(3)],
         )
+        self.b.apply_chord(header_result, None)
         assert self.app.tasks[unlock].apply_async.call_count
 
 
@@ -527,12 +528,15 @@ class test_KeyValueStoreBackend:
     def test_chord_apply_fallback(self):
         self.b.implements_incr = False
         self.b.fallback_chord_unlock = Mock()
+        header_result = self.app.GroupResult(
+            'group_id',
+            [self.app.AsyncResult(x) for x in range(3)],
+        )
         self.b.apply_chord(
-            group(app=self.app), (), 'group_id', 'body',
-            result='result', foo=1,
+            header_result, 'body', foo=1,
         )
         self.b.fallback_chord_unlock.assert_called_with(
-            'group_id', 'body', result='result', foo=1,
+            header_result, 'body', foo=1,
         )
 
     def test_get_missing_meta(self):

+ 13 - 6
t/unit/backends/test_cache.py

@@ -8,7 +8,7 @@ import pytest
 from case import Mock, mock, patch, skip
 from kombu.utils.encoding import ensure_bytes, str_to_bytes
 
-from celery import group, signature, states, uuid
+from celery import signature, states, uuid
 from celery.backends.cache import CacheBackend, DummyClient, backends
 from celery.exceptions import ImproperlyConfigured
 from celery.five import bytes_if_py2, items, string, text_t
@@ -65,8 +65,12 @@ class test_CacheBackend:
 
     def test_apply_chord(self):
         tb = CacheBackend(backend='memory://', app=self.app)
-        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
-        tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
+        result = self.app.GroupResult(
+            uuid(),
+            [self.app.AsyncResult(uuid()) for _ in range(3)],
+        )
+        tb.apply_chord(result, None)
+        assert self.app.GroupResult.restore(result.id, backend=tb) == result
 
     @patch('celery.result.GroupResult.restore')
     def test_on_chord_part_return(self, restore):
@@ -81,9 +85,12 @@ class test_CacheBackend:
         self.app.tasks['foobarbaz'] = task
         task.request.chord = signature(task)
 
-        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
-        task.request.group = gid
-        tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
+        result = self.app.GroupResult(
+            uuid(),
+            [self.app.AsyncResult(uuid()) for _ in range(3)],
+        )
+        task.request.group = result.id
+        tb.apply_chord(result, None)
 
         deps.join_native.assert_not_called()
         tb.on_chord_part_return(task.request, 'SUCCESS', 10)

+ 7 - 7
t/unit/backends/test_redis.py

@@ -266,14 +266,14 @@ class test_RedisBackend:
         self.b.expire('foo', 300)
         self.b.client.expire.assert_called_with('foo', 300)
 
-    def test_apply_chord(self):
-        header = Mock(name='header')
-        header.results = [Mock(name='t1'), Mock(name='t2')]
-        self.b.apply_chord(
-            header, (1, 2), 'gid', None,
-            options={'max_retries': 10},
+    def test_apply_chord(self, unlock='celery.chord_unlock'):
+        self.app.tasks[unlock] = Mock()
+        header_result = self.app.GroupResult(
+            uuid(),
+            [self.app.AsyncResult(x) for x in range(3)],
         )
-        header.assert_called_with(1, 2, max_retries=10, task_id='gid')
+        self.b.apply_chord(header_result, None)
+        assert self.app.tasks[unlock].apply_async.call_count == 0
 
     def test_unpack_chord_result(self):
         self.b.exception_to_python = Mock(name='etp')

+ 1 - 1
t/unit/backends/test_rpc.py

@@ -28,7 +28,7 @@ class test_RPCBackend:
 
     def test_apply_chord(self):
         with pytest.raises(NotImplementedError):
-            self.b.apply_chord([], (), 'gid', Mock(name='body'))
+            self.b.apply_chord(self.app.GroupResult(), None)
 
     @pytest.mark.celery(result_backend='rpc')
     def test_chord_raises_error(self):