Jelajahi Sumber

refactors some hairy canvas code

Ask Solem 12 tahun lalu
induk
melakukan
c0b455011c
3 mengubah file dengan 27 tambahan dan 21 penghapusan
  1. 3 21
      celery/app/builtins.py
  2. 1 0
      celery/bin/celery.py
  3. 23 0
      celery/canvas.py

+ 3 - 21
celery/app/builtins.py

@@ -186,7 +186,6 @@ def add_group_task(app):
 @shared_task
 @shared_task
 def add_chain_task(app):
 def add_chain_task(app):
     from celery.canvas import chord, group, maybe_subtask
     from celery.canvas import chord, group, maybe_subtask
-    from celery.result import GroupResult
     _app = app
     _app = app
 
 
     class Chain(app.Task):
     class Chain(app.Task):
@@ -203,33 +202,16 @@ def add_chain_task(app):
                 # First task get partial args from chain.
                 # First task get partial args from chain.
                 task = maybe_subtask(steps.popleft())
                 task = maybe_subtask(steps.popleft())
                 task = task.clone() if i else task.clone(args)
                 task = task.clone() if i else task.clone(args)
-                AsyncResult = task.type.AsyncResult
+                res = task._freeze()
                 i += 1
                 i += 1
-                tid = task.options.get('task_id')
-                if tid is None:
-                    tid = task.options['task_id'] = uuid()
-                res = AsyncResult(tid)
 
 
-                # groups must be turned into GroupResults
                 if isinstance(task, group):
                 if isinstance(task, group):
-                    #
-                    gid = task.options.get('group')
-                    if gid is None:
-                        gid = task.options['group'] = uuid()
-                    group_results = []
-                    for sub in task.tasks:
-                        tid = sub.options.get('task_id')
-                        if tid is None:
-                            tid = sub.options['task_id'] = uuid()
-                        group_results.append(AsyncResult(tid))
-                    res = GroupResult(gid, group_results)
-
                     # automatically upgrade group(..) | s to chord(group, s)
                     # automatically upgrade group(..) | s to chord(group, s)
                     try:
                     try:
                         next_step = steps.popleft()
                         next_step = steps.popleft()
-                        task = chord(task, body=next_step, task_id=tid)
+                        task = chord(task, body=next_step, task_id=res.task_id)
                     except IndexError:
                     except IndexError:
-                        res = GroupResult(gid, group_results)
+                        pass
                 if prev_task:
                 if prev_task:
                     # link previous task to this task.
                     # link previous task to this task.
                     prev_task.link(task)
                     prev_task.link(task)

+ 1 - 0
celery/bin/celery.py

@@ -11,6 +11,7 @@ from __future__ import with_statement
 
 
 import anyjson
 import anyjson
 import heapq
 import heapq
+import os
 import sys
 import sys
 import warnings
 import warnings
 
 

+ 23 - 0
celery/canvas.py

@@ -20,6 +20,7 @@ from kombu.utils import cached_property, fxrange, kwdict, reprcall, uuid
 from celery import current_app
 from celery import current_app
 from celery.local import Proxy
 from celery.local import Proxy
 from celery.utils.compat import chain_from_iterable
 from celery.utils.compat import chain_from_iterable
+from celery.result import GroupResult
 from celery.utils.functional import (
 from celery.utils.functional import (
     maybe_list, is_list, regen,
     maybe_list, is_list, regen,
     chunks as _chunks,
     chunks as _chunks,
@@ -128,6 +129,14 @@ class Signature(dict):
         return s
         return s
     partial = clone
     partial = clone
 
 
+    def _freeze(self, _id=None):
+        opts = self.options
+        try:
+            tid = opts['task_id']
+        except KeyError:
+            tid = opts['task_id'] = _id or uuid()
+        return self.type.AsyncResult(tid)
+
     def replace(self, args=None, kwargs=None, options=None):
     def replace(self, args=None, kwargs=None, options=None):
         s = self.clone()
         s = self.clone()
         if args is not None:
         if args is not None:
@@ -320,6 +329,20 @@ class group(Signature):
                     map(Signature.clone, self.tasks), partial_args)
                     map(Signature.clone, self.tasks), partial_args)
         return self.type(tasks, result, gid, args)
         return self.type(tasks, result, gid, args)
 
 
+    def _freeze(self, _id=None):
+        opts = self.options
+        try:
+            gid = opts['group']
+        except KeyError:
+            gid = opts['group'] = uuid()
+        new_tasks, results = [], []
+        for task in self.tasks:
+            task = maybe_subtask(task).clone()
+            results.append(task._freeze())
+            new_tasks.append(task)
+        self.tasks = self.kwargs['tasks'] = new_tasks
+        return GroupResult(gid, results)
+
     def skew(self, start=1.0, stop=None, step=1.0):
     def skew(self, start=1.0, stop=None, step=1.0):
         _next_skew = fxrange(start, stop, step, repeatlast=True).next
         _next_skew = fxrange(start, stop, step, repeatlast=True).next
         for task in self.tasks:
         for task in self.tasks: