Преглед изворни кода

Implements the new chain field in task protocol 2. Closes #1078

The chain is now stored in reverse order, so the first task in the list
is the last.  This means we can do a quick pop instead of a slow head remove.
Ask Solem пре 9 година
родитељ
комит
d3e1282664
5 измењених фајлова са 51 додато и 18 уклоњено
  1. 2 2
      celery/app/amqp.py
  2. 2 2
      celery/app/base.py
  3. 1 0
      celery/app/task.py
  4. 6 0
      celery/app/trace.py
  5. 40 14
      celery/canvas.py

+ 2 - 2
celery/app/amqp.py

@@ -297,7 +297,7 @@ class AMQP(object):
                    callbacks=None, errbacks=None, reply_to=None,
                    time_limit=None, soft_time_limit=None,
                    create_sent_event=False, root_id=None, parent_id=None,
-                   shadow=None, now=None, timezone=None):
+                   shadow=None, chain=None, now=None, timezone=None):
         args = args or ()
         kwargs = kwargs or {}
         utc = self.utc
@@ -354,7 +354,7 @@ class AMQP(object):
                 args, kwargs, {
                     'callbacks': callbacks,
                     'errbacks': errbacks,
-                    'chain': None,  # TODO
+                    'chain': chain,
                     'chord': chord,
                 },
             ),

+ 2 - 2
celery/app/base.py

@@ -612,7 +612,7 @@ class Celery(object):
                   add_to_parent=True, group_id=None, retries=0, chord=None,
                   reply_to=None, time_limit=None, soft_time_limit=None,
                   root_id=None, parent_id=None, route_name=None,
-                  shadow=None, **options):
+                  shadow=None, chain=None, **options):
         """Send task by name.
 
         :param name: Name of task to call (e.g. `"tasks.add"`).
@@ -639,7 +639,7 @@ class Celery(object):
             maybe_list(link), maybe_list(link_error),
             reply_to or self.oid, time_limit, soft_time_limit,
             self.conf.task_send_sent_event,
-            root_id, parent_id, shadow,
+            root_id, parent_id, shadow, chain,
         )
 
         if connection:

+ 1 - 0
celery/app/task.py

@@ -86,6 +86,7 @@ class Context(object):
     taskset = None   # compat alias to group
     group = None
     chord = None
+    chain = None
     utc = None
     called_directly = True
     callbacks = None

+ 6 - 0
celery/app/trace.py

@@ -394,6 +394,12 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                                     group(sigs).apply_async((retval,))
                             else:
                                 signature(callbacks[0], app=app).delay(retval)
+
+                        # execute first task in chain
+                        chain = task.request.chain
+                        if chain:
+                            signature(chain.pop(), app=app).apply_async(
+                                    (retval,), chain=chain)
                         mark_as_done(
                             uuid, retval, task_request, publish_result,
                         )

+ 40 - 14
celery/canvas.py

@@ -27,8 +27,7 @@ from celery.local import try_import
 from celery.result import GroupResult
 from celery.utils import abstract
 from celery.utils.functional import (
-    maybe_list, is_list, regen,
-    chunks as _chunks,
+    maybe_list, is_list, noop, regen, chunks as _chunks,
 )
 from celery.utils.text import truncate
 
@@ -383,6 +382,7 @@ class chain(Signature):
         Signature.__init__(
             self, 'celery.chain', (), {'tasks': tasks}, **options
         )
+        self._use_link = options.pop('use_link', None)
         self.subtask_type = 'chain'
         self._frozen = None
 
@@ -402,6 +402,7 @@ class chain(Signature):
             task_id=None, link=None, link_error=None,
             publisher=None, producer=None, root_id=None, app=None, **options):
         app = app or self.app
+        use_link = self._use_link
         args = (tuple(args) + tuple(self.args)
                 if args and not self.immutable else self.args)
 
@@ -413,12 +414,22 @@ class chain(Signature):
                 task_id, group_id, chord,
             )
 
+
         if results:
             # make sure we can do a link() and link_error() on a chain object.
-            if link:
-                tasks[-1].set(link=link)
-            tasks[0].apply_async(**options)
-            return results[-1]
+            if self._use_link:
+                # old task protocol used link for chains, last is last.
+                if link:
+                    tasks[-1].set(link=link)
+                tasks[0].apply_async(**options)
+                return results[-1]
+            else:
+                # -- using chain message field means last task is first.
+                if link:
+                    tasks[0].set(link=link)
+                first_task = tasks.pop()
+                first_task.apply_async(chain=tasks, **options)
+                return results[0]
 
     def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
         _, results = self._frozen = self.prepare_steps(
@@ -432,12 +443,25 @@ class chain(Signature):
                       last_task_id=None, group_id=None, chord_body=None,
                       clone=True, from_dict=Signature.from_dict):
         app = app or self.app
+        # use chain message field for protocol 2 and later.
+        # this avoids pickle blowing the stack on the recursion
+        # required by linking task together in a tree structure.
+        # (why is pickle using recursion? or better yet why cannot python
+        #  do tail call optimization making recursion actually useful?)
+        use_link = self._use_link
+        if use_link is None and app.conf.task_protocol > 1:
+            use_link = False
         steps = deque(tasks)
+
+        steps_pop = steps.popleft if use_link else steps.pop
+        steps_extend = steps.extendleft if use_link else steps.extend
+        extend_order = reverse if use_link else noop
+
         next_step = prev_task = prev_res = None
         tasks, results = [], []
         i = 0
         while steps:
-            task = steps.popleft()
+            task = steps_pop()
 
             if not isinstance(task, abstract.CallableSignature):
                 task = from_dict(task, app=app)
@@ -452,12 +476,12 @@ class chain(Signature):
 
             if isinstance(task, chain):
                 # splice the chain
-                steps.extendleft(reversed(task.tasks))
+                steps_extend(extend_order(task.tasks))
                 continue
             elif isinstance(task, group) and steps:
                 # automatically upgrade group(...) | s to chord(group, s)
                 try:
-                    next_step = steps.popleft()
+                    next_step = steps_pop()
                     # for chords we freeze by pretending it's a normal
                     # signature instead of a group.
                     res = Signature.freeze(next_step, root_id=root_id)
@@ -484,11 +508,13 @@ class chain(Signature):
             i += 1
 
             if prev_task:
-                # link previous task to this task.
-                prev_task.link(task)
-                # set AsyncResult.parent
-                if not res.parent:
-                    res.parent = prev_res
+                if use_link:
+                    # link previous task to this task.
+                    prev_task.link(task)
+                    if not res.parent:
+                        res.parent = prev_res
+                else:
+                    prev_res.parent = res
 
             if link_error:
                 task.set(link_error=link_error)