Browse Source

Adds parent_id + root_id task message fields, and to events. Closes #1318

Ask Solem 9 năm trước cách đây
mục cha
commit
6066a45700

+ 2 - 2
celery/app/amqp.py

@@ -360,8 +360,8 @@ class AMQP(object):
             ),
             sent_event={
                 'uuid': task_id,
-                'root': root_id,
-                'parent': parent_id,
+                'root_id': root_id,
+                'parent_id': parent_id,
                 'name': name,
                 'args': argsrepr,
                 'kwargs': kwargsrepr,

+ 13 - 1
celery/app/base.py

@@ -622,6 +622,7 @@ class Celery(object):
         Otherwise supports the same arguments as :meth:`@-Task.apply_async`.
 
         """
+        parent = have_parent = None
         amqp = self.amqp
         task_id = task_id or uuid()
         producer = producer or publisher  # XXX compat
@@ -633,6 +634,16 @@ class Celery(object):
             ), stacklevel=2)
         options = router.route(options, route_name or name, args, kwargs)
 
+        if root_id is None:
+            parent, have_parent = get_current_worker_task(), True
+            if parent:
+                root_id = parent.request.root_id or parent.request.id
+        if parent_id is None:
+            if not have_parent:
+                parent, have_parent = get_current_worker_task(), True
+            if parent:
+                parent_id = parent.request.id
+
         message = amqp.create_task_message(
             task_id, name, args, kwargs, countdown, eta, group_id,
             expires, retries, chord,
@@ -649,7 +660,8 @@ class Celery(object):
             amqp.send_task_message(P, name, message, **options)
         result = (result_cls or self.AsyncResult)(task_id)
         if add_to_parent:
-            parent = get_current_worker_task()
+            if not have_parent:
+                parent, have_parent = get_current_worker_task(), True
             if parent:
                 parent.add_trail(result)
         return result

+ 19 - 8
celery/app/trace.py

@@ -306,10 +306,11 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         I = Info(state, exc)
         R = I.handle_error_state(task, request, eager=eager)
         if call_errbacks:
+            root_id = request.root_id or uuid
             group(
                 [signature(errback, app=app)
                  for errback in request.errbacks or []], app=app,
-            ).apply_async((uuid,))
+            ).apply_async((uuid,), parent_id=uuid, root_id=root_id)
         return I, R, I.state, I.retval
 
     def trace_task(uuid, args, kwargs, request=None):
@@ -336,6 +337,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
             push_task(task)
             task_request = Context(request or {}, args=args,
                                    called_directly=False, kwargs=kwargs)
+            root_id = task_request.root_id or uuid
             push_request(task_request)
             try:
                 # -*- PRE -*-
@@ -363,8 +365,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     I.handle_ignore(task, task_request)
                 except Retry as exc:
                     I, R, state, retval = on_error(
-                        task_request, exc, uuid, RETRY, call_errbacks=False,
-                    )
+                        task_request, exc, uuid, RETRY, call_errbacks=False)
                 except Exception as exc:
                     I, R, state, retval = on_error(task_request, exc, uuid)
                 except BaseException as exc:
@@ -389,17 +390,27 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                                     else:
                                         sigs.append(sig)
                                 for group_ in groups:
-                                    group.apply_async((retval,))
+                                    group.apply_async(
+                                        (retval,),
+                                        parent_id=uuid, root_id=root_id,
+                                    )
                                 if sigs:
-                                    group(sigs).apply_async((retval,))
+                                    group(sigs).apply_async(
+                                        (retval,),
+                                        parent_id=uuid, root_id=root_id,
+                                    )
                             else:
-                                signature(callbacks[0], app=app).delay(retval)
+                                signature(callbacks[0], app=app).apply_async(
+                                    (retval,), parent_id=uuid, root_id=root_id,
+                                )
 
                         # execute first task in chain
-                        chain = task.request.chain
+                        chain = task_request.chain
                         if chain:
                             signature(chain.pop(), app=app).apply_async(
-                                    (retval,), chain=chain)
+                                (retval,), chain=chain,
+                                parent_id=uuid, root_id=root_id,
+                            )
                         mark_as_done(
                             uuid, retval, task_request, publish_result,
                         )

+ 90 - 40
celery/canvas.py

@@ -216,13 +216,17 @@ class Signature(dict):
         return s
     partial = clone
 
-    def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
         opts = self.options
         try:
             tid = opts['task_id']
         except KeyError:
             tid = opts['task_id'] = _id or uuid()
-        root_id = opts.setdefault('root_id', root_id)
+        if root_id:
+            opts['root_id'] = root_id
+        if parent_id:
+            opts['parent_id'] = parent_id
         if 'reply_to' not in opts:
             opts['reply_to'] = self.app.oid
         if group_id:
@@ -251,6 +255,9 @@ class Signature(dict):
     def set_immutable(self, immutable):
         self.immutable = immutable
 
+    def set_parent_id(self, parent_id):
+        self.parent_id = parent_id
+
     def apply_async(self, args=(), kwargs={}, route_name=None, **options):
         try:
             _apply = self._apply_async
@@ -362,6 +369,8 @@ class Signature(dict):
         except KeyError:
             return _partial(self.app.send_task, self['task'])
     id = _getitem_property('options.task_id')
+    parent_id = _getitem_property('options.parent_id')
+    root_id = _getitem_property('options.root_id')
     task = _getitem_property('task')
     args = _getitem_property('args')
     kwargs = _getitem_property('kwargs')
@@ -399,8 +408,8 @@ class chain(Signature):
             dict(self.options, **options) if options else self.options))
 
     def run(self, args=(), kwargs={}, group_id=None, chord=None,
-            task_id=None, link=None, link_error=None,
-            publisher=None, producer=None, root_id=None, app=None, **options):
+            task_id=None, link=None, link_error=None, publisher=None,
+            producer=None, root_id=None, parent_id=None, app=None, **options):
         app = app or self.app
         use_link = self._use_link
         args = (tuple(args) + tuple(self.args)
@@ -410,7 +419,7 @@ class chain(Signature):
             tasks, results = self._frozen
         else:
             tasks, results = self.prepare_steps(
-                args, self.tasks, root_id, link_error, app,
+                args, self.tasks, root_id, parent_id, link_error, app,
                 task_id, group_id, chord,
             )
 
@@ -422,15 +431,16 @@ class chain(Signature):
                 chain=tasks if not use_link else None, **options)
             return results[0]
 
-    def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
         _, results = self._frozen = self.prepare_steps(
-            self.args, self.tasks, root_id, None,
+            self.args, self.tasks, root_id, parent_id, None,
             self.app, _id, group_id, chord, clone=False,
         )
         return results[-1]
 
     def prepare_steps(self, args, tasks,
-                      root_id=None, link_error=None, app=None,
+                      root_id=None, parent_id=None, link_error=None, app=None,
                       last_task_id=None, group_id=None, chord_body=None,
                       clone=True, from_dict=Signature.from_dict):
         app = app or self.app
@@ -447,7 +457,8 @@ class chain(Signature):
         steps_pop = steps.pop
         steps_extend = steps.extend
 
-        next_step = prev_task = prev_res = None
+        next_step = prev_task = prev_prev_task = None
+        prev_res = prev_prev_res = None
         tasks, results = [], []
         i = 0
         while steps:
@@ -469,21 +480,18 @@ class chain(Signature):
                 # splice the chain
                 steps_extend(task.tasks)
                 continue
-            elif isinstance(task, group):
-                if prev_task:
-                    # automatically upgrade group(...) | s to chord(group, s)
-                    try:
-                        next_step = prev_task
-                        # for chords we freeze by pretending it's a normal
-                        # signature instead of a group.
-                        res = Signature.freeze(next_step, root_id=root_id)
-                        task = chord(
-                            task, body=next_step,
-                            task_id=res.task_id, root_id=root_id,
-                        )
-                    except IndexError:
-                        pass  # no callback, so keep as group.
 
+            if isinstance(task, group) and prev_task:
+                # automatically upgrade group(...) | s to chord(group, s)
+                # for chords we freeze by pretending it's a normal
+                # signature instead of a group.
+                tasks.pop()
+                results.pop()
+                prev_res = prev_prev_res
+                task = chord(
+                    task, body=prev_task,
+                    task_id=res.task_id, root_id=root_id, app=app,
+                )
             if is_last_task:
                 # chain(task_id=id) means task id is set for the last task
                 # in the chain.  If the chord is part of a chord/group
@@ -496,26 +504,36 @@ class chain(Signature):
                 )
             else:
                 res = task.freeze(root_id=root_id)
-            root_id = res.id if root_id is None else root_id
+
             i += 1
 
             if prev_task:
+                prev_task.set_parent_id(task.id)
                 if use_link:
                     # link previous task to this task.
                     task.link(prev_task)
-                    if not res.parent:
+                    if not res.parent and prev_res:
                         prev_res.parent = res.parent
-                else:
+                elif prev_res:
                     prev_res.parent = res
 
+            if is_first_task and parent_id is not None:
+                task.set_parent_id(parent_id)
+
             if link_error:
                 task.set(link_error=link_error)
 
             tasks.append(task)
             results.append(res)
 
-            prev_task, prev_res = task, res
+            prev_prev_task, prev_task, prev_prev_res, prev_res = (
+                prev_task, task, prev_res, res,
+            )
 
+        if root_id is None and tasks:
+            root_id = tasks[-1].id
+            for task in reversed(tasks):
+                task.options['root_id'] = root_id
         return tasks, results
 
     def apply(self, args=(), kwargs={}, **options):
@@ -634,13 +652,16 @@ class chunks(Signature):
         return cls(task, it, n, app=app)()
 
 
-def _maybe_group(tasks):
+def _maybe_group(tasks, app):
+    if isinstance(tasks, dict):
+        tasks = signature(tasks, app=app)
+
     if isinstance(tasks, group):
-        tasks = list(tasks.tasks)
+        tasks = tasks.tasks
     elif isinstance(tasks, abstract.CallableSignature):
         tasks = [tasks]
     else:
-        tasks = [signature(t) for t in regen(tasks)]
+        tasks = [signature(t, app=app) for t in regen(tasks)]
     return tasks
 
 
@@ -649,8 +670,9 @@ class group(Signature):
     tasks = _getitem_property('kwargs.tasks')
 
     def __init__(self, *tasks, **options):
+        app = options.get('app')
         if len(tasks) == 1:
-            tasks = _maybe_group(tasks[0])
+            tasks = _maybe_group(tasks[0], app)
         Signature.__init__(
             self, 'celery.group', (), {'tasks': tasks}, **options
         )
@@ -662,6 +684,9 @@ class group(Signature):
             d, group(d['kwargs']['tasks'], app=app, **d['options']),
         )
 
+    def __len__(self):
+        return len(self.tasks)
+
     def _prepared(self, tasks, partial_args, group_id, root_id, app, dict=dict,
                   CallableSignature=abstract.CallableSignature,
                   from_dict=Signature.from_dict):
@@ -703,6 +728,10 @@ class group(Signature):
             options.pop('task_id', uuid()))
         return options, group_id, options.get('root_id')
 
+    def set_parent_id(self, parent_id):
+        for task in self.tasks:
+            task.set_parent_id(parent_id)
+
     def apply_async(self, args=(), kwargs=None, add_to_parent=True,
                     producer=None, **options):
         app = self.app
@@ -757,7 +786,7 @@ class group(Signature):
     def __call__(self, *partial_args, **options):
         return self.apply_async(partial_args, **options)
 
-    def _freeze_unroll(self, new_tasks, group_id, chord, root_id):
+    def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
         stack = deque(self.tasks)
         while stack:
             task = maybe_signature(stack.popleft(), app=self._app).clone()
@@ -766,9 +795,11 @@ class group(Signature):
             else:
                 new_tasks.append(task)
                 yield task.freeze(group_id=group_id,
-                                  chord=chord, root_id=root_id)
+                                  chord=chord, root_id=root_id,
+                                  parent_id=parent_id)
 
-    def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
         opts = self.options
         try:
             gid = opts['task_id']
@@ -779,11 +810,12 @@ class group(Signature):
         if chord:
             opts['chord'] = chord
         root_id = opts.setdefault('root_id', root_id)
+        parent_id = opts.setdefault('parent_id', parent_id)
         new_tasks = []
         # Need to unroll subgroups early so that chord gets the
         # right result instance for chord_unlock etc.
         results = list(self._freeze_unroll(
-            new_tasks, group_id, chord, root_id,
+            new_tasks, group_id, chord, root_id, parent_id,
         ))
         if isinstance(self.tasks, MutableSequence):
             self.tasks[:] = new_tasks
@@ -819,16 +851,29 @@ class group(Signature):
 class chord(Signature):
 
     def __init__(self, header, body=None, task='celery.chord',
-                 args=(), kwargs={}, **options):
+                 args=(), kwargs={}, app=None, **options):
         Signature.__init__(
             self, task, args,
-            dict(kwargs, header=_maybe_group(header),
+            dict(kwargs, header=_maybe_group(header, app),
                  body=maybe_signature(body, app=self._app)), **options
         )
         self.subtask_type = 'chord'
 
-    def freeze(self, *args, **kwargs):
-        return self.body.freeze(*args, **kwargs)
+    def freeze(self, _id=None, group_id=None, chord=None,
+               root_id=None, parent_id=None):
+        if not isinstance(self.tasks, group):
+            self.tasks = group(self.tasks)
+        self.tasks.freeze(parent_id=parent_id, root_id=root_id)
+        self.id = self.tasks.id
+        return self.body.freeze(_id, parent_id=self.id, root_id=root_id)
+
+    def set_parent_id(self, parent_id):
+        tasks = self.tasks
+        if isinstance(tasks, group):
+            tasks = tasks.tasks
+        for task in tasks:
+            task.set_parent_id(parent_id)
+        self.parent_id = parent_id
 
     @classmethod
     def from_dict(self, d, app=None):
@@ -848,7 +893,11 @@ class chord(Signature):
     def _get_app(self, body=None):
         app = self._app
         if app is None:
-            app = self.tasks[0]._app
+            try:
+                tasks = self.tasks.tasks  # is a group
+            except AttributeError:
+                tasks = self.tasks
+            app = tasks[0]._app
             if app is None and body is not None:
                 app = body._app
         return app if app is not None else current_app
@@ -900,6 +949,7 @@ class chord(Signature):
         body.chord_size = self.__length_hint__()
         options = dict(self.options, **options) if options else self.options
         if options:
+            options.pop('task_id', None)
             body.options.update(options)
 
         results = header.freeze(

+ 18 - 9
celery/events/state.py

@@ -233,11 +233,13 @@ class Task(object):
     state = states.PENDING
     clock = 0
 
-    _fields = ('uuid', 'name', 'state', 'received', 'sent', 'started',
-               'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs',
-               'eta', 'expires', 'retries', 'worker', 'result', 'exception',
-               'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
-               'clock', 'client')
+    _fields = (
+        'uuid', 'name', 'state', 'received', 'sent', 'started',
+        'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs',
+        'eta', 'expires', 'retries', 'worker', 'result', 'exception',
+        'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
+        'clock', 'client', 'root_id', 'parent_id',
+    )
     if not PYPY:
         __slots__ = ('__dict__', '__weakref__')
 
@@ -249,12 +251,19 @@ class Task(object):
     #: that state. ``(RECEIVED, ('name', 'args')``, means the name and args
     #: fields are always taken from the RECEIVED state, and any values for
     #: these fields received before or after is simply ignored.
-    merge_rules = {states.RECEIVED: ('name', 'args', 'kwargs',
-                                     'retries', 'eta', 'expires')}
+    merge_rules = {
+        states.RECEIVED: (
+            'name', 'args', 'kwargs', 'parent_id',
+            'root_id' 'retries', 'eta', 'expires',
+        ),
+    }
 
     #: meth:`info` displays these fields by default.
-    _info_fields = ('args', 'kwargs', 'retries', 'result', 'eta', 'runtime',
-                    'expires', 'exception', 'exchange', 'routing_key')
+    _info_fields = (
+        'args', 'kwargs', 'retries', 'result', 'eta', 'runtime',
+        'expires', 'exception', 'exchange', 'routing_key',
+        'root_id', 'parent_id',
+    )
 
     def __init__(self, uuid=None, **kwargs):
         self.uuid = uuid

+ 3 - 1
celery/result.py

@@ -122,7 +122,7 @@ class AsyncResult(ResultBase):
                                 reply=wait, timeout=timeout)
 
     def get(self, timeout=None, propagate=True, interval=0.5,
-            no_ack=True, follow_parents=True,
+            no_ack=True, follow_parents=True, callback=None,
             EXCEPTION_STATES=states.EXCEPTION_STATES,
             PROPAGATE_STATES=states.PROPAGATE_STATES):
         """Wait until task is ready, and return its result.
@@ -174,6 +174,8 @@ class AsyncResult(ResultBase):
             status = meta['status']
             if status in PROPAGATE_STATES and propagate:
                 raise meta['result']
+            if callback is not None:
+                callback(self.id, meta['result'])
             return meta['result']
     wait = get  # deprecated alias to :meth:`get`.
 

+ 58 - 4
celery/tests/app/test_builtins.py

@@ -133,18 +133,72 @@ class test_chain(BuiltinsCase):
         self.assertTrue(result.parent.parent)
         self.assertIsNone(result.parent.parent.parent)
 
+    def test_group_to_chord__freeze_parent_id(self):
+        def using_freeze(c):
+            c.freeze(parent_id='foo', root_id='root')
+            return c._frozen[0]
+        self.assert_group_to_chord_parent_ids(using_freeze)
+
+    def assert_group_to_chord_parent_ids(self, freezefun):
+        c = (
+            self.add.s(5, 5) |
+            group([self.add.s(i, i) for i in range(5)], app=self.app) |
+            self.add.si(10, 10) |
+            self.add.si(20, 20) |
+            self.add.si(30, 30)
+        )
+        tasks = freezefun(c)
+        self.assertEqual(tasks[-1].parent_id, 'foo')
+        self.assertEqual(tasks[-1].root_id, 'root')
+        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].root_id, 'root')
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].tasks.id)
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
+        self.assertEqual(tasks[-2].body.root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[0].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[0].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[1].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[1].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[2].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[3].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[3].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[4].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[4].root_id, 'root')
+        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
+        self.assertEqual(tasks[-3].root_id, 'root')
+        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
+        self.assertEqual(tasks[-4].root_id, 'root')
+
     def test_group_to_chord(self):
         c = (
+            self.add.s(5) |
             group([self.add.s(i, i) for i in range(5)], app=self.app) |
             self.add.s(10) |
             self.add.s(20) |
             self.add.s(30)
         )
         c._use_link = True
