Browse Source

Canvas: Unroll groups within groups to a single group. Closes #1509

Ask Solem 11 years ago
parent
commit
0673da5c09
1 changed files with 59 additions and 20 deletions
  1. 59 20
      celery/canvas.py

+ 59 - 20
celery/canvas.py

@@ -341,6 +341,7 @@ class Signature(dict):
 
 @Signature.register_type
 class chain(Signature):
+    tasks = _getitem_property('kwargs.tasks')
 
     def __init__(self, *tasks, **options):
         tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0])
@@ -348,7 +349,6 @@ class chain(Signature):
         Signature.__init__(
             self, 'celery.chain', (), {'tasks': tasks}, **options
         )
-        self.tasks = tasks
         self.subtask_type = 'chain'
 
     def __call__(self, *args, **kwargs):
@@ -557,6 +557,7 @@ def _maybe_group(tasks):
 
 @Signature.register_type
 class group(Signature):
+    tasks = _getitem_property('kwargs.tasks')
 
     def __init__(self, *tasks, **options):
         if len(tasks) == 1:
@@ -564,7 +565,7 @@ class group(Signature):
         Signature.__init__(
             self, 'celery.group', (), {'tasks': tasks}, **options
         )
-        self.tasks, self.subtask_type = tasks, 'group'
+        self.subtask_type = 'group'
 
     @classmethod
     def from_dict(self, d, app=None):
@@ -586,9 +587,17 @@ class group(Signature):
                 else:
                     # serialized sigs must be converted to Signature.
                     task = from_dict(task)
-            if partial_args and not task.immutable:
-                task.args = tuple(partial_args) + tuple(task.args)
-            yield task, task.freeze(group_id=group_id, root_id=root_id)
+                if isinstance(task, group):
+                    # needs yield_from :(
+                    unroll = task._prepared(
+                        task.tasks, partial_args, group_id, root_id,
+                    )
+                    for taskN, resN in unroll:
+                        yield taskN, resN
+                else:
+                    if partial_args and not task.immutable:
+                        task.args = tuple(partial_args) + tuple(task.args)
+                    yield task, task.freeze(group_id=group_id, root_id=root_id)
 
     def _apply_tasks(self, tasks, producer=None, app=None, **options):
         app = app or self.app
@@ -650,6 +659,17 @@ class group(Signature):
     def __call__(self, *partial_args, **options):
         return self.apply_async(partial_args, **options)
 
+    def _freeze_unroll(self, new_tasks, group_id, chord, root_id):
+        stack = deque(self.tasks)
+        while stack:
+            task = maybe_signature(stack.popleft(), app=self._app).clone()
+            if isinstance(task, group):
+                stack.extendleft(task.tasks)
+            else:
+                new_tasks.append(task)
+                yield task.freeze(group_id=group_id,
+                                  chord=chord, root_id=root_id)
+
     def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
         opts = self.options
         try:
@@ -659,16 +679,18 @@ class group(Signature):
         if group_id:
             opts['group_id'] = group_id
         if chord:
-            opts['chord'] = group_id
+            opts['chord'] = chord
         root_id = opts.setdefault('root_id', root_id)
-        new_tasks, results = [], []
-        for task in self.tasks:
-            task = maybe_signature(task, app=self._app).clone()
-            results.append(task.freeze(
-                group_id=group_id, chord=chord, root_id=root_id,
-            ))
-            new_tasks.append(task)
-        self.tasks = self.kwargs['tasks'] = new_tasks
+        new_tasks = []
+        # Need to unroll subgroups early so that chord gets the
+        # right result instance for chord_unlock etc.
+        results = list(self._freeze_unroll(
+            new_tasks, group_id, chord, root_id,
+        ))
+        if isinstance(self.tasks, MutableSequence):
+            self.tasks[:] = new_tasks
+        else:
+            self.tasks = new_tasks
         return self.app.GroupResult(gid, results)
     _freeze = freeze
 
@@ -689,7 +711,7 @@ class group(Signature):
         app = self._app
         if app is None:
             try:
-                app = self.tasks[0]._app
+                app = self.tasks[0].app
             except (KeyError, IndexError):
                 pass
         return app if app is not None else current_app
@@ -723,11 +745,14 @@ class chord(Signature):
 
     @cached_property
     def app(self):
+        return self._get_app(self.body)
+
+    def _get_app(self, body=None):
         app = self._app
         if app is None:
             app = self.tasks[0]._app
-            if app is None:
-                app = self.body._app
+            if app is None and body is not None:
+                app = body._app
         return app if app is not None else current_app
 
     def apply_async(self, args=(), kwargs={}, task_id=None,
@@ -736,7 +761,7 @@ class chord(Signature):
         body = kwargs.get('body') or self.kwargs['body']
         kwargs = dict(self.kwargs, **kwargs)
         body = body.clone(**options)
-        app = self.app
+        app = self._get_app(body)
         tasks = (self.tasks.clone() if isinstance(self.tasks, group)
                  else group(self.tasks))
         if app.conf.CELERY_ALWAYS_EAGER:
@@ -752,15 +777,29 @@ class chord(Signature):
             args=(tasks.apply().get(propagate=propagate), ),
         )
 
+    def _traverse_tasks(self, tasks, value=None):
+        stack = deque(tasks)
+        while stack:
+            task = stack.popleft()
+            if isinstance(task, group):
+                stack.extend(task.tasks)
+            else:
+                yield task if value is None else value
+
+    def __length_hint__(self):
+        return sum(self._traverse_tasks(self.tasks, 1))
+
     def run(self, header, body, partial_args, app=None, interval=None,
             countdown=1, max_retries=None, propagate=None, eager=False,
             task_id=None, **options):
-        app = app or self.app
+        app = app or self._get_app(body)
         propagate = (app.conf.CELERY_CHORD_PROPAGATES
                      if propagate is None else propagate)
         group_id = uuid()
         root_id = body.options.get('root_id')
-        body.setdefault('chord_size', len(header.tasks))
+        if 'chord_size' not in body:
+            body['chord_size'] = self.__length_hint__()
+
         results = header.freeze(
             group_id=group_id, chord=body, root_id=root_id).results
         bodyres = body.freeze(task_id, root_id=root_id)