ソースを参照

Chains can now be combined and accept partial args

Ask Solem 12 年 前
コミット
ec51735fd0
2 ファイル変更22 行追加15 行削除
  1. 21 13
      celery/app/builtins.py
  2. 1 2
      celery/canvas.py

+ 21 - 13
celery/app/builtins.py

@@ -182,13 +182,16 @@ def add_chain_task(app):
         name = 'celery.chain'
         accept_magic_kwargs = False
 
-        def prepare_steps(self, tasks, opts):
+        def prepare_steps(self, args, tasks):
             steps = deque(tasks)
             next_step = prev_task = prev_res = None
             tasks, results = [], []
             while steps:
-                task = maybe_subtask(steps.popleft())
-                tid = task.options.get('task_id') or uuid()
+                # First task get partial args from chain.
+                task = maybe_subtask(steps.popleft()).clone()
+                tid = task.options.get('task_id')
+                if tid is None:
+                    tid = task.options['task_id'] = uuid()
                 res = task.type.AsyncResult(tid)
 
                 # automatically upgrade group(..) | s to chord(group, s)
@@ -198,9 +201,7 @@ def add_chain_task(app):
                     except IndexError:
                         next_step = None
                 if next_step is not None:
-                    task = chord(task, body=next_step, task_id=tid, **opts)
-                else:
-                    task = task.clone(task_id=tid, **opts)
+                    task = chord(task, body=next_step)
                 if prev_task:
                     # link previous task to this task.
                     prev_task.link(task)
@@ -210,21 +211,28 @@ def add_chain_task(app):
                 results.append(res)
                 tasks.append(task)
                 prev_task, prev_res = task, res
+
+            # First task receives partial args for chain()
+            if args and not tasks[0].immutable:
+                tasks[0].args = tuple(args) + tuple(tasks[0].args or ())
             return tasks, results
 
-        def apply_async(self, args=(), kwargs={}, **options):
+        def apply_async(self, args=(), kwargs={}, group_id=None, chord=None,
+                task_id=None, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:
                 return self.apply(args, kwargs, **options)
             options.pop('publisher', None)
-            group_id = options.pop('group_id', None)
-            chord_id = options.pop('chord', None)
-            tasks, results = self.prepare_steps(kwargs['tasks'], options)
+            tasks, results = self.prepare_steps(args, kwargs['tasks'])
+            result = results[-1]
             if group_id:
                 tasks[-1].set(group_id=group_id)
-            if chord_id:
-                tasks[-1].set(chord=chord_id)
+            if chord:
+                tasks[-1].set(chord=chord)
+            if task_id:
+                tasks[-1].set(task_id=task_id)
+                result = tasks[-1].type.AsyncResult(task_id)
             tasks[0].apply_async()
-            return results[-1]
+            return result
 
         def apply(self, args=(), kwargs={}, **options):
             tasks = [maybe_subtask(task).clone() for task in kwargs['tasks']]

+ 1 - 2
celery/canvas.py

@@ -201,8 +201,7 @@ class chain(Signature):
 
     def __init__(self, *tasks, **options):
         tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks
-        Signature.__init__(self, 'celery.chain', (), {'tasks': tasks},
-                           options, immutable=True)
+        Signature.__init__(self, 'celery.chain', (), {'tasks': tasks}, options)
         self.tasks = tasks
         self.subtask_type = 'chain'