-        tasks, _ = c.prepare_steps((), c.tasks)
-        self.assertIsInstance(tasks[-1], chord)
-        self.assertTrue(tasks[-1].body.options['link'])
-        self.assertTrue(tasks[-1].body.options['link'][0].options['link'])
+        tasks, results = c.prepare_steps((), c.tasks)
+
+        self.assertEqual(tasks[-1].args[0], 5)
+        self.assertIsInstance(tasks[-2], chord)
+        self.assertEqual(len(tasks[-2].tasks), 5)
+        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].root_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].body.args[0], 10)
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
+
+        self.assertEqual(tasks[-3].args[0], 20)
+        self.assertEqual(tasks[-3].root_id, tasks[-1].id)
+        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
+
+        self.assertEqual(tasks[-4].args[0], 30)
+        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
+        self.assertEqual(tasks[-4].root_id, tasks[-1].id)
+
+        self.assertTrue(tasks[-2].body.options['link'])
+        self.assertTrue(tasks[-2].body.options['link'][0].options['link'])
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = True

+ 35 - 1
celery/tests/tasks/test_canvas.py

@@ -14,7 +14,7 @@ from celery.canvas import (
 )
 from celery.result import EagerResult
 
-from celery.tests.case import AppCase, Mock
+from celery.tests.case import AppCase, ContextMock, Mock
 
 SIG = Signature({'task': 'TASK',
                  'args': ('A1',),
@@ -233,6 +233,40 @@ class test_chain(CanvasCase):
         self.assertIsNone(chain(app=self.app)())
         self.assertIsNone(chain(app=self.app).apply_async())
 
+    def test_root_id_parent_id(self):
+        self.app.conf.task_protocol = 2
+        c = chain(self.add.si(i, i) for i in range(4))
+        c.freeze()
+        tasks, _ = c._frozen
+        for i, task in enumerate(tasks):
+            self.assertEqual(task.root_id, tasks[-1].id)
+            try:
+                self.assertEqual(task.parent_id, tasks[i + 1].id)
+            except IndexError:
+                assert i == len(tasks) - 1
+            else:
+                valid_parents = i
+        self.assertEqual(valid_parents, len(tasks) - 2)
+
+        self.assert_sent_with_ids(tasks[-1], tasks[-1].id, 'foo',
+                                  parent_id='foo')
+        self.assertTrue(tasks[-2].options['parent_id'])
+        self.assert_sent_with_ids(tasks[-2], tasks[-1].id, tasks[-1].id)
+        self.assert_sent_with_ids(tasks[-3], tasks[-1].id, tasks[-2].id)
+        self.assert_sent_with_ids(tasks[-4], tasks[-1].id, tasks[-3].id)
+
+
+    def assert_sent_with_ids(self, task, rid, pid, **options):
+        self.app.amqp.send_task_message = Mock(name='send_task_message')
+        self.app.backend = Mock()
+        self.app.producer_or_acquire = ContextMock()
+
+        res = task.apply_async(**options)
+        self.assertTrue(self.app.amqp.send_task_message.called)
+        message = self.app.amqp.send_task_message.call_args[0][2]
+        self.assertEqual(message.headers['parent_id'], pid)
+        self.assertEqual(message.headers['root_id'], rid)
+
     def test_call_no_tasks(self):
         x = chain()
         self.assertFalse(x())

+ 0 - 1
celery/worker/consumer.py

@@ -458,7 +458,6 @@ class Consumer(object):
         callbacks = self.on_task_message
 
         def on_task_received(message):
-
             # payload will only be set for v1 protocol, since v2
             # will defer deserializing the message body to the pool.
             payload = None

+ 5 - 3
celery/worker/request.py

@@ -77,9 +77,9 @@ class Request(object):
 
     if not IS_PYPY:  # pragma: no cover
         __slots__ = (
-            'app', 'type', 'name', 'id', 'on_ack', 'body',
-            'hostname', 'eventer', 'connection_errors', 'task', 'eta',
-            'expires', 'request_dict', 'on_reject', 'utc',
+            'app', 'type', 'name', 'id', 'root_id', 'parent_id',
+            'on_ack', 'body', 'hostname', 'eventer', 'connection_errors',
+            'task', 'eta', 'expires', 'request_dict', 'on_reject', 'utc',
             'content_type', 'content_encoding', 'argsrepr', 'kwargsrepr',
             '__weakref__', '__dict__',
         )
@@ -108,6 +108,8 @@ class Request(object):
 
         self.id = headers['id']
         type = self.type = self.name = headers['task']
+        self.root_id = headers.get('root_id')
+        self.parent_id = headers.get('parent_id')
         if 'shadow' in headers:
             self.name = headers['shadow']
         if 'timelimit' in headers:

+ 2 - 2
docs/userguide/monitoring.rst

@@ -650,7 +650,7 @@ task-sent
 ~~~~~~~~~
 
 :signature: ``task-sent(uuid, name, args, kwargs, retries, eta, expires,
-              queue, exchange, routing_key)``
+              queue, exchange, routing_key, root_id, parent_id)``
 
 Sent when a task message is published and
 the :setting:`task_send_sent_event` setting is enabled.
@@ -661,7 +661,7 @@ task-received
 ~~~~~~~~~~~~~
 
 :signature: ``task-received(uuid, name, args, kwargs, retries, eta, hostname,
-              timestamp)``
+              timestamp, root_id, parent_id)``
 
 Sent when the worker receives a task.
 

