Browse Source

Fix #3725: Task replaced with group does not complete (#3731)

* Add add_to_all test task

* Add new failing test test_chord.test_group_chain

- Exercises a task replaced with a group

* Fix calculation of __length_hint__ on chord.tasks when it's a group

Before group.__iter__() was removed in 8c7ac5d, passing either a list or
a group through list() would yield the same thing: the list of tasks.
After the removal of group.__iter__(), list(g) yields the same thing as
g.keys(), so we need to handle the group case explicitly now.

Using len(g.keys()) for __length_hint__ will only happen to be the
correct chord size sometimes; the rest of the time it will result in
chords that never get unlocked due to a mismatch between the computed
chord size and the size of the complete set of results returned.

Makes test_chord.test_group_chain pass.

Fixes celery/celery#3725
Morgan Doocy 8 years ago
parent
commit
96177c6e5d
3 changed files with 25 additions and 4 deletions
  1. 4 2
      celery/canvas.py
  2. 8 1
      t/integration/tasks.py
  3. 13 1
      t/integration/test_canvas.py

+ 4 - 2
celery/canvas.py

@@ -1253,7 +1253,7 @@ class chord(Signature):
         )
 
     def _traverse_tasks(self, tasks, value=None):
-        stack = deque(list(tasks))
+        stack = deque(tasks)
         while stack:
             task = stack.popleft()
             if isinstance(task, group):
@@ -1262,7 +1262,9 @@ class chord(Signature):
                 yield task if value is None else value
 
     def __length_hint__(self):
-        return sum(self._traverse_tasks(self.tasks, 1))
+        tasks = (self.tasks.tasks if isinstance(self.tasks, group)
+                 else self.tasks)
+        return sum(self._traverse_tasks(tasks, 1))
 
     def run(self, header, body, partial_args, app=None, interval=None,
             countdown=1, max_retries=None, eager=False,

+ 8 - 1
t/integration/tasks.py

@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from __future__ import absolute_import, unicode_literals
 from time import sleep
-from celery import shared_task
+from celery import shared_task, group
 from celery.utils.log import get_task_logger
 
 logger = get_task_logger(__name__)
@@ -19,6 +19,13 @@ def add_replaced(self, x, y):
     raise self.replace(add.s(x, y))
 
 
+@shared_task(bind=True)
+def add_to_all(self, nums, val):
+    """Add the given value to all supplied numbers."""
+    subtasks = [add.s(num, val) for num in nums]
+    raise self.replace(group(*subtasks))
+
+
 @shared_task
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
     """Task that both logs and print strings containing funny characters."""

+ 13 - 1
t/integration/test_canvas.py

@@ -4,7 +4,7 @@ from celery import chain, chord, group
 from celery.exceptions import TimeoutError
 from celery.result import AsyncResult, GroupResult
 from .conftest import flaky
-from .tasks import add, add_replaced, collect_ids, ids
+from .tasks import add, add_replaced, add_to_all, collect_ids, ids
 
 TIMEOUT = 120
 
@@ -88,6 +88,18 @@ def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
 
 class test_chord:
 
+    @flaky
+    def test_group_chain(self, manager):
+        if not manager.app.conf.result_backend.startswith('redis'):
+            raise pytest.skip('Requires redis result backend.')
+        c = (
+            add.s(2, 2) |
+            group(add.s(i) for i in range(4)) |
+            add_to_all.s(8)
+        )
+        res = c()
+        assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15]
+
     @flaky
     def test_parent_ids(self, manager):
         if not manager.app.conf.result_backend.startswith('redis'):