Browse Source

Fix #3725: Task replaced with group does not complete (#3731)

* Add add_to_all test task

* Add new failing test test_chord.test_group_chain

- Exercises a task replaced with a group

* Fix calculation of __length_hint__ on chord.tasks when it's a group

Before group.__iter__() was removed in 8c7ac5d, passing either a list or
a group through list() would yield the same thing: the list of tasks.
After the removal of group.__iter__(), list(g) yields the same thing as
g.keys(), so we need to handle the group case explicitly now.

Using len(g.keys()) for __length_hint__ will only happen to be the
correct chord size sometimes; the rest of the time it will result in
chords that never get unlocked due to a mismatch between the computed
chord size and the size of the complete set of results returned.

Makes test_chord.test_group_chain pass.

Fixes celery/celery#3725
Morgan Doocy 8 years ago
parent
commit
96177c6e5d
3 changed files with 25 additions and 4 deletions
  1. 4 2
      celery/canvas.py
  2. 8 1
      t/integration/tasks.py
  3. 13 1
      t/integration/test_canvas.py

+ 4 - 2
celery/canvas.py

@@ -1253,7 +1253,7 @@ class chord(Signature):
         )
         )
 
 
     def _traverse_tasks(self, tasks, value=None):
     def _traverse_tasks(self, tasks, value=None):
-        stack = deque(list(tasks))
+        stack = deque(tasks)
         while stack:
         while stack:
             task = stack.popleft()
             task = stack.popleft()
             if isinstance(task, group):
             if isinstance(task, group):
@@ -1262,7 +1262,9 @@ class chord(Signature):
                 yield task if value is None else value
                 yield task if value is None else value
 
 
     def __length_hint__(self):
     def __length_hint__(self):
-        return sum(self._traverse_tasks(self.tasks, 1))
+        tasks = (self.tasks.tasks if isinstance(self.tasks, group)
+                 else self.tasks)
+        return sum(self._traverse_tasks(tasks, 1))
 
 
     def run(self, header, body, partial_args, app=None, interval=None,
     def run(self, header, body, partial_args, app=None, interval=None,
             countdown=1, max_retries=None, eager=False,
             countdown=1, max_retries=None, eager=False,

+ 8 - 1
t/integration/tasks.py

@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 from time import sleep
 from time import sleep
-from celery import shared_task
+from celery import shared_task, group
 from celery.utils.log import get_task_logger
 from celery.utils.log import get_task_logger
 
 
 logger = get_task_logger(__name__)
 logger = get_task_logger(__name__)
@@ -19,6 +19,13 @@ def add_replaced(self, x, y):
     raise self.replace(add.s(x, y))
     raise self.replace(add.s(x, y))
 
 
 
 
+@shared_task(bind=True)
+def add_to_all(self, nums, val):
+    """Add the given value to all supplied numbers."""
+    subtasks = [add.s(num, val) for num in nums]
+    raise self.replace(group(*subtasks))
+
+
 @shared_task
 @shared_task
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
     """Task that both logs and print strings containing funny characters."""
     """Task that both logs and print strings containing funny characters."""

+ 13 - 1
t/integration/test_canvas.py

@@ -4,7 +4,7 @@ from celery import chain, chord, group
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.result import AsyncResult, GroupResult
 from celery.result import AsyncResult, GroupResult
 from .conftest import flaky
 from .conftest import flaky
-from .tasks import add, add_replaced, collect_ids, ids
+from .tasks import add, add_replaced, add_to_all, collect_ids, ids
 
 
 TIMEOUT = 120
 TIMEOUT = 120
 
 
@@ -88,6 +88,18 @@ def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
 
 
 class test_chord:
 class test_chord:
 
 
+    @flaky
+    def test_group_chain(self, manager):
+        if not manager.app.conf.result_backend.startswith('redis'):
+            raise pytest.skip('Requires redis result backend.')
+        c = (
+            add.s(2, 2) |
+            group(add.s(i) for i in range(4)) |
+            add_to_all.s(8)
+        )
+        res = c()
+        assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15]
+
     @flaky
     @flaky
     def test_parent_ids(self, manager):
     def test_parent_ids(self, manager):
         if not manager.app.conf.result_backend.startswith('redis'):
         if not manager.app.conf.result_backend.startswith('redis'):