+ 10 - 0
funtests/stress/stress/app.py

@@ -63,6 +63,16 @@ def add(x, y):
     return x + y
 
 
+@app.task(bind=True)
+def ids(self, i):
+    return (self.request.root_id, self.request.parent_id, i)
+
+
+@app.task(bind=True)
+def collect_ids(self, ids, i):
+    return ids, (self.request.root_id, self.request.parent_id, i)
+
+
 @app.task
 def xsum(x):
     return sum(x)

+ 107 - 4
funtests/stress/stress/suite.py

@@ -10,7 +10,7 @@ from collections import OrderedDict, defaultdict, namedtuple
 from itertools import count
 from time import sleep
 
-from celery import group, VERSION_BANNER
+from celery import VERSION_BANNER, chain, group, uuid
 from celery.exceptions import TimeoutError
 from celery.five import items, monotonic, range, values
 from celery.utils.debug import blockdetection
@@ -18,12 +18,13 @@ from celery.utils.text import pluralize, truncate
 from celery.utils.timeutils import humanize_seconds
 
 from .app import (
-    marker, _marker, add, any_, exiting, kill, sleeping,
+    marker, _marker, add, any_, collect_ids, exiting, ids, kill, sleeping,
     sleeping_ignore_limits, any_returning, print_unicode,
 )
 from .data import BIG, SMALL
 from .fbi import FBI
 
+
 BANNER = """\
 Celery stress-suite v{version}
 
@@ -50,6 +51,10 @@ Progress = namedtuple('Progress', (
 Inf = float('Inf')
 
 
+def assert_equal(a, b):
+    assert a == b, '{0!r} != {1!r}'.format(a, b)
+
+
 class StopSuite(Exception):
     pass
 
@@ -163,6 +168,7 @@ class BaseSuite(object):
         )
 
     def runtest(self, fun, n=50, index=0, repeats=1):
+        n = getattr(fun, '__iterations__', None) or n
         print('{0}: [[[{1}({2})]]]'.format(repeats, fun.__name__, n))
         with blockdetection(self.block_timeout):
             with self.fbi.investigation():
@@ -185,6 +191,8 @@ class BaseSuite(object):
                             raise
                         except Exception as exc:
                             print('-> {0!r}'.format(exc))
+                            import traceback
+                            print(traceback.format_exc())
                             print(pstatus(self.progress))
                         else:
                             print(pstatus(self.progress))
@@ -238,13 +246,14 @@ class BaseSuite(object):
 _creation_counter = count(0)
 
 
-def testcase(*groups):
+def testcase(*groups, **kwargs):
     if not groups:
         raise ValueError('@testcase requires at least one group name')
 
     def _mark_as_case(fun):
         fun.__testgroup__ = groups
         fun.__testsort__ = next(_creation_counter)
+        fun.__iterations__ = kwargs.get('iterations')
         return fun
 
     return _mark_as_case
@@ -262,12 +271,106 @@ def _is_descriptor(obj, attr):
 
 class Suite(BaseSuite):
 
+    @testcase('all', 'green', iterations=1)
+    def chain(self):
+        c = add.s(4, 4) | add.s(8) | add.s(16)
+        assert_equal(self.join(c()), 32)
+
+    @testcase('all', 'green', iterations=1)
+    def chaincomplex(self):
+        c = (
+            add.s(2, 2) | (
+                add.s(4) | add.s(8) | add.s(16)
+            ) |
+            group(add.s(i) for i in range(4))
+        )
+        res = c()
+        assert_equal(res.get(), [32, 33, 34, 35])
+
+    @testcase('all', 'green', iterations=1)
+    def parentids_chain(self):
+        c = chain(ids.si(i) for i in range(248))
+        c.freeze()
+        res = c()
+        res.get(timeout=5)
+        self.assert_ids(res, len(c.tasks) - 1)
+
+    @testcase('all', 'green', iterations=1)
+    def parentids_group(self):
+        g = ids.si(1) | ids.si(2) | group(ids.si(i) for i in range(2, 50))
+        res = g()
+        expected_root_id = res.parent.parent.id
+        expected_parent_id = res.parent.id
+        values = res.get(timeout=5)
+
+        for i, r in enumerate(values):
+            root_id, parent_id, value = r
+            assert_equal(root_id, expected_root_id)
+            assert_equal(parent_id, expected_parent_id)
+            assert_equal(value, i + 2)
+
+    def assert_ids(self, res, len):
+        i, root = len, res
+        while root.parent:
+            root = root.parent
+        node = res
+        while node:
+            root_id, parent_id, value = node.get(timeout=5)
+            assert_equal(value, i)
+            assert_equal(root_id, root.id)
+            if node.parent:
+                assert_equal(parent_id, node.parent.id)
+            node = node.parent
+            i -= 1
+
+    @testcase('redis', iterations=1)
+    def parentids_chord(self):
+        self.assert_parentids_chord()
+        self.assert_parentids_chord(uuid(), uuid())
+
+    def assert_parentids_chord(self, base_root=None, base_parent=None):
+        g = (
+            ids.si(1) |
+            ids.si(2) |
+            group(ids.si(i) for i in range(3, 50)) |
+            collect_ids.s(i=50) |
+            ids.si(51)
+        )
+        g.freeze(root_id=base_root, parent_id=base_parent)
+        res = g.apply_async(root_id=base_root, parent_id=base_parent)
+        expected_root_id = base_root or res.parent.parent.parent.id
+
+        root_id, parent_id, value = res.get(timeout=5)
+        assert_equal(value, 51)
+        assert_equal(root_id, expected_root_id)
+        assert_equal(parent_id, res.parent.id)
+
+        prev, (root_id, parent_id, value) = res.parent.get(timeout=5)
+        assert_equal(value, 50)
+        assert_equal(root_id, expected_root_id)
+        assert_equal(parent_id, res.parent.parent.id)
+
+        for i, p in enumerate(prev):
+            root_id, parent_id, value = p
+            assert_equal(root_id, expected_root_id)
+            assert_equal(parent_id, res.parent.parent.id)
+
+        root_id, parent_id, value = res.parent.parent.get(timeout=5)
+        assert_equal(value, 2)
+        assert_equal(parent_id, res.parent.parent.parent.id)
+        assert_equal(root_id, expected_root_id)
+
+        root_id, parent_id, value = res.parent.parent.parent.get(timeout=5)
+        assert_equal(value, 1)
+        assert_equal(root_id, expected_root_id)
+        assert_equal(parent_id, base_parent)
+
     @testcase('all', 'green')
     def manyshort(self):
         self.join(group(add.s(i, i) for i in range(1000))(),
                   timeout=10, propagate=True)
 
-    @testcase('all', 'green')
+    @testcase('all', 'green', iterations=1)
     def unicodetask(self):
         self.join(group(print_unicode.s() for _ in range(5))(),
                   timeout=1, propagate=True)

+ 1 - 1
funtests/stress/stress/templates.py

@@ -57,7 +57,7 @@ class default(object):
     result_serializer = 'json'
     result_persistent = True
     result_expires = 300
-    result_cache_max = -1
+    result_cache_max = 100
     task_default_queue = CSTRESS_QUEUE
     task_queues = [
         Queue(CSTRESS_QUEUE,