Browse Source

Adds parent_id + root_id task message fields, and to events. Closes #1318

Ask Solem 9 years ago
parent
commit
6066a45700

+ 2 - 2
celery/app/amqp.py

@@ -360,8 +360,8 @@ class AMQP(object):
             ),
             ),
             sent_event={
             sent_event={
                 'uuid': task_id,
                 'uuid': task_id,
-                'root': root_id,
-                'parent': parent_id,
+                'root_id': root_id,
+                'parent_id': parent_id,
                 'name': name,
                 'name': name,
                 'args': argsrepr,
                 'args': argsrepr,
                 'kwargs': kwargsrepr,
                 'kwargs': kwargsrepr,

+ 13 - 1
celery/app/base.py

@@ -622,6 +622,7 @@ class Celery(object):
         Otherwise supports the same arguments as :meth:`@-Task.apply_async`.
         Otherwise supports the same arguments as :meth:`@-Task.apply_async`.
 
 
         """
         """
+        parent = have_parent = None
         amqp = self.amqp
         amqp = self.amqp
         task_id = task_id or uuid()
         task_id = task_id or uuid()
         producer = producer or publisher  # XXX compat
         producer = producer or publisher  # XXX compat
@@ -633,6 +634,16 @@ class Celery(object):
             ), stacklevel=2)
             ), stacklevel=2)
         options = router.route(options, route_name or name, args, kwargs)
         options = router.route(options, route_name or name, args, kwargs)
 
 
