浏览代码

Fix #3726 - Chaining of replaced tasks (#3730)

* Add add_replaced test task

* Make test_complex_chain fail by adding a replaced task

- Update expected output

* Copy replaced task's request chain in reverse

- Make t/integration/test_canvas.py::test_chain::test_complex_chain pass
Morgan Doocy 8 年之前
父节点
当前提交
9d2566e9c0
共有 3 个文件被更改,包括 10 次插入4 次删除
  1. 1 1
      celery/app/task.py
  2. 6 0
      t/integration/tasks.py
  3. 3 3
      t/integration/test_canvas.py

+ 1 - 1
celery/app/task.py

@@ -850,7 +850,7 @@ class Task(object):
             chord = None
             chord = None
 
 
         if self.request.chain:
         if self.request.chain:
-            for t in self.request.chain:
+            for t in reversed(self.request.chain):
                 sig |= signature(t, app=self.app)
                 sig |= signature(t, app=self.app)
 
 
         sig.freeze(self.request.id,
         sig.freeze(self.request.id,

+ 6 - 0
t/integration/tasks.py

@@ -13,6 +13,12 @@ def add(x, y):
     return x + y
     return x + y
 
 
 
 
+@shared_task(bind=True)
+def add_replaced(self, x, y):
+    """Add two numbers (via the add task)."""
+    raise self.replace(add.s(x, y))
+
+
 @shared_task
 @shared_task
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
 def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
     """Task that both logs and print strings containing funny characters."""
     """Task that both logs and print strings containing funny characters."""

+ 3 - 3
t/integration/test_canvas.py

@@ -4,7 +4,7 @@ from celery import chain, chord, group
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.result import AsyncResult, GroupResult
 from celery.result import AsyncResult, GroupResult
 from .conftest import flaky
 from .conftest import flaky
-from .tasks import add, collect_ids, ids
+from .tasks import add, add_replaced, collect_ids, ids
 
 
 TIMEOUT = 120
 TIMEOUT = 120
 
 
@@ -20,12 +20,12 @@ class test_chain:
     def test_complex_chain(self, manager):
     def test_complex_chain(self, manager):
         c = (
         c = (
             add.s(2, 2) | (
             add.s(2, 2) | (
-                add.s(4) | add.s(8) | add.s(16)
+                add.s(4) | add_replaced.s(8) | add.s(16) | add.s(32)
             ) |
             ) |
             group(add.s(i) for i in range(4))
             group(add.s(i) for i in range(4))
         )
         )
         res = c()
         res = c()
-        assert res.get(timeout=TIMEOUT) == [32, 33, 34, 35]
+        assert res.get(timeout=TIMEOUT) == [64, 65, 66, 67]
 
 
     @flaky
     @flaky
     def test_parent_ids(self, manager, num=10):
     def test_parent_ids(self, manager, num=10):