Browse Source

Refactors Chain.apply_async

Ask Solem 12 years ago
parent
commit
aa3ccf2aba
2 changed files with 32 additions and 24 deletions
  1. 32 20
      celery/app/builtins.py
  2. 0 4
      celery/result.py

+ 32 - 20
celery/app/builtins.py

@@ -10,6 +10,7 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
+from collections import deque
 from itertools import starmap
 
 from celery._state import get_current_worker_task
@@ -181,37 +182,48 @@ def add_chain_task(app):
         name = 'celery.chain'
         accept_magic_kwargs = False
 
+        def prepare_steps(self, tasks, opts):
+            steps = deque(tasks)
+            next_step = prev_task = prev_res = None
+            tasks, results = [], []
+            while steps:
+                task = maybe_subtask(steps.popleft())
+                tid = task.options.get('task_id') or uuid()
+                res = task.type.AsyncResult(tid)
+
+                # automatically upgrade group(..) | s to chord(group, s)
+                if isinstance(task, group):
+                    try:
+                        next_step = steps.popleft()
+                    except IndexError:
+                        next_step = None
+                if next_step is not None:
+                    task = chord(task, body=next_step, task_id=tid, **opts)
+                else:
+                    task = task.clone(task_id=tid, **opts)
+                if prev_task:
+                    # link previous task to this task.
+                    prev_task.link(task)
+                    # set the results parent attribute.
+                    res.parent = prev_res
+
+                results.append(res)
+                tasks.append(task)
+                prev_task, prev_res = task, res
+            return tasks, results
+
         def apply_async(self, args=(), kwargs={}, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:
                 return self.apply(args, kwargs, **options)
             options.pop('publisher', None)
             group_id = options.pop('group_id', None)
             chord_id = options.pop('chord', None)
-
-            def prepare_steps(tasks, opts):
-                i, size = 0, len(tasks);
-                while i < size:
-                    sig = maybe_subtask(tasks[i])
-                    task_id = sig.options.get('task_id')
-                    if isinstance(sig, group) and i + 1 < size:
-                        i += 1
-                        sig = chord(sig, body=tasks[i],
-                                    task_id=task_id or uuid(), **opts)
-                    else:
-                        sig = sig.clone(task_id=task_id or uuid(), **opts)
-                    yield sig
-                    i += 1
-
-            tasks = list(prepare_steps(kwargs['tasks'], options))
-            reduce(lambda a, b: a.link(b), tasks)
+            tasks, results = self.prepare_steps(kwargs['tasks'], options)
             if group_id:
                 tasks[-1].set(group_id=group_id)
             if chord_id:
                 tasks[-1].set(chord=chord_id)
             tasks[0].apply_async()
-            results = [task.type.AsyncResult(task.options['task_id'])
-                            for task in tasks]
-            reduce(lambda a, b: a.set_parent(b), reversed(results))
             return results[-1]
 
         def apply(self, args=(), kwargs={}, **options):

+ 0 - 4
celery/result.py

@@ -217,10 +217,6 @@ class AsyncResult(ResultBase):
     def __reduce_args__(self):
         return self.id, self.backend, self.task_name, self.parent
 
-    def set_parent(self, parent):
-        self.parent = parent
-        return parent
-
     @cached_property
     def graph(self):
         return self.build_graph()