+        if root_id is None:
+            parent, have_parent = get_current_worker_task(), True
+            if parent:
+                root_id = parent.request.root_id or parent.request.id
+        if parent_id is None:
+            if not have_parent:
+                parent, have_parent = get_current_worker_task(), True
+            if parent:
+                parent_id = parent.request.id
+
         message = amqp.create_task_message(
         message = amqp.create_task_message(
             task_id, name, args, kwargs, countdown, eta, group_id,
             task_id, name, args, kwargs, countdown, eta, group_id,
             expires, retries, chord,
             expires, retries, chord,
@@ -649,7 +660,8 @@ class Celery(object):
             amqp.send_task_message(P, name, message, **options)
             amqp.send_task_message(P, name, message, **options)
         result = (result_cls or self.AsyncResult)(task_id)
         result = (result_cls or self.AsyncResult)(task_id)
         if add_to_parent:
         if add_to_parent:
-            parent = get_current_worker_task()
+            if not have_parent:
+                parent, have_parent = get_current_worker_task(), True
             if parent:
             if parent:
                 parent.add_trail(result)
                 parent.add_trail(result)
         return result
         return result

+ 19 - 8
celery/app/trace.py

@@ -306,10 +306,11 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         I = Info(state, exc)
         I = Info(state, exc)
         R = I.handle_error_state(task, request, eager=eager)
         R = I.handle_error_state(task, request, eager=eager)
         if call_errbacks:
         if call_errbacks:
+            root_id = request.root_id or uuid
             group(
             group(
                 [signature(errback, app=app)
                 [signature(errback, app=app)
                  for errback in request.errbacks or []], app=app,
                  for errback in request.errbacks or []], app=app,
-            ).apply_async((uuid,))
+            ).apply_async((uuid,), parent_id=uuid, root_id=root_id)
         return I, R, I.state, I.retval
         return I, R, I.state, I.retval
 
 
     def trace_task(uuid, args, kwargs, request=None):
     def trace_task(uuid, args, kwargs, request=None):
@@ -336,6 +337,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
             push_task(task)
             push_task(task)
             task_request = Context(request or {}, args=args,
             task_request = Context(request or {}, args=args,
                                    called_directly=False, kwargs=kwargs)
                                    called_directly=False, kwargs=kwargs)
+            root_id = task_request.root_id or uuid
             push_request(task_request)
             push_request(task_request)
             try:
             try:
                 # -*- PRE -*-
                 # -*- PRE -*-
@@ -363,8 +365,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     I.handle_ignore(task, task_request)
                     I.handle_ignore(task, task_request)
                 except Retry as exc:
                 except Retry as exc:
                     I, R, state, retval = on_error(
                     I, R, state, retval = on_error(
-                        task_request, exc, uuid, RETRY, call_errbacks=False,
-                    )
+                        task_request, exc, uuid, RETRY, call_errbacks=False)
                 except Exception as exc:
                 except Exception as exc:
                     I, R, state, retval = on_error(task_request, exc, uuid)
                     I, R, state, retval = on_error(task_request, exc, uuid)
                 except BaseException as exc:
                 except BaseException as exc:
@@ -389,17 +390,27 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                                     else:
                                     else:
                                         sigs.append(sig)
                                         sigs.append(sig)
                                 for group_ in groups:
                                 for group_ in groups:
-                                    group.apply_async((retval,))
+                                    group.apply_async(
+                                        (retval,),
+                                        parent_id=uuid, root_id=root_id,
+                                    )
                                 if sigs:
                                 if sigs:
-                                    group(sigs).apply_async((retval,))
+                                    group(sigs).apply_async(
+                                        (retval,),
+                                        parent_id=uuid, root_id=root_id,
+                                    )
                             else:
                             else:
-                                signature(callbacks[0], app=app).delay(retval)
+                                signature(callbacks[0], app=app).apply_async(
+                                    (retval,), parent_id=uuid, root_id=root_id,
+                                )
 
 
                         # execute first task in chain
                         # execute first task in chain
-                        chain = task.request.chain
+                        chain = task_request.chain
                         if chain:
                         if chain:
                             signature(chain.pop(), app=app).apply_async(
                             signature(chain.pop(), app=app).apply_async(
-                                    (retval,), chain=chain)
+                                (retval,), chain=chain,
+                                parent_id=uuid, root_id=root_id,
+                            )
                         mark_as_done(
                         mark_as_done(
                             uuid, retval, task_request, publish_result,
                             uuid, retval, task_request, publish_result,
                         )
                         )

+ 90 - 40
celery/canvas.py

@@ -216,13 +216,17 @@ class Signature(dict):
         return s
         return s
     partial = clone
     partial = clone
 
 
-    def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
         opts = self.options
         opts = self.options
         try:
         try:
             tid = opts['task_id']
             tid = opts['task_id']
         except KeyError:
         except KeyError:
             tid = opts['task_id'] = _id or uuid()
             tid = opts['task_id'] = _id or uuid()
-        root_id = opts.setdefault('root_id', root_id)
+        if root_id:
+            opts['root_id'] = root_id
+        if parent_id:
+            opts['parent_id'] = parent_id
         if 'reply_to' not in opts:
         if 'reply_to' not in opts:
             opts['reply_to'] = self.app.oid
             opts['reply_to'] = self.app.oid
         if group_id:
         if group_id:
@@ -251,6 +255,9 @@ class Signature(dict):
     def set_immutable(self, immutable):
     def set_immutable(self, immutable):
         self.immutable = immutable
         self.immutable = immutable
 
 
+    def set_parent_id(self, parent_id):
+        self.parent_id = parent_id
+
     def apply_async(self, args=(), kwargs={}, route_name=None, **options):
     def apply_async(self, args=(), kwargs={}, route_name=None, **options):
         try:
         try:
             _apply = self._apply_async
             _apply = self._apply_async
@@ -362,6 +369,8 @@ class Signature(dict):
         except KeyError:
         except KeyError:
             return _partial(self.app.send_task, self['task'])
             return _partial(self.app.send_task, self['task'])
     id = _getitem_property('options.task_id')
     id = _getitem_property('options.task_id')
+    parent_id = _getitem_property('options.parent_id')
+    root_id = _getitem_property('options.root_id')
     task = _getitem_property('task')
     task = _getitem_property('task')
     args = _getitem_property('args')
     args = _getitem_property('args')
     kwargs = _getitem_property('kwargs')
     kwargs = _getitem_property('kwargs')
@@ -399,8 +408,8 @@ class chain(Signature):
             dict(self.options, **options) if options else self.options))
             dict(self.options, **options) if options else self.options))
 
 
     def run(self, args=(), kwargs={}, group_id=None, chord=None,
     def run(self, args=(), kwargs={}, group_id=None, chord=None,
-            task_id=None, link=None, link_error=None,
-            publisher=None, producer=None, root_id=None, app=None, **options):
+            task_id=None, link=None, link_error=None, publisher=None,
+            producer=None, root_id=None, parent_id=None, app=None, **options):
         app = app or self.app
         app = app or self.app
         use_link = self._use_link
         use_link = self._use_link
         args = (tuple(args) + tuple(self.args)
         args = (tuple(args) + tuple(self.args)
@@ -410,7 +419,7 @@ class chain(Signature):
             tasks, results = self._frozen
             tasks, results = self._frozen
         else:
         else:
             tasks, results = self.prepare_steps(
             tasks, results = self.prepare_steps(
-                args, self.tasks, root_id, link_error, app,
+                args, self.tasks, root_id, parent_id, link_error, app,
                 task_id, group_id, chord,
                 task_id, group_id, chord,
             )
             )
 
 
@@ -422,15 +431,16 @@ class chain(Signature):
                 chain=tasks if not use_link else None, **options)
                 chain=tasks if not use_link else None, **options)
             return results[0]
             return results[0]
 
 
-    def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
         _, results = self._frozen = self.prepare_steps(
         _, results = self._frozen = self.prepare_steps(
-            self.args, self.tasks, root_id, None,
+            self.args, self.tasks, root_id, parent_id, None,
             self.app, _id, group_id, chord, clone=False,
             self.app, _id, group_id, chord, clone=False,
         )
         )
         return results[-1]
         return results[-1]
 
 
     def prepare_steps(self, args, tasks,
     def prepare_steps(self, args, tasks,
-                      root_id=None, link_error=None, app=None,
+                      root_id=None, parent_id=None, link_error=None, app=None,
                       last_task_id=None, group_id=None, chord_body=None,
                       last_task_id=None, group_id=None, chord_body=None,
                       clone=True, from_dict=Signature.from_dict):
                       clone=True, from_dict=Signature.from_dict):
         app = app or self.app
         app = app or self.app
@@ -447,7 +457,8 @@ class chain(Signature):
         steps_pop = steps.pop
         steps_pop = steps.pop
         steps_extend = steps.extend
         steps_extend = steps.extend
 
 
-        next_step = prev_task = prev_res = None
+        next_step = prev_task = prev_prev_task = None
+        prev_res = prev_prev_res = None
         tasks, results = [], []
         tasks, results = [], []
         i = 0
         i = 0
         while steps:
         while steps:
@@ -469,21 +480,18 @@ class chain(Signature):
                 # splice the chain
                 # splice the chain
                 steps_extend(task.tasks)
                 steps_extend(task.tasks)
                 continue
                 continue
-            elif isinstance(task, group):
-                if prev_task:
-                    # automatically upgrade group(...) | s to chord(group, s)
-                    try:
-                        next_step = prev_task
-                        # for chords we freeze by pretending it's a normal
-                        # signature instead of a group.
-                        res = Signature.freeze(next_step, root_id=root_id)
-                        task = chord(
-                            task, body=next_step,
-                            task_id=res.task_id, root_id=root_id,
-                        )
-                    except IndexError:
-                        pass  # no callback, so keep as group.
 
 
+            if isinstance(task, group) and prev_task:
+                # automatically upgrade group(...) | s to chord(group, s)
+                # for chords we freeze by pretending it's a normal
+                # signature instead of a group.
+                tasks.pop()
+                results.pop()
+                prev_res = prev_prev_res
+                task = chord(
+                    task, body=prev_task,
+                    task_id=res.task_id, root_id=root_id, app=app,
+                )
             if is_last_task:
             if is_last_task:
                 # chain(task_id=id) means task id is set for the last task
                 # chain(task_id=id) means task id is set for the last task
                 # in the chain.  If the chord is part of a chord/group
                 # in the chain.  If the chord is part of a chord/group
@@ -496,26 +504,36 @@ class chain(Signature):
                 )
                 )
             else:
             else:
                 res = task.freeze(root_id=root_id)
                 res = task.freeze(root_id=root_id)
-            root_id = res.id if root_id is None else root_id
+
             i += 1
             i += 1
 
 
             if prev_task:
             if prev_task:
+                prev_task.set_parent_id(task.id)
                 if use_link:
                 if use_link:
                     # link previous task to this task.
                     # link previous task to this task.
                     task.link(prev_task)
                     task.link(prev_task)
-                    if not res.parent:
+                    if not res.parent and prev_res:
                         prev_res.parent = res.parent
                         prev_res.parent = res.parent
-                else:
+                elif prev_res:
                     prev_res.parent = res
                     prev_res.parent = res
 
 
+            if is_first_task and parent_id is not None:
+                task.set_parent_id(parent_id)
+
             if link_error:
             if link_error:
                 task.set(link_error=link_error)
                 task.set(link_error=link_error)
 
 
             tasks.append(task)
             tasks.append(task)
             results.append(res)
             results.append(res)
 
 
-            prev_task, prev_res = task, res
+            prev_prev_task, prev_task, prev_prev_res, prev_res = (
+                prev_task, task, prev_res, res,
+            )
 
 
+        if root_id is None and tasks:
+            root_id = tasks[-1].id
+            for task in reversed(tasks):
+                task.options['root_id'] = root_id
         return tasks, results
         return tasks, results
 
 
     def apply(self, args=(), kwargs={}, **options):
     def apply(self, args=(), kwargs={}, **options):
@@ -634,13 +652,16 @@ class chunks(Signature):
         return cls(task, it, n, app=app)()
         return cls(task, it, n, app=app)()
 
 
 
 
-def _maybe_group(tasks):
+def _maybe_group(tasks, app):
+    if isinstance(tasks, dict):
+        tasks = signature(tasks, app=app)
+
     if isinstance(tasks, group):
     if isinstance(tasks, group):
-        tasks = list(tasks.tasks)
+        tasks = tasks.tasks
     elif isinstance(tasks, abstract.CallableSignature):
     elif isinstance(tasks, abstract.CallableSignature):
         tasks = [tasks]
         tasks = [tasks]
     else:
     else:
-        tasks = [signature(t) for t in regen(tasks)]
+        tasks = [signature(t, app=app) for t in regen(tasks)]
     return tasks
     return tasks
 
 
 
 
@@ -649,8 +670,9 @@ class group(Signature):
     tasks = _getitem_property('kwargs.tasks')
     tasks = _getitem_property('kwargs.tasks')
 
 
     def __init__(self, *tasks, **options):
     def __init__(self, *tasks, **options):
+        app = options.get('app')
         if len(tasks) == 1:
         if len(tasks) == 1:
-            tasks = _maybe_group(tasks[0])
+            tasks = _maybe_group(tasks[0], app)
         Signature.__init__(
         Signature.__init__(
             self, 'celery.group', (), {'tasks': tasks}, **options
             self, 'celery.group', (), {'tasks': tasks}, **options
         )
         )
@@ -662,6 +684,9 @@ class group(Signature):
             d, group(d['kwargs']['tasks'], app=app, **d['options']),
             d, group(d['kwargs']['tasks'], app=app, **d['options']),
         )
         )
 
 
+    def __len__(self):
+        return len(self.tasks)
+
     def _prepared(self, tasks, partial_args, group_id, root_id, app, dict=dict,
     def _prepared(self, tasks, partial_args, group_id, root_id, app, dict=dict,
                   CallableSignature=abstract.CallableSignature,
                   CallableSignature=abstract.CallableSignature,
                   from_dict=Signature.from_dict):
                   from_dict=Signature.from_dict):
