Browse Source

Pass parameters to the first task in chain().apply_async(params) (#4690)

Fixes #4643
tothegump 7 years ago
parent
commit
73c50e4782
2 changed files with 10 additions and 7 deletions
  1. 7 4
      celery/canvas.py
  2. 3 3
      t/unit/tasks/test_canvas.py

+ 7 - 4
celery/canvas.py

@@ -569,7 +569,7 @@ class _chain(Signature):
                 if args and not self.immutable else self.args)
 
         tasks, results = self.prepare_steps(
-            args, self.tasks, root_id, parent_id, link_error, app,
+            args, kwargs, self.tasks, root_id, parent_id, link_error, app,
             task_id, group_id, chord,
         )
 
@@ -589,12 +589,12 @@ class _chain(Signature):
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         _, results = self._frozen = self.prepare_steps(
-            self.args, self.tasks, root_id, parent_id, None,
+            self.args, self.kwargs, self.tasks, root_id, parent_id, None,
             self.app, _id, group_id, chord, clone=False,
         )
         return results[0]
 
-    def prepare_steps(self, args, tasks,
+    def prepare_steps(self, args, kwargs, tasks,
                       root_id=None, parent_id=None, link_error=None, app=None,
                       last_task_id=None, group_id=None, chord_body=None,
                       clone=True, from_dict=Signature.from_dict):
@@ -632,7 +632,10 @@ class _chain(Signature):
 
             # first task gets partial args from chain
             if clone:
-                task = task.clone(args) if is_first_task else task.clone()
+                if is_first_task:
+                    task = task.clone(args, kwargs)
+                else:
+                    task = task.clone()
             elif is_first_task:
                 task.args = tuple(args) + tuple(task.args)
 

+ 3 - 3
t/unit/tasks/test_canvas.py

@@ -333,7 +333,7 @@ class test_chain(CanvasCase):
             self.add.s(30)
         )
         c._use_link = True
-        tasks, results = c.prepare_steps((), c.tasks)
+        tasks, results = c.prepare_steps((), {}, c.tasks)
 
         assert tasks[-1].args[0] == 5
         assert isinstance(tasks[-2], chord)
@@ -347,7 +347,7 @@ class test_chain(CanvasCase):
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = True
-        tasks2, _ = c2.prepare_steps((), c2.tasks)
+        tasks2, _ = c2.prepare_steps((), {}, c2.tasks)
         assert isinstance(tasks2[0], group)
 
     def test_group_to_chord__protocol_2__or(self):
@@ -372,7 +372,7 @@ class test_chain(CanvasCase):
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = False
-        tasks2, _ = c2.prepare_steps((), c2.tasks)
+        tasks2, _ = c2.prepare_steps((), {}, c2.tasks)
         assert isinstance(tasks2[0], group)
 
     def test_apply_options(self):