Browse Source

Merge branch 'group_partial'

Ask Solem 12 years ago
parent
commit
0e95530d88
2 changed files with 22 additions and 19 deletions
  1. 19 15
      celery/app/builtins.py
  2. 3 4
      celery/canvas.py

+ 19 - 15
celery/app/builtins.py

@@ -125,17 +125,18 @@ def add_group_task(app):
         name = 'celery.group'
         accept_magic_kwargs = False
 
-        def run(self, tasks, result, group_id):
+        def run(self, tasks, result, group_id, partial_args):
             app = self.app
             result = from_serializable(result)
+            # any partial args are added to the first task in the group
+            taskit = (subtask(task) if i else subtask(task).clone(partial_args)
+                        for i, task in enumerate(tasks))
             if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
                 return app.GroupResult(result.id,
-                        [subtask(task).apply(group_id=group_id)
-                            for task in tasks])
+                        [task.apply(group_id=group_id) for task in taskit])
             with app.default_producer() as pub:
-                [subtask(task).apply_async(group_id=group_id, publisher=pub,
-                                           add_to_parent=False)
-                        for task in tasks]
+                [task.apply_async(group_id=group_id, publisher=pub,
+                                  add_to_parent=False) for task in taskit]
             parent = get_current_worker_task()
             if parent:
                 parent.request.children.append(result)
@@ -158,17 +159,17 @@ def add_group_task(app):
             tasks, results = zip(*[prepare_member(task) for task in tasks])
             return tasks, self.app.GroupResult(group_id, results), group_id
 
-        def apply_async(self, args=(), kwargs={}, **options):
+        def apply_async(self, partial_args=(), kwargs={}, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:
                 return self.apply(args, kwargs, **options)
             tasks, result, gid = self.prepare(options, **kwargs)
-            super(Group, self).apply_async((
-                list(tasks), result.serializable(), gid), **options)
+            super(Group, self).apply_async((list(tasks),
+                result.serializable(), gid, partial_args), **options)
             return result
 
         def apply(self, args=(), kwargs={}, **options):
             return super(Group, self).apply(
-                    self.prepare(options, **kwargs), **options)
+                    self.prepare(options, **kwargs) + (args, ), **options)
     return Group
 
 
@@ -232,6 +233,8 @@ def add_chain_task(app):
             if task_id:
                 tasks[-1].set(task_id=task_id)
                 result = tasks[-1].type.AsyncResult(task_id)
+            print("TASKS[-1]: %r" % (tasks[-1], ))
+            print("ID: %r" % (tasks[-1].options, ))
             tasks[0].apply_async()
             return result
 
@@ -260,8 +263,8 @@ def add_chord_task(app):
         accept_magic_kwargs = False
         ignore_result = False
 
-        def run(self, header, body, interval=1, max_retries=None,
-                propagate=False, eager=False, **kwargs):
+        def run(self, header, body, partial_args=(), interval=1,
+                max_retries=None, propagate=False, eager=False, **kwargs):
             if not isinstance(header, group):
                 header = group(map(maybe_subtask, header))
             r = []
@@ -276,13 +279,13 @@ def add_chord_task(app):
                 opts['group_id'] = group_id
                 r.append(app.AsyncResult(tid))
             if eager:
-                return header.apply(task_id=group_id)
+                return header.apply(args=partial_args, task_id=group_id)
             app.backend.on_chord_apply(group_id, body,
                                        interval=interval,
                                        max_retries=max_retries,
                                        propagate=propagate,
                                        result=r)
-            return header(task_id=group_id)
+            return header(*partial_args, task_id=group_id)
 
         def apply_async(self, args=(), kwargs={}, task_id=None, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:
@@ -296,7 +299,8 @@ def add_chord_task(app):
             if chord:
                 body.set(chord=chord)
             callback_id = body.options.setdefault('task_id', task_id or uuid())
-            parent = super(Chord, self).apply_async((header, body), **options)
+            parent = super(Chord, self).apply_async((header, body, args),
+                                                    **options)
             body_result = self.AsyncResult(callback_id)
             body_result.parent = parent
             return body_result

+ 3 - 4
celery/canvas.py

@@ -296,18 +296,17 @@ class group(Signature):
     def __init__(self, *tasks, **options):
         if len(tasks) == 1:
             tasks = _maybe_group(tasks[0])
-        Signature.__init__(self, 'celery.group', (),
-                {'tasks': tasks}, options, immutable=True)
+        Signature.__init__(self, 'celery.group', (), {'tasks': tasks}, options)
         self.tasks, self.subtask_type = tasks, 'group'
 
     @classmethod
     def from_dict(self, d):
         return group(d['kwargs']['tasks'], **kwdict(d['options']))
 
-    def __call__(self, **options):
+    def __call__(self, *partial_args, **options):
         tasks, result, gid = self.type.prepare(options,
                                 map(Signature.clone, self.tasks))
-        return self.type(tasks, result, gid)
+        return self.type(tasks, result, gid, partial_args)
 
     def skew(self, start=1.0, stop=None, step=1.0):
         _next_skew = fxrange(start, stop, step, repeatlast=True).next