@@ -703,6 +728,10 @@ class group(Signature):
             options.pop('task_id', uuid()))
             options.pop('task_id', uuid()))
         return options, group_id, options.get('root_id')
         return options, group_id, options.get('root_id')
 
 
+    def set_parent_id(self, parent_id):
+        for task in self.tasks:
+            task.set_parent_id(parent_id)
+
     def apply_async(self, args=(), kwargs=None, add_to_parent=True,
     def apply_async(self, args=(), kwargs=None, add_to_parent=True,
                     producer=None, **options):
                     producer=None, **options):
         app = self.app
         app = self.app
@@ -757,7 +786,7 @@ class group(Signature):
     def __call__(self, *partial_args, **options):
     def __call__(self, *partial_args, **options):
         return self.apply_async(partial_args, **options)
         return self.apply_async(partial_args, **options)
 
 
-    def _freeze_unroll(self, new_tasks, group_id, chord, root_id):
+    def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
         stack = deque(self.tasks)
         stack = deque(self.tasks)
         while stack:
         while stack:
             task = maybe_signature(stack.popleft(), app=self._app).clone()
             task = maybe_signature(stack.popleft(), app=self._app).clone()
@@ -766,9 +795,11 @@ class group(Signature):
             else:
             else:
                 new_tasks.append(task)
                 new_tasks.append(task)
                 yield task.freeze(group_id=group_id,
                 yield task.freeze(group_id=group_id,
-                                  chord=chord, root_id=root_id)
+                                  chord=chord, root_id=root_id,
+                                  parent_id=parent_id)
 
 
-    def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
         opts = self.options
         opts = self.options
         try:
         try:
             gid = opts['task_id']
             gid = opts['task_id']
@@ -779,11 +810,12 @@ class group(Signature):
         if chord:
         if chord:
             opts['chord'] = chord
             opts['chord'] = chord
         root_id = opts.setdefault('root_id', root_id)
         root_id = opts.setdefault('root_id', root_id)
+        parent_id = opts.setdefault('parent_id', parent_id)
         new_tasks = []
         new_tasks = []
         # Need to unroll subgroups early so that chord gets the
         # Need to unroll subgroups early so that chord gets the
         # right result instance for chord_unlock etc.
         # right result instance for chord_unlock etc.
         results = list(self._freeze_unroll(
         results = list(self._freeze_unroll(
-            new_tasks, group_id, chord, root_id,
+            new_tasks, group_id, chord, root_id, parent_id,
         ))
         ))
         if isinstance(self.tasks, MutableSequence):
         if isinstance(self.tasks, MutableSequence):
             self.tasks[:] = new_tasks
             self.tasks[:] = new_tasks
@@ -819,16 +851,29 @@ class group(Signature):
 class chord(Signature):
 class chord(Signature):
 
 
     def __init__(self, header, body=None, task='celery.chord',
     def __init__(self, header, body=None, task='celery.chord',
-                 args=(), kwargs={}, **options):
+                 args=(), kwargs={}, app=None, **options):
         Signature.__init__(
         Signature.__init__(
             self, task, args,
             self, task, args,
-            dict(kwargs, header=_maybe_group(header),
+            dict(kwargs, header=_maybe_group(header, app),
                  body=maybe_signature(body, app=self._app)), **options
                  body=maybe_signature(body, app=self._app)), **options
         )
         )
         self.subtask_type = 'chord'
         self.subtask_type = 'chord'
 
 
-    def freeze(self, *args, **kwargs):
-        return self.body.freeze(*args, **kwargs)
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
+        if not isinstance(self.tasks, group):
+            self.tasks = group(self.tasks)
+        self.tasks.freeze(parent_id=parent_id, root_id=root_id)
+        self.id = self.tasks.id
+        return self.body.freeze(_id, parent_id=self.id, root_id=root_id)
+
+    def set_parent_id(self, parent_id):
+        tasks = self.tasks
+        if isinstance(tasks, group):
+            tasks = tasks.tasks
+        for task in tasks:
+            task.set_parent_id(parent_id)
+        self.parent_id = parent_id
 
 
     @classmethod
     @classmethod
     def from_dict(self, d, app=None):
     def from_dict(self, d, app=None):
@@ -848,7 +893,11 @@ class chord(Signature):
     def _get_app(self, body=None):
     def _get_app(self, body=None):
         app = self._app
         app = self._app
         if app is None:
         if app is None:
-            app = self.tasks[0]._app
+            try:
+                tasks = self.tasks.tasks  # is a group
+            except AttributeError:
+                tasks = self.tasks
+            app = tasks[0]._app
             if app is None and body is not None:
             if app is None and body is not None:
                 app = body._app
                 app = body._app
         return app if app is not None else current_app
         return app if app is not None else current_app
@@ -900,6 +949,7 @@ class chord(Signature):
         body.chord_size = self.__length_hint__()
         body.chord_size = self.__length_hint__()
         options = dict(self.options, **options) if options else self.options
         options = dict(self.options, **options) if options else self.options
         if options:
         if options:
