Browse Source

task_replace chord inside chord fix (fixes #4368) (#4369)

* task_replace chord inside chord fix

* Complete fix for replace inside chords with tests

* Add integration tests for add_to_chord

* Fix JSON serialisation in tests

* Raise exception when replacing signature has a chord
Denis Shirokov 7 years ago
parent
commit
442f42b708
4 changed files with 91 additions and 19 deletions
  1. 17 14
      celery/app/task.py
  2. 25 0
      t/integration/tasks.py
  3. 42 3
      t/integration/test_canvas.py
  4. 7 2
      t/unit/tasks/test_tasks.py

+ 17 - 14
celery/app/task.py

@@ -11,7 +11,8 @@ from kombu.utils.uuid import uuid
 from celery import current_app, group, states
 from celery._state import _task_stack
 from celery.canvas import signature
-from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
+from celery.exceptions import (Ignore, ImproperlyConfigured,
+                               MaxRetriesExceededError, Reject, Retry)
 from celery.five import items, python_2_unicode_compatible
 from celery.local import class_property
 from celery.result import EagerResult, denied_join_result
@@ -839,27 +840,26 @@ class Task(object):
         """
         chord = self.request.chord
         if 'chord' in sig.options:
-            if chord:
-                chord = sig.options['chord'] | chord
-            else:
-                chord = sig.options['chord']
+            raise ImproperlyConfigured(
+                "A signature replacing a task must not be part of a chord"
+            )
 
         if isinstance(sig, group):
             sig |= self.app.tasks['celery.accumulate'].s(index=0).set(
-                chord=chord,
                 link=self.request.callbacks,
                 link_error=self.request.errbacks,
             )
-            chord = None
 
         if self.request.chain:
             for t in reversed(self.request.chain):
                 sig |= signature(t, app=self.app)
 
-        sig.freeze(self.request.id,
-                   group_id=self.request.group,
-                   chord=chord,
-                   root_id=self.request.root_id)
+        sig.set(
+            chord=chord,
+            group_id=self.request.group,
+            root_id=self.request.root_id,
+        )
+        sig.freeze(self.request.id)
 
         sig.delay()
         raise Ignore('Replaced by new task')
@@ -878,9 +878,12 @@ class Task(object):
         """
         if not self.request.chord:
             raise ValueError('Current task is not member of any chord')
-        result = sig.freeze(group_id=self.request.group,
-                            chord=self.request.chord,
-                            root_id=self.request.root_id)
+        sig.set(
+            group_id=self.request.group,
+            chord=self.request.chord,
+            root_id=self.request.root_id,
+        )
+        result = sig.freeze()
         self.backend.add_to_chord(self.request.group, result)
         return sig.delay() if not lazy else sig
 

+ 25 - 0
t/integration/tasks.py

@@ -10,6 +10,11 @@ from celery.utils.log import get_task_logger
 logger = get_task_logger(__name__)
 
 
+@shared_task
+def identity(x):
+    return x
+
+
 @shared_task
 def add(x, y):
     """Add two numbers."""
@@ -35,6 +40,12 @@ def delayed_sum_with_soft_guard(numbers, pause_time=1):
         return 0
 
 
+@shared_task
+def tsum(nums):
+    """Sum an iterable of numbers"""
+    return sum(nums)
+
+
 @shared_task(bind=True)
 def add_replaced(self, x, y):
     """Add two numbers (via the add task)."""
@@ -48,6 +59,20 @@ def add_to_all(self, nums, val):
     raise self.replace(group(*subtasks))
 
 
+@shared_task(bind=True)
+def add_to_all_to_chord(self, nums, val):
+    for num in nums:
+        self.add_to_chord(add.s(num, val))
+    return 0
+
+
+@shared_task(bind=True)
+def add_chord_to_chord(self, nums, val):
+    subtasks = [add.s(num, val) for num in nums]
+    self.add_to_chord(group(subtasks) | tsum.s())
+    return 0
+
+
 @shared_task
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
     """Task that both logs and print strings containing funny characters."""

+ 42 - 3
t/integration/test_canvas.py

@@ -10,9 +10,10 @@ from celery.exceptions import TimeoutError
 from celery.result import AsyncResult, GroupResult
 
 from .conftest import flaky
-from .tasks import (add, add_replaced, add_to_all, collect_ids, delayed_sum,
-                    delayed_sum_with_soft_guard, ids, redis_echo,
-                    second_order_replace1)
+from .tasks import (add, add_chord_to_chord, add_replaced, add_to_all,
+                    add_to_all_to_chord, collect_ids, delayed_sum,
+                    delayed_sum_with_soft_guard, identity, ids, redis_echo,
+                    second_order_replace1, tsum)
 
 TIMEOUT = 120
 
@@ -211,6 +212,44 @@ class test_chord:
             len(redis_client.execute_command('PUBSUB CHANNELS'))
         assert channels_after < channels_before
 
+    @flaky
+    def test_replaced_nested_chord(self, manager):
+        try:
+            manager.app.backend.ensure_chords_allowed()
+        except NotImplementedError as e:
+            raise pytest.skip(e.args[0])
+
+        c1 = chord([
+            chord(
+                [add.s(1, 2), add_replaced.s(3, 4)],
+                add_to_all.s(5),
+            ) | tsum.s(),
+            chord(
+                [add_replaced.s(6, 7), add.s(0, 0)],
+                add_to_all.s(8),
+            ) | tsum.s(),
+        ], add_to_all.s(9))
+        res1 = c1()
+        assert res1.get(timeout=TIMEOUT) == [29, 38]
+
+    @flaky
+    def test_add_to_chord(self, manager):
+        if not manager.app.conf.result_backend.startswith('redis'):
+            raise pytest.skip('Requires redis result backend.')
+
+        c = group([add_to_all_to_chord.s([1, 2, 3], 4)]) | identity.s()
+        res = c()
+        assert res.get() == [0, 5, 6, 7]
+
+    @flaky
+    def test_add_chord_to_chord(self, manager):
+        if not manager.app.conf.result_backend.startswith('redis'):
+            raise pytest.skip('Requires redis result backend.')
+
+        c = group([add_chord_to_chord.s([1, 2, 3], 4)]) | identity.s()
+        res = c()
+        assert res.get() == [0, 5 + 6 + 7]
+
     @flaky
     def test_group_chain(self, manager):
         if not manager.app.conf.result_backend.startswith('redis'):

+ 7 - 2
t/unit/tasks/test_tasks.py

@@ -10,7 +10,7 @@ from kombu import Queue
 
 from celery import Task, group, uuid
 from celery.app.task import _reprtask
-from celery.exceptions import Ignore, Retry
+from celery.exceptions import Ignore, ImproperlyConfigured, Retry
 from celery.five import items, range, string_t
 from celery.result import EagerResult
 from celery.utils.time import parse_iso8601
@@ -589,6 +589,12 @@ class test_tasks(TasksCase):
         with pytest.raises(Ignore):
             self.mytask.replace(sig1)
 
+    def test_replace_with_chord(self):
+        sig1 = Mock(name='sig1')
+        sig1.options = {'chord': None}
+        with pytest.raises(ImproperlyConfigured):
+            self.mytask.replace(sig1)
+
     @pytest.mark.usefixtures('depends_on_current_app')
     def test_replace_callback(self):
         c = group([self.mytask.s()], app=self.app)
@@ -617,7 +623,6 @@ class test_tasks(TasksCase):
             self.mytask.replace(c)
         except Ignore:
             mocked_signature.return_value.set.assert_called_with(
-                chord=None,
                 link='callbacks',
                 link_error='errbacks',
             )