Browse Source

Fix #3726 - Chaining of replaced tasks (#3730)

* Add add_replaced test task

* Make test_complex_chain fail by adding a replaced task

- Update expected output

* Copy replaced task's request chain in reverse

- Make t/integration/test_canvas.py::test_chain::test_complex_chain pass
Morgan Doocy 8 years ago
parent
commit
9d2566e9c0
3 changed files with 10 additions and 4 deletions
  1. 1 1
      celery/app/task.py
  2. 6 0
      t/integration/tasks.py
  3. 3 3
      t/integration/test_canvas.py

+ 1 - 1
celery/app/task.py

@@ -850,7 +850,7 @@ class Task(object):
             chord = None
             chord = None
 
 
         if self.request.chain:
         if self.request.chain:
-            for t in self.request.chain:
+            for t in reversed(self.request.chain):
                 sig |= signature(t, app=self.app)
                 sig |= signature(t, app=self.app)
 
 
         sig.freeze(self.request.id,
         sig.freeze(self.request.id,

+ 6 - 0
t/integration/tasks.py

@@ -13,6 +13,12 @@ def add(x, y):
     return x + y
     return x + y
 
 
 
 
+@shared_task(bind=True)
+def add_replaced(self, x, y):
+    """Add two numbers (via the add task)."""
+    raise self.replace(add.s(x, y))
+
+
 @shared_task
 @shared_task
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
     """Task that both logs and print strings containing funny characters."""
     """Task that both logs and print strings containing funny characters."""

+ 3 - 3
t/integration/test_canvas.py

@@ -4,7 +4,7 @@ from celery import chain, chord, group
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.result import AsyncResult, GroupResult
 from celery.result import AsyncResult, GroupResult
 from .conftest import flaky
 from .conftest import flaky
-from .tasks import add, collect_ids, ids
+from .tasks import add, add_replaced, collect_ids, ids
 
 
 TIMEOUT = 120
 TIMEOUT = 120
 
 
@@ -20,12 +20,12 @@ class test_chain:
     def test_complex_chain(self, manager):
     def test_complex_chain(self, manager):
         c = (
         c = (
             add.s(2, 2) | (
             add.s(2, 2) | (
-                add.s(4) | add.s(8) | add.s(16)
+                add.s(4) | add_replaced.s(8) | add.s(16) | add.s(32)
             ) |
             ) |
             group(add.s(i) for i in range(4))
             group(add.s(i) for i in range(4))
         )
         )
         res = c()
         res = c()
-        assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
+        assert res.get(timeout=TIMEOUT) == [64, 65, 66, 67]
 
 
     @flaky
     @flaky
     def test_parent_ids(self, manager, num=10):
     def test_parent_ids(self, manager, num=10):