+            options.pop('task_id', None)
             body.options.update(options)
             body.options.update(options)
 
 
         results = header.freeze(
         results = header.freeze(

+ 18 - 9
celery/events/state.py

@@ -233,11 +233,13 @@ class Task(object):
     state = states.PENDING
     state = states.PENDING
     clock = 0
     clock = 0
 
 
-    _fields = ('uuid', 'name', 'state', 'received', 'sent', 'started',
-               'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs',
-               'eta', 'expires', 'retries', 'worker', 'result', 'exception',
-               'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
-               'clock', 'client')
+    _fields = (
+        'uuid', 'name', 'state', 'received', 'sent', 'started',
+        'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs',
+        'eta', 'expires', 'retries', 'worker', 'result', 'exception',
+        'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
+        'clock', 'client', 'root_id', 'parent_id',
+    )
     if not PYPY:
     if not PYPY:
         __slots__ = ('__dict__', '__weakref__')
         __slots__ = ('__dict__', '__weakref__')
 
 
@@ -249,12 +251,19 @@ class Task(object):
     #: that state. ``(RECEIVED, ('name', 'args')``, means the name and args
     #: that state. ``(RECEIVED, ('name', 'args')``, means the name and args
     #: fields are always taken from the RECEIVED state, and any values for
     #: fields are always taken from the RECEIVED state, and any values for
     #: these fields received before or after is simply ignored.
     #: these fields received before or after is simply ignored.
-    merge_rules = {states.RECEIVED: ('name', 'args', 'kwargs',
-                                     'retries', 'eta', 'expires')}
+    merge_rules = {
+        states.RECEIVED: (
+            'name', 'args', 'kwargs', 'parent_id',
+            'root_id' 'retries', 'eta', 'expires',
+        ),
+    }
 
 
     #: meth:`info` displays these fields by default.
     #: meth:`info` displays these fields by default.
-    _info_fields = ('args', 'kwargs', 'retries', 'result', 'eta', 'runtime',
-                    'expires', 'exception', 'exchange', 'routing_key')
+    _info_fields = (
+        'args', 'kwargs', 'retries', 'result', 'eta', 'runtime',
+        'expires', 'exception', 'exchange', 'routing_key',
+        'root_id', 'parent_id',
+    )
 
 
     def __init__(self, uuid=None, **kwargs):
     def __init__(self, uuid=None, **kwargs):
         self.uuid = uuid
         self.uuid = uuid

+ 3 - 1
celery/result.py

@@ -122,7 +122,7 @@ class AsyncResult(ResultBase):
                                 reply=wait, timeout=timeout)
                                 reply=wait, timeout=timeout)
 
 
     def get(self, timeout=None, propagate=True, interval=0.5,
     def get(self, timeout=None, propagate=True, interval=0.5,
-            no_ack=True, follow_parents=True,
+            no_ack=True, follow_parents=True, callback=None,
             EXCEPTION_STATES=states.EXCEPTION_STATES,
             EXCEPTION_STATES=states.EXCEPTION_STATES,
             PROPAGATE_STATES=states.PROPAGATE_STATES):
             PROPAGATE_STATES=states.PROPAGATE_STATES):
         """Wait until task is ready, and return its result.
         """Wait until task is ready, and return its result.
@@ -174,6 +174,8 @@ class AsyncResult(ResultBase):
             status = meta['status']
             status = meta['status']
             if status in PROPAGATE_STATES and propagate:
             if status in PROPAGATE_STATES and propagate:
                 raise meta['result']
                 raise meta['result']
+            if callback is not None:
+                callback(self.id, meta['result'])
             return meta['result']
             return meta['result']
     wait = get  # deprecated alias to :meth:`get`.
     wait = get  # deprecated alias to :meth:`get`.
 
 

+ 58 - 4
celery/tests/app/test_builtins.py

@@ -133,18 +133,72 @@ class test_chain(BuiltinsCase):
         self.assertTrue(result.parent.parent)
         self.assertTrue(result.parent.parent)
         self.assertIsNone(result.parent.parent.parent)
         self.assertIsNone(result.parent.parent.parent)
 
 
+    def test_group_to_chord__freeze_parent_id(self):
+        def using_freeze(c):
+            c.freeze(parent_id='foo', root_id='root')
+            return c._frozen[0]
+        self.assert_group_to_chord_parent_ids(using_freeze)
+
+    def assert_group_to_chord_parent_ids(self, freezefun):
+        c = (
+            self.add.s(5, 5) |
+            group([self.add.s(i, i) for i in range(5)], app=self.app) |
+            self.add.si(10, 10) |
+            self.add.si(20, 20) |
+            self.add.si(30, 30)
+        )
+        tasks = freezefun(c)
+        self.assertEqual(tasks[-1].parent_id, 'foo')
+        self.assertEqual(tasks[-1].root_id, 'root')
+        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].root_id, 'root')
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].tasks.id)
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
+        self.assertEqual(tasks[-2].body.root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[0].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[0].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[1].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[1].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[2].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[3].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[3].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[4].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[4].root_id, 'root')
+        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
+        self.assertEqual(tasks[-3].root_id, 'root')
+        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
+        self.assertEqual(tasks[-4].root_id, 'root')
+
     def test_group_to_chord(self):
     def test_group_to_chord(self):
         c = (
         c = (
+            self.add.s(5) |
             group([self.add.s(i, i) for i in range(5)], app=self.app) |
             group([self.add.s(i, i) for i in range(5)], app=self.app) |
             self.add.s(10) |
             self.add.s(10) |
             self.add.s(20) |
             self.add.s(20) |
             self.add.s(30)
             self.add.s(30)
         )
         )
         c._use_link = True
         c._use_link = True
-        tasks, _ = c.prepare_steps((), c.tasks)
-        self.assertIsInstance(tasks[-1], chord)
-        self.assertTrue(tasks[-1].body.options['link'])
-        self.assertTrue(tasks[-1].body.options['link'][0].options['link'])
+        tasks, results = c.prepare_steps((), c.tasks)
+
+        self.assertEqual(tasks[-1].args[0], 5)
+        self.assertIsInstance(tasks[-2], chord)
+        self.assertEqual(len(tasks[-2].tasks), 5)
+        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].root_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].body.args[0], 10)
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
+
+        self.assertEqual(tasks[-3].args[0], 20)
+        self.assertEqual(tasks[-3].root_id, tasks[-1].id)
+        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
+
+        self.assertEqual(tasks[-4].args[0], 30)
+        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
+        self.assertEqual(tasks[-4].root_id, tasks[-1].id)
+
+        self.assertTrue(tasks[-2].body.options['link'])
+        self.assertTrue(tasks[-2].body.options['link'][0].options['link'])
 
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = True
         c2._use_link = True

+ 35 - 1
celery/tests/tasks/test_canvas.py

@@ -14,7 +14,7 @@ from celery.canvas import (
 )
 )
 from celery.result import EagerResult
 from celery.result import EagerResult
 
 
-from celery.tests.case import AppCase, Mock
+from celery.tests.case import AppCase, ContextMock, Mock
 
 
 SIG = Signature({'task': 'TASK',
 SIG = Signature({'task': 'TASK',
                  'args': ('A1',),
                  'args': ('A1',),
@@ -233,6 +233,40 @@ class test_chain(CanvasCase):
         self.assertIsNone(chain(app=self.app)())
         self.assertIsNone(chain(app=self.app)())
         self.assertIsNone(chain(app=self.app).apply_async())
         self.assertIsNone(chain(app=self.app).apply_async())
 
 
+    def test_root_id_parent_id(self):
+        self.app.conf.task_protocol = 2
+        c = chain(self.add.si(i, i) for i in range(4))
+        c.freeze()
+        tasks, _ = c._frozen
+        for i, task in enumerate(tasks):
+            self.assertEqual(task.root_id, tasks[-1].id)
+            try:
+                self.assertEqual(task.parent_id, tasks[i + 1].id)
+            except IndexError:
+                assert i == len(tasks) - 1
+            else:
+                valid_parents = i
+        self.assertEqual(valid_parents, len(tasks) - 2)
+
+        self.assert_sent_with_ids(tasks[-1], tasks[-1].id, 'foo',
+                                  parent_id='foo')
+        self.assertTrue(tasks[-2].options['parent_id'])
+        self.assert_sent_with_ids(tasks[-2], tasks[-1].id, tasks[-1].id)
+        self.assert_sent_with_ids(tasks[-3], tasks[-1].id, tasks[-2].id)
+        self.assert_sent_with_ids(tasks[-4], tasks[-1].id, tasks[-3].id)
+
+
+    def assert_sent_with_ids(self, task, rid, pid, **options):
+        self.app.amqp.send_task_message = Mock(name='send_task_message')
+        self.app.backend = Mock()
+        self.app.producer_or_acquire = ContextMock()
+
+        res = task.apply_async(**options)
+        self.assertTrue(self.app.amqp.send_task_message.called)
+        message = self.app.amqp.send_task_message.call_args[0][2]
+        self.assertEqual(message.headers['parent_id'], pid)
+        self.assertEqual(message.headers['root_id'], rid)
+
     def test_call_no_tasks(self):
     def test_call_no_tasks(self):
         x = chain()
         x = chain()
         self.assertFalse(x())
         self.assertFalse(x())

+ 0 - 1
celery/worker/consumer.py

@@ -458,7 +458,6 @@ class Consumer(object):
         callbacks = self.on_task_message
         callbacks = self.on_task_message
 
 
         def on_task_received(message):
         def on_task_received(message):
-
             # payload will only be set for v1 protocol, since v2
             # payload will only be set for v1 protocol, since v2
             # will defer deserializing the message body to the pool.
             # will defer deserializing the message body to the pool.
             payload = None
             payload = None

+ 5 - 3
celery/worker/request.py

@@ -77,9 +77,9 @@ class Request(object):
 
 
     if not IS_PYPY:  # pragma: no cover
     if not IS_PYPY:  # pragma: no cover
         __slots__ = (
         __slots__ = (
-            'app', 'type', 'name', 'id', 'on_ack', 'body',
-            'hostname', 'eventer', 'connection_errors', 'task', 'eta',
-            'expires', 'request_dict', 'on_reject', 'utc',
+            'app', 'type', 'name', 'id', 'root_id', 'parent_id',
+            'on_ack', 'body', 'hostname', 'eventer', 'connection_errors',
+            'task', 'eta', 'expires', 'request_dict', 'on_reject', 'utc',
             'content_type', 'content_encoding', 'argsrepr', 'kwargsrepr',
             'content_type', 'content_encoding', 'argsrepr', 'kwargsrepr',
             '__weakref__', '__dict__',
             '__weakref__', '__dict__',
         )
         )
@@ -108,6 +108,8 @@ class Request(object):
 
 
         self.id = headers['id']
         self.id = headers['id']
         type = self.type = self.name = headers['task']
         type = self.type = self.name = headers['task']
+        self.root_id = headers.get('root_id')
+        self.parent_id = headers.get('parent_id')
         if 'shadow' in headers:
         if 'shadow' in headers:
             self.name = headers['shadow']
             self.name = headers['shadow']
         if 'timelimit' in headers:
         if 'timelimit' in headers:

+ 2 - 2
docs/userguide/monitoring.rst

@@ -650,7 +650,7 @@ task-sent
 ~~~~~~~~~
 ~~~~~~~~~
 
 
 :signature: ``task-sent(uuid, name, args, kwargs, retries, eta, expires,
 :signature: ``task-sent(uuid, name, args, kwargs, retries, eta, expires,
-              queue, exchange, routing_key)``
+              queue, exchange, routing_key, root_id, parent_id)``
 
 
 Sent when a task message is published and
 Sent when a task message is published and
 the :setting:`task_send_sent_event` setting is enabled.
 the :setting:`task_send_sent_event` setting is enabled.
@@ -661,7 +661,7 @@ task-received
 ~~~~~~~~~~~~~
 ~~~~~~~~~~~~~
 
 
 :signature: ``task-received(uuid, name, args, kwargs, retries, eta, hostname,
 :signature: ``task-received(uuid, name, args, kwargs, retries, eta, hostname,
-              timestamp)``
+              timestamp, root_id, parent_id)``
 
 
 Sent when the worker receives a task.
 Sent when the worker receives a task.
 
 

+ 10 - 0
funtests/stress/stress/app.py

@@ -63,6 +63,16 @@ def add(x, y):
     return x + y
     return x + y
 
 
 
 
+@app.task(bind=True)
+def ids(self, i):
+    return (self.request.root_id, self.request.parent_id, i)
+
+
+@app.task(bind=True)
+def collect_ids(self, ids, i):
+    return ids, (self.request.root_id, self.request.parent_id, i)
+
+
 @app.task
 @app.task
 def xsum(x):
 def xsum(x):
     return sum(x)
     return sum(x)

+ 107 - 4
funtests/stress/stress/suite.py

@@ -10,7 +10,7 @@ from collections import OrderedDict, defaultdict, namedtuple
 from itertools import count
 from itertools import count
 from time import sleep
 from time import sleep
 
 
-from celery import group, VERSION_BANNER
+from celery import VERSION_BANNER, chain, group, uuid
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.five import items, monotonic, range, values
 from celery.five import items, monotonic, range, values
 from celery.utils.debug import blockdetection
 from celery.utils.debug import blockdetection
@@ -18,12 +18,13 @@ from celery.utils.text import pluralize, truncate
 from celery.utils.timeutils import humanize_seconds
 from celery.utils.timeutils import humanize_seconds
 
 
 from .app import (
 from .app import (
-    marker, _marker, add, any_, exiting, kill, sleeping,
+    marker, _marker, add, any_, collect_ids, exiting, ids, kill, sleeping,
     sleeping_ignore_limits, any_returning, print_unicode,
     sleeping_ignore_limits, any_returning, print_unicode,
 )
 )
 from .data import BIG, SMALL
 from .data import BIG, SMALL
 from .fbi import FBI
 from .fbi import FBI
 
 
+
 BANNER = """\
 BANNER = """\
 Celery stress-suite v{version}
 Celery stress-suite v{version}
 
 
@@ -50,6 +51,10 @@ Progress = namedtuple('Progress', (
 Inf = float('Inf')
 Inf = float('Inf')
 
 
 
 
+def assert_equal(a, b):
+    assert a == b, '{0!r} != {1!r}'.format(a, b)
+
+
 class StopSuite(Exception):
 class StopSuite(Exception):
     pass
     pass
 
 
@@ -163,6 +168,7 @@ class BaseSuite(object):
         )
         )
 
 
     def runtest(self, fun, n=50, index=0, repeats=1):
     def runtest(self, fun, n=50, index=0, repeats=1):
+        n = getattr(fun, '__iterations__', None) or n
         print('{0}: [[[{1}({2})]]]'.format(repeats, fun.__name__, n))
         print('{0}: [[[{1}({2})]]]'.format(repeats, fun.__name__, n))
         with blockdetection(self.block_timeout):
         with blockdetection(self.block_timeout):
             with self.fbi.investigation():
             with self.fbi.investigation():
@@ -185,6 +191,8 @@ class BaseSuite(object):
                             raise
                             raise
                         except Exception as exc:
                         except Exception as exc:
                             print('-> {0!r}'.format(exc))
                             print('-> {0!r}'.format(exc))
+                            import traceback
+                            print(traceback.format_exc())
                             print(pstatus(self.progress))
                             print(pstatus(self.progress))
                         else:
                         else:
                             print(pstatus(self.progress))
                             print(pstatus(self.progress))
@@ -238,13 +246,14 @@ class BaseSuite(object):
 _creation_counter = count(0)
 _creation_counter = count(0)
 
 
 
 
-def testcase(*groups):
+def testcase(*groups, **kwargs):
     if not groups:
     if not groups:
         raise ValueError('@testcase requires at least one group name')
         raise ValueError('@testcase requires at least one group name')
 
 
     def _mark_as_case(fun):
     def _mark_as_case(fun):
         fun.__testgroup__ = groups
         fun.__testgroup__ = groups
         fun.__testsort__ = next(_creation_counter)
         fun.__testsort__ = next(_creation_counter)
+        fun.__iterations__ = kwargs.get('iterations')
         return fun
         return fun
 
 
     return _mark_as_case
     return _mark_as_case
@@ -262,12 +271,106 @@ def _is_descriptor(obj, attr):
 
 
 class Suite(BaseSuite):
 class Suite(BaseSuite):
 
 
+    @testcase('all', 'green', iterations=1)
+    def chain(self):
+        c = add.s(4, 4) | add.s(8) | add.s(16)
+        assert_equal(self.join(c()), 32)
+
+    @testcase('all', 'green', iterations=1)
+    def chaincomplex(self):
+        c = (
+            add.s(2, 2) | (
+                add.s(4) | add.s(8) | add.s(16)
+            ) |
+            group(add.s(i) for i in range(4))
+        )
+        res = c()
+        assert_equal(res.get(), [32, 33, 34, 35])
+
+    @testcase('all', 'green', iterations=1)
+    def parentids_chain(self):
+        c = chain(ids.si(i) for i in range(248))
+        c.freeze()
+        res = c()
+        res.get(timeout=5)
+        self.assert_ids(res, len(c.tasks) - 1)
+
+    @testcase('all', 'green', iterations=1)
+    def parentids_group(self):
+        g = ids.si(1) | ids.si(2) | group(ids.si(i) for i in range(2, 50))
+        res = g()
+        expected_root_id = res.parent.parent.id
+        expected_parent_id = res.parent.id
+        values = res.get(timeout=5)
+
+        for i, r in enumerate(values):
+            root_id, parent_id, value = r
+            assert_equal(root_id, expected_root_id)
+            assert_equal(parent_id, expected_parent_id)
+            assert_equal(value, i + 2)
+
+    def assert_ids(self, res, len):
+        i, root = len, res
+        while root.parent:
+            root = root.parent
+        node = res
+        while node:
+            root_id, parent_id, value = node.get(timeout=5)
+            assert_equal(value, i)
+            assert_equal(root_id, root.id)
+            if node.parent:
+                assert_equal(parent_id, node.parent.id)
+            node = node.parent
+            i -= 1
+
+    @testcase('redis', iterations=1)
+    def parentids_chord(self):
+        self.assert_parentids_chord()
+        self.assert_parentids_chord(uuid(), uuid())
+
+    def assert_parentids_chord(self, base_root=None, base_parent=None):
+        g = (
+            ids.si(1) |
+            ids.si(2) |
+            group(ids.si(i) for i in range(3, 50)) |
+            collect_ids.s(i=50) |
+            ids.si(51)
+        )
+        g.freeze(root_id=base_root, parent_id=base_parent)
+        res = g.apply_async(root_id=base_root, parent_id=base_parent)
+        expected_root_id = base_root or res.parent.parent.parent.id
+
+        root_id, parent_id, value = res.get(timeout=5)
+        assert_equal(value, 51)
+        assert_equal(root_id, expected_root_id)
+        assert_equal(parent_id, res.parent.id)
+
+        prev, (root_id, parent_id, value) = res.parent.get(timeout=5)
+        assert_equal(value, 50)
+        assert_equal(root_id, expected_root_id)
+        assert_equal(parent_id, res.parent.parent.id)
+
+        for i, p in enumerate(prev):
+            root_id, parent_id, value = p
+            assert_equal(root_id, expected_root_id)
+            assert_equal(parent_id, res.parent.parent.id)
+
+        root_id, parent_id, value = res.parent.parent.get(timeout=5)
+        assert_equal(value, 2)
+        assert_equal(parent_id, res.parent.parent.parent.id)
+        assert_equal(root_id, expected_root_id)
+
+        root_id, parent_id, value = res.parent.parent.parent.get(timeout=5)
+        assert_equal(value, 1)
+        assert_equal(root_id, expected_root_id)
+        assert_equal(parent_id, base_parent)
+
     @testcase('all', 'green')
     @testcase('all', 'green')
     def manyshort(self):
     def manyshort(self):
         self.join(group(add.s(i, i) for i in range(1000))(),
         self.join(group(add.s(i, i) for i in range(1000))(),
                   timeout=10, propagate=True)
                   timeout=10, propagate=True)
 
 
-    @testcase('all', 'green')
+    @testcase('all', 'green', iterations=1)
     def unicodetask(self):
     def unicodetask(self):
         self.join(group(print_unicode.s() for _ in range(5))(),
         self.join(group(print_unicode.s() for _ in range(5))(),
                   timeout=1, propagate=True)
                   timeout=1, propagate=True)

+ 1 - 1
funtests/stress/stress/templates.py

@@ -57,7 +57,7 @@ class default(object):
     result_serializer = 'json'
     result_serializer = 'json'
     result_persistent = True
     result_persistent = True
     result_expires = 300
     result_expires = 300
-    result_cache_max = -1
+    result_cache_max = 100
     task_default_queue = CSTRESS_QUEUE
     task_default_queue = CSTRESS_QUEUE
     task_queues = [
     task_queues = [
         Queue(CSTRESS_QUEUE,
         Queue(CSTRESS_QUEUE,