Browse Source

Coverage one point up, long way to go

Ask Solem 9 years ago
parent
commit
b1deab39aa

+ 2 - 1
.coveragerc

@@ -1,6 +1,7 @@
 [run]
 branch = 1
 cover_pylib = 0
-omit = celery.utils.debug,celery.tests.*,celery.bin.graph
+include=*celery/*
+omit = celery.utils.debug,celery.tests.*,celery.bin.graph;
 [report]
 omit = */python?.?/*,*/site-packages/*,*/pypy/*

+ 18 - 18
celery/app/amqp.py

@@ -26,7 +26,7 @@ from celery.five import items, string_t
 from celery.local import try_import
 from celery.utils.saferepr import saferepr
 from celery.utils.text import indent as textindent
-from celery.utils.timeutils import to_utc
+from celery.utils.timeutils import maybe_make_aware, to_utc
 
 from . import routes as _routes
 
@@ -300,7 +300,6 @@ class AMQP(object):
                    shadow=None, chain=None, now=None, timezone=None):
         args = args or ()
         kwargs = kwargs or {}
-        utc = self.utc
         if not isinstance(args, (list, tuple)):
             raise TypeError('task args must be a list or tuple')
         if not isinstance(kwargs, Mapping):
@@ -308,22 +307,22 @@ class AMQP(object):
         if countdown:  # convert countdown to ETA
             now = now or self.app.now()
             timezone = timezone or self.app.timezone
-            eta = now + timedelta(seconds=countdown)
-            if utc:
-                eta = to_utc(eta).astimezone(timezone)
+            eta = maybe_make_aware(
+                now + timedelta(seconds=countdown), tz=timezone,
+            )
         if isinstance(expires, numbers.Real):
             now = now or self.app.now()
             timezone = timezone or self.app.timezone
-            expires = now + timedelta(seconds=expires)
-            if utc:
-                expires = to_utc(expires).astimezone(timezone)
+            expires = maybe_make_aware(
+                now + timedelta(seconds=expires), tz=timezone,
+            )
         eta = eta and eta.isoformat()
         expires = expires and expires.isoformat()
 
         argsrepr = saferepr(args)
         kwargsrepr = saferepr(kwargs)
 
-        if JSON_NEEDS_UNICODE_KEYS:
+        if JSON_NEEDS_UNICODE_KEYS:  # pragma: no cover
             if callbacks:
                 callbacks = [utf8dict(callback) for callback in callbacks]
             if errbacks:
@@ -400,7 +399,7 @@ class AMQP(object):
         eta = eta and eta.isoformat()
         expires = expires and expires.isoformat()
 
-        if JSON_NEEDS_UNICODE_KEYS:
+        if JSON_NEEDS_UNICODE_KEYS:  # pragma: no cover
             if callbacks:
                 callbacks = [utf8dict(callback) for callback in callbacks]
             if errbacks:
@@ -462,12 +461,13 @@ class AMQP(object):
         default_serializer = self.app.conf.task_serializer
         default_compressor = self.app.conf.result_compression
 
-        def publish_task(producer, name, message,
-                         exchange=None, routing_key=None, queue=None,
-                         event_dispatcher=None, retry=None, retry_policy=None,
-                         serializer=None, delivery_mode=None,
-                         compression=None, declare=None,
-                         headers=None, **kwargs):
+        def send_task_message(producer, name, message,
+                              exchange=None, routing_key=None, queue=None,
+                              event_dispatcher=None,
+                              retry=None, retry_policy=None,
+                              serializer=None, delivery_mode=None,
+                              compression=None, declare=None,
+                              headers=None, **kwargs):
             retry = default_retry if retry is None else retry
             headers2, properties, body, sent_event = message
             if headers:
@@ -527,7 +527,7 @@ class AMQP(object):
             if sent_event:
                 evd = event_dispatcher or default_evd
                 exname = exchange or self.exchange
-                if isinstance(name, Exchange):
+                if isinstance(exname, Exchange):
                     exname = exname.name
                 sent_event.update({
                     'queue': qname,
@@ -537,7 +537,7 @@ class AMQP(object):
                 evd.publish('task-sent', sent_event,
                             self, retry=retry, retry_policy=retry_policy)
             return ret
-        return publish_task
+        return send_task_message
 
     @cached_property
     def default_queue(self):

+ 6 - 6
celery/backends/base.py

@@ -110,13 +110,13 @@ class BaseBackend(object):
 
     def mark_as_started(self, task_id, **meta):
         """Mark a task as started"""
-        return self.store_result(task_id, meta, status=states.STARTED)
+        return self.store_result(task_id, meta, states.STARTED)
 
     def mark_as_done(self, task_id, result,
                      request=None, store_result=True, state=states.SUCCESS):
         """Mark task as successfully executed."""
         if store_result:
-            self.store_result(task_id, result, status=state, request=request)
+            self.store_result(task_id, result, state, request=request)
         if request and request.chord:
             self.on_chord_part_return(request, state, result)
 
@@ -125,7 +125,7 @@ class BaseBackend(object):
                         state=states.FAILURE):
         """Mark task as executed with failure. Stores the exception."""
         if store_result:
-            self.store_result(task_id, exc, status=state,
+            self.store_result(task_id, exc, state,
                               traceback=traceback, request=request)
         if request and request.chord:
             self.on_chord_part_return(request, state, exc)
@@ -134,8 +134,8 @@ class BaseBackend(object):
                         request=None, store_result=True, state=states.REVOKED):
         exc = TaskRevokedError(reason)
         if store_result:
-            self.store_result(task_id, exc,
-                              status=state, traceback=None, request=request)
+            self.store_result(task_id, exc, state,
+                              traceback=None, request=request)
         if request and request.chord:
             self.on_chord_part_return(request, state, exc)
 
@@ -143,7 +143,7 @@ class BaseBackend(object):
                       request=None, store_result=True, state=states.RETRY):
         """Mark task as being retries. Stores the current
         exception (if any)."""
-        return self.store_result(task_id, exc, status=state,
+        return self.store_result(task_id, exc, state,
                                  traceback=traceback, request=request)
 
     def chord_error_from_stack(self, callback, exc=None):

+ 1 - 1
celery/backends/redis.py

@@ -17,7 +17,7 @@ from celery import states
 from celery.canvas import maybe_signature
 from celery.exceptions import ChordError, ImproperlyConfigured
 from celery.five import string_t
-from celery.utils import deprecated_property, strtobool
+from celery.utils import deprecated_property
 from celery.utils.functional import dictfilter
 from celery.utils.log import get_logger
 from celery.utils.timeutils import humanize_seconds

+ 6 - 6
celery/canvas.py

@@ -27,7 +27,7 @@ from celery.local import try_import
 from celery.result import GroupResult
 from celery.utils import abstract
 from celery.utils.functional import (
-    maybe_list, is_list, noop, regen, chunks as _chunks,
+    maybe_list, is_list, regen, chunks as _chunks,
 )
 from celery.utils.text import truncate
 
@@ -457,7 +457,7 @@ class chain(Signature):
         steps_pop = steps.pop
         steps_extend = steps.extend
 
-        next_step = prev_task = prev_prev_task = None
+        prev_task = None
         prev_res = prev_prev_res = None
         tasks, results = [], []
         i = 0
@@ -490,7 +490,7 @@ class chain(Signature):
                 prev_res = prev_prev_res
                 task = chord(
                     task, body=prev_task,
-                    task_id=res.task_id, root_id=root_id, app=app,
+                    task_id=prev_res.task_id, root_id=root_id, app=app,
                 )
             if is_last_task:
                 # chain(task_id=id) means task id is set for the last task
@@ -526,8 +526,8 @@ class chain(Signature):
             tasks.append(task)
             results.append(res)
 
-            prev_prev_task, prev_task, prev_prev_res, prev_res = (
-                prev_task, task, prev_res, res,
+            prev_task, prev_prev_res, prev_res = (
+                task, prev_res, res,
             )
 
         if root_id is None and tasks:
@@ -701,7 +701,7 @@ class group(Signature):
                     task = from_dict(task, app=app)
                 if isinstance(task, group):
                     # needs yield_from :(
-                    unroll = task_prepared(
+                    unroll = task._prepared(
                         task.tasks, partial_args, group_id, root_id, app,
                     )
                     for taskN, resN in unroll:

+ 1 - 1
celery/concurrency/asynpool.py

@@ -33,7 +33,7 @@ from pickle import HIGHEST_PROTOCOL
 from time import sleep
 from weakref import WeakValueDictionary, ref
 
-from amqp.utils import promise
+from amqp import promise
 from billiard.pool import RUN, TERMINATE, ACK, NACK, WorkersJoined
 from billiard import pool as _pool
 from billiard.compat import buf_t, setblocking, isblocking

+ 132 - 2
celery/tests/app/test_amqp.py

@@ -1,10 +1,15 @@
 from __future__ import absolute_import
 
+from datetime import datetime, timedelta
+
 from kombu import Exchange, Queue
 
-from celery.app.amqp import Queues
+from celery import uuid
+from celery.app.amqp import Queues, utf8dict
 from celery.five import keys
-from celery.tests.case import AppCase
+from celery.utils.timeutils import to_utc
+
+from celery.tests.case import AppCase, Mock
 
 
 class test_TaskConsumer(AppCase):
@@ -146,6 +151,12 @@ class test_Queues(AppCase):
             'x-max-priority': 3,
         })
 
+        q1 = Queue('moo', queue_arguments=None)
+        qs1.add(q1)
+        self.assertEqual(qs1['moo'].queue_arguments, {
+            'x-max-priority': 10,
+        })
+
         qs2 = Queues(ha_policy='all', max_priority=5)
         qs2.add('bar')
         self.assertEqual(qs2['bar'].queue_arguments, {
@@ -169,3 +180,122 @@ class test_Queues(AppCase):
         self.assertEqual(qs3['xyx3'].queue_arguments, {
             'x-max-priority': 7,
         })
+
+
+class test_AMQP(AppCase):
+
+    def setup(self):
+        self.simple_message = self.app.amqp.as_task_v2(
+            uuid(), 'foo', create_sent_event=True,
+        )
+
+    def test_Queues__with_ha_policy(self):
+        x = self.app.amqp.Queues({}, ha_policy='all')
+        self.assertEqual(x.ha_policy, 'all')
+
+    def test_Queues__with_max_priority(self):
+        x = self.app.amqp.Queues({}, max_priority=23)
+        self.assertEqual(x.max_priority, 23)
+
+    def test_send_task_message__no_kwargs(self):
+        self.app.amqp.send_task_message(Mock(), 'foo', self.simple_message)
+
+    def test_send_task_message__properties(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, foo=1, retry=False,
+        )
+        self.assertEqual(prod.publish.call_args[1]['foo'], 1)
+
+    def test_send_task_message__headers(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, headers={'x1x': 'y2x'},
+            retry=False,
+        )
+        self.assertEqual(prod.publish.call_args[1]['headers']['x1x'], 'y2x')
+
+    def test_send_task_message__queue_string(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, queue='foo', retry=False,
+        )
+        kwargs = prod.publish.call_args[1]
+        self.assertEqual(kwargs['routing_key'], 'foo')
+        self.assertEqual(kwargs['exchange'], 'foo')
+
+    def test_send_event_exchange_string(self):
+        evd = Mock(name="evd")
+        self.app.amqp.send_task_message(
+            Mock(), 'foo', self.simple_message, retry=False,
+            exchange='xyz', routing_key='xyb',
+            event_dispatcher=evd,
+        )
+        self.assertTrue(evd.publish.called)
+        event = evd.publish.call_args[0][1]
+        self.assertEqual(event['routing_key'], 'xyb')
+        self.assertEqual(event['exchange'], 'xyz')
+
+    def test_send_task_message__with_delivery_mode(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, delivery_mode=33, retry=False,
+        )
+        self.assertEqual(prod.publish.call_args[1]['delivery_mode'], 33)
+
+    def test_routes(self):
+        r1 = self.app.amqp.routes
+        r2 = self.app.amqp.routes
+        self.assertIs(r1, r2)
+
+
+class test_as_task_v2(AppCase):
+
+    def test_raises_if_args_is_not_tuple(self):
+        with self.assertRaises(TypeError):
+            self.app.amqp.as_task_v2(uuid(), 'foo', args='123')
+
+    def test_raises_if_kwargs_is_not_mapping(self):
+        with self.assertRaises(TypeError):
+            self.app.amqp.as_task_v2(uuid(), 'foo', kwargs=(1, 2, 3))
+
+    def test_countdown_to_eta(self):
+        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo', countdown=10, now=now,
+        )
+        self.assertEqual(
+            m.headers['eta'],
+            (now + timedelta(seconds=10)).isoformat(),
+        )
+
+    def test_expires_to_datetime(self):
+        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo', expires=30, now=now,
+        )
+        self.assertEqual(
+            m.headers['expires'],
+            (now + timedelta(seconds=30)).isoformat(),
+        )
+
+    def test_callbacks_errbacks_chord(self):
+
+        @self.app.task
+        def t(i):
+            pass
+
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo',
+            callbacks=[t.s(1), t.s(2)],
+            errbacks=[t.s(3), t.s(4)],
+            chord=t.s(5),
+        )
+        _, _, embed = m.body
+        self.assertListEqual(
+            embed['callbacks'], [utf8dict(t.s(1)), utf8dict(t.s(2))],
+        )
+        self.assertListEqual(
+            embed['errbacks'], [utf8dict(t.s(3)), utf8dict(t.s(4))],
+        )
+        self.assertEqual(embed['chord'], utf8dict(t.s(5)))

+ 14 - 0
celery/tests/app/test_app.py

@@ -24,6 +24,7 @@ from celery.tests.case import (
     CELERY_TEST_CONFIG,
     AppCase,
     Mock,
+    Case,
     depends_on_current_app,
     mask_modules,
     patch,
@@ -75,6 +76,19 @@ class test_module(AppCase):
         self.assertTrue(_app.bugreport(app=self.app))
 
 
+class test_task_join_will_block(Case):
+
+    def test_task_join_will_block(self):
+        prev, _state._task_join_will_block = _state._task_join_will_block, 0
+        try:
+            self.assertEqual(_state._task_join_will_block, 0)
+            _state._set_task_join_will_block(True)
+            print(_state.task_join_will_block)
+            self.assertTrue(_state.task_join_will_block())
+        finally:
+            _state._task_join_will_block = prev
+
+
 class test_App(AppCase):
 
     def setup(self):

+ 46 - 155
celery/tests/app/test_builtins.py

@@ -2,10 +2,10 @@ from __future__ import absolute_import
 
 from celery import group, chord
 from celery.app import builtins
-from celery.canvas import Signature
 from celery.five import range
-from celery._state import _task_stack
-from celery.tests.case import AppCase, Mock, patch
+from celery.utils.functional import pass1
+
+from celery.tests.case import AppCase, ContextMock, Mock, patch
 
 
 class BuiltinsCase(AppCase):
@@ -32,6 +32,18 @@ class test_backend_cleanup(BuiltinsCase):
         self.assertTrue(self.app.backend.cleanup.called)
 
 
+class test_accumulate(BuiltinsCase):
+
+    def setup(self):
+        self.accumulate = self.app.tasks['celery.accumulate']
+
+    def test_with_index(self):
+        self.assertEqual(self.accumulate(1, 2, 3, 4, index=0), 1)
+
+    def test_no_index(self):
+        self.assertEqual(self.accumulate(1, 2, 3, 4), (1, 2, 3, 4))
+
+
 class test_map(BuiltinsCase):
 
     def test_run(self):
@@ -78,46 +90,42 @@ class test_chunks(BuiltinsCase):
 class test_group(BuiltinsCase):
 
     def setup(self):
+        self.maybe_signature = self.patch('celery.canvas.maybe_signature')
+        self.maybe_signature.side_effect = pass1
+        self.app.producer_or_acquire = Mock()
+        self.app.producer_or_acquire.attach_mock(ContextMock(), 'return_value')
+        self.app.conf.task_always_eager = True
         self.task = builtins.add_group_task(self.app)
         super(test_group, self).setup()
 
     def test_apply_async_eager(self):
-        self.task.apply = Mock()
-        self.app.conf.task_always_eager = True
+        self.task.apply = Mock(name='apply')
         self.task.apply_async((1, 2, 3, 4, 5))
         self.assertTrue(self.task.apply.called)
 
-    def test_apply(self):
-        x = group([self.add.s(4, 4), self.add.s(8, 8)])
-        res = x.apply()
-        self.assertEqual(res.get(), [8, 16])
+    def mock_group(self, *tasks):
+        g = group(*tasks, app=self.app)
+        result = g.freeze()
+        for task in g.tasks:
+            task.clone = Mock(name='clone')
+            task.clone.attach_mock(Mock(), 'apply_async')
+        return g, result
+
+    @patch('celery.app.builtins.get_current_worker_task')
+    def test_task(self, get_current_worker_task):
+        g, result = self.mock_group(self.add.s(2), self.add.s(4))
+        self.task(g.tasks, result, result.id, (2,)).results
+        g.tasks[0].clone().apply_async.assert_called_with(
+            group_id=result.id, producer=self.app.producer_or_acquire(),
+            add_to_parent=False,
+        )
+        get_current_worker_task().add_trail.assert_called_with(result)
 
-    def test_apply_async(self):
-        x = group([self.add.s(4, 4), self.add.s(8, 8)])
-        x.apply_async()
-
-    def test_apply_empty(self):
-        x = group(app=self.app)
-        x.apply()
-        res = x.apply_async()
-        self.assertFalse(res)
-        self.assertFalse(res.results)
-
-    def test_apply_async_with_parent(self):
-        _task_stack.push(self.add)
-        try:
-            self.add.push_request(called_directly=False)
-            try:
-                assert not self.add.request.children
-                x = group([self.add.s(4, 4), self.add.s(8, 8)])
-                res = x()
-                self.assertTrue(self.add.request.children)
-                self.assertIn(res, self.add.request.children)
-                self.assertEqual(len(self.add.request.children), 1)
-            finally:
-                self.add.pop_request()
-        finally:
-            _task_stack.pop()
+    @patch('celery.app.builtins.get_current_worker_task')
+    def test_task__disable_add_to_parent(self, get_current_worker_task):
+        g, result = self.mock_group(self.add.s(2, 2), self.add.s(4, 4))
+        self.task(g.tasks, result, result.id, None, add_to_parent=False)
+        self.assertFalse(get_current_worker_task().add_trail.called)
 
 
 class test_chain(BuiltinsCase):
@@ -126,126 +134,9 @@ class test_chain(BuiltinsCase):
         BuiltinsCase.setup(self)
         self.task = builtins.add_chain_task(self.app)
 
-    def test_apply_async(self):
-        c = self.add.s(2, 2) | self.add.s(4) | self.add.s(8)
-        result = c.apply_async()
-        self.assertTrue(result.parent)
-        self.assertTrue(result.parent.parent)
-        self.assertIsNone(result.parent.parent.parent)
-
-    def test_group_to_chord__freeze_parent_id(self):
-        def using_freeze(c):
-            c.freeze(parent_id='foo', root_id='root')
-            return c._frozen[0]
-        self.assert_group_to_chord_parent_ids(using_freeze)
-
-    def assert_group_to_chord_parent_ids(self, freezefun):
-        c = (
-            self.add.s(5, 5) |
-            group([self.add.s(i, i) for i in range(5)], app=self.app) |
-            self.add.si(10, 10) |
-            self.add.si(20, 20) |
-            self.add.si(30, 30)
-        )
-        tasks = freezefun(c)
-        self.assertEqual(tasks[-1].parent_id, 'foo')
-        self.assertEqual(tasks[-1].root_id, 'root')
-        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].root_id, 'root')
-        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].tasks.id)
-        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
-        self.assertEqual(tasks[-2].body.root_id, 'root')
-        self.assertEqual(tasks[-2].tasks.tasks[0].parent_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].tasks.tasks[0].root_id, 'root')
-        self.assertEqual(tasks[-2].tasks.tasks[1].parent_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].tasks.tasks[1].root_id, 'root')
-        self.assertEqual(tasks[-2].tasks.tasks[2].parent_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].tasks.tasks[2].root_id, 'root')
-        self.assertEqual(tasks[-2].tasks.tasks[3].parent_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].tasks.tasks[3].root_id, 'root')
-        self.assertEqual(tasks[-2].tasks.tasks[4].parent_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].tasks.tasks[4].root_id, 'root')
-        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
-        self.assertEqual(tasks[-3].root_id, 'root')
-        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
-        self.assertEqual(tasks[-4].root_id, 'root')
-
-    def test_group_to_chord(self):
-        c = (
-            self.add.s(5) |
-            group([self.add.s(i, i) for i in range(5)], app=self.app) |
-            self.add.s(10) |
-            self.add.s(20) |
-            self.add.s(30)
-        )
-        c._use_link = True
-        tasks, results = c.prepare_steps((), c.tasks)
-
-        self.assertEqual(tasks[-1].args[0], 5)
-        self.assertIsInstance(tasks[-2], chord)
-        self.assertEqual(len(tasks[-2].tasks), 5)
-        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].root_id, tasks[-1].id)
-        self.assertEqual(tasks[-2].body.args[0], 10)
-        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
-
-        self.assertEqual(tasks[-3].args[0], 20)
-        self.assertEqual(tasks[-3].root_id, tasks[-1].id)
-        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
-
-        self.assertEqual(tasks[-4].args[0], 30)
-        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
-        self.assertEqual(tasks[-4].root_id, tasks[-1].id)
-
-        self.assertTrue(tasks[-2].body.options['link'])
-        self.assertTrue(tasks[-2].body.options['link'][0].options['link'])
-
-        c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
-        c2._use_link = True
-        tasks2, _ = c2.prepare_steps((), c2.tasks)
-        self.assertIsInstance(tasks2[0], group)
-
-    def test_group_to_chord__protocol_2(self):
-        c = (
-            group([self.add.s(i, i) for i in range(5)], app=self.app) |
-            self.add.s(10) |
-            self.add.s(20) |
-            self.add.s(30)
-        )
-        c._use_link = False
-        tasks, _ = c.prepare_steps((), c.tasks)
-        self.assertIsInstance(tasks[-1], chord)
-
-        c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
-        c2._use_link = False
-        tasks2, _ = c2.prepare_steps((), c2.tasks)
-        self.assertIsInstance(tasks2[0], group)
-    def test_apply_options(self):
-
-        class static(Signature):
-
-            def clone(self, *args, **kwargs):
-                return self
-
-        def s(*args, **kwargs):
-            return static(self.add, args, kwargs, type=self.add, app=self.app)
-
-        c = s(2, 2) | s(4, 4) | s(8, 8)
-        r1 = c.apply_async(task_id='some_id')
-        self.assertEqual(r1.id, 'some_id')
-
-        c.apply_async(group_id='some_group_id')
-        self.assertEqual(c.tasks[-1].options['group_id'], 'some_group_id')
-
-        c.apply_async(chord='some_chord_id')
-        self.assertEqual(c.tasks[-1].options['chord'], 'some_chord_id')
-
-        c.apply_async(link=[s(32)])
-        self.assertListEqual(c.tasks[-1].options['link'], [s(32)])
-
-        c.apply_async(link_error=[s('error')])
-        for task in c.tasks:
-            self.assertListEqual(task.options['link_error'], [s('error')])
+    def test_not_implemented(self):
+        with self.assertRaises(NotImplementedError):
+            self.task()
 
 
 class test_chord(BuiltinsCase):

+ 11 - 0
celery/tests/bin/test_celery.py

@@ -41,6 +41,17 @@ class test__main__(AppCase):
                 mpc.assert_called_with()
                 main.assert_called_with()
 
+    def test_main__multi(self):
+        with patch('celery.__main__.maybe_patch_concurrency') as mpc:
+            with patch('celery.bin.celery.main') as main:
+                prev, sys.argv = sys.argv, ['foo', 'multi']
+                try:
+                    __main__.main()
+                    self.assertFalse(mpc.called)
+                    main.assert_called_with()
+                finally:
+                    sys.argv = prev
+
 
 class test_Command(AppCase):
 

+ 57 - 20
celery/tests/case.py

@@ -309,6 +309,12 @@ def alive_threads():
 
 class Case(unittest.TestCase):
 
+    def patch(self, *path, **options):
+        manager = patch(".".join(path), **options)
+        patched = manager.start()
+        self.addCleanup(manager.stop)
+        return patched
+
     def assertWarns(self, expected_warning):
         return _AssertWarnsContext(expected_warning, self, None)
 
@@ -420,6 +426,8 @@ class AppCase(Case):
         self._threads_at_setup = self.threads_at_startup()
         from celery import _state
         from celery import result
+        self._prev_res_join_block = result.task_join_will_block
+        self._prev_state_join_block = _state.task_join_will_block
         result.task_join_will_block = \
             _state.task_join_will_block = lambda: False
         self._current_app = current_app()
@@ -446,12 +454,16 @@ class AppCase(Case):
             raise
 
     def _teardown_app(self):
+        from celery import _state
+        from celery import result
         from celery.utils.log import LoggingProxy
         assert sys.stdout
         assert sys.stderr
         assert sys.__stdout__
         assert sys.__stderr__
         this = self._get_test_name()
+        result.task_join_will_block = self._prev_res_join_block
+        _state.task_join_will_block = self._prev_state_join_block
         if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
                 isinstance(sys.__stdout__, (LoggingProxy, Mock)):
             raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
@@ -839,7 +851,49 @@ def skip_if_jython(fun):
     return _inner
 
 
-def task_message_from_sig(app, sig, utc=True):
+def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
+                errbacks=None, chain=None, shadow=None, utc=None, **options):
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {
+        'id': id,
+        'task': name,
+        'shadow': shadow,
+    }
+    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
+    message.headers.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        (args, kwargs, embed), serializer='json',
+    )
+    message.payload = (args, kwargs, embed)
+    return message
+
+
+def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
+                 errbacks=None, chain=None, **options):
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {}
+    message.payload = {
+        'task': name,
+        'id': id,
+        'args': args,
+        'kwargs': kwargs,
+        'callbacks': callbacks,
+        'errbacks': errbacks,
+    }
+    message.payload.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        message.payload,
+    )
+    return message
+
+
+def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
     sig.freeze()
     callbacks = sig.options.pop('link', None)
     errbacks = sig.options.pop('link_error', None)
@@ -862,6 +916,8 @@ def task_message_from_sig(app, sig, utc=True):
         errbacks=[dict(s) for s in errbacks] if errbacks else None,
         eta=eta,
         expires=expires,
+        utc=utc,
+        **sig.options
     )
 
 
@@ -878,22 +934,3 @@ def restore_logging():
         sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
         root.level = level
         root.handlers[:] = handlers
-
-
-def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
-                errbacks=None, chain=None, **options):
-    from celery import uuid
-    from kombu.serialization import dumps
-    id = id or uuid()
-    message = Mock(name='TaskMessage-{0}'.format(id))
-    message.headers = {
-        'id': id,
-        'task': name,
-    }
-    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
-    message.headers.update(options)
-    message.content_type, message.content_encoding, message.body = dumps(
-        (args, kwargs, embed), serializer='json',
-    )
-    message.payload = (args, kwargs, embed)
-    return message

+ 13 - 0
celery/tests/concurrency/test_eventlet.py

@@ -1,5 +1,6 @@
 from __future__ import absolute_import
 
+import os
 import sys
 
 from celery.app.defaults import is_pypy
@@ -43,6 +44,18 @@ class test_aaa_eventlet_patch(EventletCase):
             maybe_patch_concurrency(['x', '-P', 'eventlet'])
             monkey_patch.assert_called_with()
 
+    @patch('eventlet.debug.hub_blocking_detection', create=True)
+    @patch('eventlet.monkey_patch', create=True)
+    def test_aaa_blockdetecet(self, monkey_patch, hub_blocking_detection):
+        os.environ['EVENTLET_NOBLOCK'] = "10.3"
+        try:
+            from celery import maybe_patch_concurrency
+            maybe_patch_concurrency(['x', '-P', 'eventlet'])
+            monkey_patch.assert_called_with()
+            hub_blocking_detection.assert_called_with(10.3, 10.3)
+        finally:
+            os.environ.pop('EVENTLET_NOBLOCK', None)
+
 
 eventlet_modules = (
     'eventlet',

+ 5 - 5
celery/tests/contrib/test_rdb.py

@@ -74,22 +74,22 @@ class test_Rdb(AppCase):
     def test_get_avail_port(self, sock):
         out = WhateverIO()
         sock.return_value.accept.return_value = (Mock(), ['helu'])
-        with Rdb(out=out) as rdb:
+        with Rdb(out=out):
             pass
 
         with patch('celery.contrib.rdb.current_process') as curproc:
             curproc.return_value.name = 'PoolWorker-10'
-            with Rdb(out=out) as rdb:
+            with Rdb(out=out):
                 pass
 
         err = sock.return_value.bind.side_effect = SockErr()
         err.errno = errno.ENOENT
         with self.assertRaises(SockErr):
-            with Rdb(out=out) as rdb:
+            with Rdb(out=out):
                 pass
         err.errno = errno.EADDRINUSE
         with self.assertRaises(Exception):
-            with Rdb(out=out) as rdb:
+            with Rdb(out=out):
                 pass
         called = [0]
 
@@ -101,5 +101,5 @@ class test_Rdb(AppCase):
             finally:
                 called[0] += 1
         sock.return_value.bind.side_effect = effect
-        with Rdb(out=out) as rdb:
+        with Rdb(out=out):
             pass

+ 156 - 2
celery/tests/tasks/test_canvas.py

@@ -1,5 +1,6 @@
 from __future__ import absolute_import
 
+from celery._state import _task_stack
 from celery.canvas import (
     Signature,
     chain,
@@ -210,6 +211,128 @@ class test_chain(CanvasCase):
             repr(x), '%s(2, 2) | %s(2)' % (self.add.name, self.add.name),
         )
 
+    def test_apply_async(self):
+        c = self.add.s(2, 2) | self.add.s(4) | self.add.s(8)
+        result = c.apply_async()
+        self.assertTrue(result.parent)
+        self.assertTrue(result.parent.parent)
+        self.assertIsNone(result.parent.parent.parent)
+
+    def test_group_to_chord__freeze_parent_id(self):
+        def using_freeze(c):
+            c.freeze(parent_id='foo', root_id='root')
+            return c._frozen[0]
+        self.assert_group_to_chord_parent_ids(using_freeze)
+
+    def assert_group_to_chord_parent_ids(self, freezefun):
+        c = (
+            self.add.s(5, 5) |
+            group([self.add.s(i, i) for i in range(5)], app=self.app) |
+            self.add.si(10, 10) |
+            self.add.si(20, 20) |
+            self.add.si(30, 30)
+        )
+        tasks = freezefun(c)
+        self.assertEqual(tasks[-1].parent_id, 'foo')
+        self.assertEqual(tasks[-1].root_id, 'root')
+        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].root_id, 'root')
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].tasks.id)
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
+        self.assertEqual(tasks[-2].body.root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[0].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[0].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[1].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[1].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[2].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[3].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[3].root_id, 'root')
+        self.assertEqual(tasks[-2].tasks.tasks[4].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].tasks.tasks[4].root_id, 'root')
+        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
+        self.assertEqual(tasks[-3].root_id, 'root')
+        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
+        self.assertEqual(tasks[-4].root_id, 'root')
+
+    def test_group_to_chord(self):
+        c = (
+            self.add.s(5) |
+            group([self.add.s(i, i) for i in range(5)], app=self.app) |
+            self.add.s(10) |
+            self.add.s(20) |
+            self.add.s(30)
+        )
+        c._use_link = True
+        tasks, results = c.prepare_steps((), c.tasks)
+
+        self.assertEqual(tasks[-1].args[0], 5)
+        self.assertIsInstance(tasks[-2], chord)
+        self.assertEqual(len(tasks[-2].tasks), 5)
+        self.assertEqual(tasks[-2].parent_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].root_id, tasks[-1].id)
+        self.assertEqual(tasks[-2].body.args[0], 10)
+        self.assertEqual(tasks[-2].body.parent_id, tasks[-2].id)
+
+        self.assertEqual(tasks[-3].args[0], 20)
+        self.assertEqual(tasks[-3].root_id, tasks[-1].id)
+        self.assertEqual(tasks[-3].parent_id, tasks[-2].body.id)
+
+        self.assertEqual(tasks[-4].args[0], 30)
+        self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
+        self.assertEqual(tasks[-4].root_id, tasks[-1].id)
+
+        self.assertTrue(tasks[-2].body.options['link'])
+        self.assertTrue(tasks[-2].body.options['link'][0].options['link'])
+
+        c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
+        c2._use_link = True
+        tasks2, _ = c2.prepare_steps((), c2.tasks)
+        self.assertIsInstance(tasks2[0], group)
+
+    def test_group_to_chord__protocol_2(self):
+        c = (
+            group([self.add.s(i, i) for i in range(5)], app=self.app) |
+            self.add.s(10) |
+            self.add.s(20) |
+            self.add.s(30)
+        )
+        c._use_link = False
+        tasks, _ = c.prepare_steps((), c.tasks)
+        self.assertIsInstance(tasks[-1], chord)
+
+        c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
+        c2._use_link = False
+        tasks2, _ = c2.prepare_steps((), c2.tasks)
+        self.assertIsInstance(tasks2[0], group)
+
+    def test_apply_options(self):
+
+        class static(Signature):
+
+            def clone(self, *args, **kwargs):
+                return self
+
+        def s(*args, **kwargs):
+            return static(self.add, args, kwargs, type=self.add, app=self.app)
+
+        c = s(2, 2) | s(4, 4) | s(8, 8)
+        r1 = c.apply_async(task_id='some_id')
+        self.assertEqual(r1.id, 'some_id')
+
+        c.apply_async(group_id='some_group_id')
+        self.assertEqual(c.tasks[-1].options['group_id'], 'some_group_id')
+
+        c.apply_async(chord='some_chord_id')
+        self.assertEqual(c.tasks[-1].options['chord'], 'some_chord_id')
+
+        c.apply_async(link=[s(32)])
+        self.assertListEqual(c.tasks[-1].options['link'], [s(32)])
+
+        c.apply_async(link_error=[s('error')])
+        for task in c.tasks:
+            self.assertListEqual(task.options['link_error'], [s('error')])
+
     def test_reverse(self):
         x = self.add.s(2, 2) | self.add.s(2)
         self.assertIsInstance(signature(x), chain)
@@ -255,13 +378,12 @@ class test_chain(CanvasCase):
         self.assert_sent_with_ids(tasks[-3], tasks[-1].id, tasks[-2].id)
         self.assert_sent_with_ids(tasks[-4], tasks[-1].id, tasks[-3].id)
 
-
     def assert_sent_with_ids(self, task, rid, pid, **options):
         self.app.amqp.send_task_message = Mock(name='send_task_message')
         self.app.backend = Mock()
         self.app.producer_or_acquire = ContextMock()
 
-        res = task.apply_async(**options)
+        task.apply_async(**options)
         self.assertTrue(self.app.amqp.send_task_message.called)
         message = self.app.amqp.send_task_message.call_args[0][2]
         self.assertEqual(message.headers['parent_id'], pid)
@@ -306,6 +428,38 @@ class test_group(CanvasCase):
             _maybe_group(self.add.s(2, 2), self.app), [self.add.s(2, 2)],
         )
 
+    def test_apply(self):
+        x = group([self.add.s(4, 4), self.add.s(8, 8)])
+        res = x.apply()
+        self.assertEqual(res.get(), [8, 16])
+
+    def test_apply_async(self):
+        x = group([self.add.s(4, 4), self.add.s(8, 8)])
+        x.apply_async()
+
+    def test_apply_empty(self):
+        x = group(app=self.app)
+        x.apply()
+        res = x.apply_async()
+        self.assertFalse(res)
+        self.assertFalse(res.results)
+
+    def test_apply_async_with_parent(self):
+        _task_stack.push(self.add)
+        try:
+            self.add.push_request(called_directly=False)
+            try:
+                assert not self.add.request.children
+                x = group([self.add.s(4, 4), self.add.s(8, 8)])
+                res = x()
+                self.assertTrue(self.add.request.children)
+                self.assertIn(res, self.add.request.children)
+                self.assertEqual(len(self.add.request.children), 1)
+            finally:
+                self.add.pop_request()
+        finally:
+            _task_stack.pop()
+
     def test_from_dict(self):
         x = group([self.add.s(2, 2), self.add.s(4, 4)])
         x['args'] = (2, 2)

+ 9 - 0
celery/tests/tasks/test_chord.py

@@ -79,6 +79,15 @@ class test_unlock_chord_task(ChordCase):
             # did not retry
             self.assertFalse(retry.call_count)
 
+    def test_deps_ready_fails(self):
+        GroupResult = Mock(name='GroupResult')
+        GroupResult.return_value.ready.side_effect = KeyError('foo')
+        unlock_chord = self.app.tasks['celery.chord_unlock']
+
+        with self.assertRaises(KeyError):
+            unlock_chord('groupid', Mock(), result=[Mock()],
+                         GroupResult=GroupResult, result_from_tuple=Mock())
+
     def test_callback_fails(self):
 
         class AlwaysReady(TSR):

+ 216 - 39
celery/tests/worker/test_request.py

@@ -36,7 +36,9 @@ from celery.five import monotonic
 from celery.signals import task_revoked
 from celery.utils import uuid
 from celery.worker import request as module
-from celery.worker.request import Request, logger as req_logger
+from celery.worker.request import (
+    Request, create_request_cls, logger as req_logger,
+)
 from celery.worker.state import revoked
 
 from celery.tests.case import (
@@ -51,6 +53,39 @@ from celery.tests.case import (
 )
 
 
+class RequestCase(AppCase):
+
+    def setup(self):
+        self.app.conf.result_serializer = 'pickle'
+
+        @self.app.task(shared=False)
+        def add(x, y, **kw_):
+            return x + y
+        self.add = add
+
+        @self.app.task(shared=False)
+        def mytask(i, **kwargs):
+            return i ** i
+        self.mytask = mytask
+
+        @self.app.task(shared=False)
+        def mytask_raising(i):
+            raise KeyError(i)
+        self.mytask_raising = mytask_raising
+
+    def xRequest(self, name=None, id=None, args=None, kwargs=None,
+                 on_ack=None, on_reject=None, Request=Request, **head):
+        args = [1] if args is None else args
+        kwargs = {'f': 'x'} if kwargs is None else kwargs
+        on_ack = on_ack or Mock(name='on_ack')
+        on_reject = on_reject or Mock(name='on_reject')
+        message = TaskMessage(
+            name or self.mytask.name, id, args=args, kwargs=kwargs, **head
+        )
+        return Request(message, app=self.app,
+                       on_ack=on_ack, on_reject=on_reject)
+
+
 class test_mro_lookup(Case):
 
     def test_order(self):
@@ -125,7 +160,7 @@ class test_Retry(AppCase):
             self.assertEqual(ret.exc, exc)
 
 
-class test_trace_task(AppCase):
+class test_trace_task(RequestCase):
 
     def setup(self):
 
@@ -162,7 +197,7 @@ class test_trace_task(AppCase):
     def test_marked_as_started(self):
         _started = []
 
-        def store_result(tid, meta, state, **kwars):
+        def store_result(tid, meta, state, **kwargs):
             if state == states.STARTED:
                 _started.append(tid)
         self.mytask.backend.store_result = Mock(name='store_result')
@@ -207,25 +242,7 @@ class MockEventDispatcher(object):
         self.sent.append(event)
 
 
-class test_Request(AppCase):
-
-    def setup(self):
-        self.app.conf.result_serializer = 'pickle'
-
-        @self.app.task(shared=False)
-        def add(x, y, **kw_):
-            return x + y
-        self.add = add
-
-        @self.app.task(shared=False)
-        def mytask(i, **kwargs):
-            return i ** i
-        self.mytask = mytask
-
-        @self.app.task(shared=False)
-        def mytask_raising(i):
-            raise KeyError(i)
-        self.mytask_raising = mytask_raising
+class test_Request(RequestCase):
 
     def get_request(self, sig, Request=Request, **kwargs):
         return Request(
@@ -239,6 +256,12 @@ class test_Request(AppCase):
             **kwargs
         )
 
+    def test_shadow(self):
+        self.assertEqual(
+            self.get_request(self.add.s(2, 2).set(shadow='fooxyz')).name,
+            'fooxyz',
+        )
+
     def test_invalid_eta_raises_InvalidTaskError(self):
         with self.assertRaises(InvalidTaskError):
             self.get_request(self.add.s(2, 2).set(eta='12345'))
@@ -358,18 +381,6 @@ class test_Request(AppCase):
         req._tzlocal = 'foo'
         self.assertEqual(req.tzlocal, 'foo')
 
-    def xRequest(self, name=None, id=None, args=None, kwargs=None,
-                 on_ack=None, on_reject=None, **head):
-        args = [1] if args is None else args
-        kwargs = {'f': 'x'} if kwargs is None else kwargs
-        on_ack = on_ack or Mock(name='on_ack')
-        on_reject = on_reject or Mock(name='on_reject')
-        message = TaskMessage(
-            name or self.mytask.name, id, args=args, kwargs=kwargs, **head
-        )
-        return Request(message, app=self.app,
-                       on_ack=on_ack, on_reject=on_reject)
-
     def test_task_wrapper_repr(self):
         self.assertTrue(repr(self.xRequest()))
 
@@ -414,6 +425,23 @@ class test_Request(AppCase):
         job.task_name = 'NAME'
         self.assertEqual(job.name, 'NAME')
 
+    def test_terminate__pool_ref(self):
+        pool = Mock()
+        signum = signal.SIGTERM
+        job = self.get_request(self.mytask.s(1, f='x'))
+        job._apply_result = Mock(name='_apply_result')
+        with assert_signal_called(
+                task_revoked, sender=job.task, request=job,
+                terminated=True, expired=False, signum=signum):
+            job.time_start = monotonic()
+            job.worker_pid = 314
+            job.terminate(pool, signal='TERM')
+            job._apply_result().terminate.assert_called_with(signum)
+
+            job._apply_result = Mock(name='_apply_result2')
+            job._apply_result.return_value = None
+            job.terminate(pool, signal='TERM')
+
     def test_terminate__task_started(self):
         pool = Mock()
         signum = signal.SIGTERM
@@ -627,6 +655,8 @@ class test_Request(AppCase):
     def test_on_timeout(self, warn, error):
 
         job = self.xRequest()
+        job.acknowledge = Mock(name='ack')
+        job.task.acks_late = True
         job.on_timeout(soft=True, timeout=1337)
         self.assertIn('Soft time limit', warn.call_args[0][0])
         job.on_timeout(soft=False, timeout=1337)
@@ -634,6 +664,7 @@ class test_Request(AppCase):
         self.assertEqual(
             self.mytask.backend.get_status(job.id), states.FAILURE,
         )
+        job.acknowledge.assert_called_with()
 
         self.mytask.ignore_result = True
         job = self.xRequest()
@@ -642,6 +673,12 @@ class test_Request(AppCase):
             self.mytask.backend.get_status(job.id), states.PENDING,
         )
 
+        job = self.xRequest()
+        job.acknowledge = Mock(name='ack')
+        job.task.acks_late = False
+        job.on_timeout(soft=True, timeout=1335)
+        self.assertFalse(job.acknowledge.called)
+
     def test_fast_trace_task(self):
         from celery.app import trace
         setup_worker_optimizations(self.app)
@@ -874,23 +911,163 @@ class test_Request(AppCase):
         self.assertEqual(p.args[1], tid)
         self.assertEqual(p.args[3], job.message.body)
 
-    def _test_on_failure(self, exception):
+    def _test_on_failure(self, exception, **kwargs):
         tid = uuid()
         job = self.xRequest(id=tid, args=[4])
         job.send_event = Mock(name='send_event')
+        job.task.backend.mark_as_failure = Mock(name='mark_as_failure')
         try:
             raise exception
-        except Exception:
+        except type(exception):
             exc_info = ExceptionInfo()
-            job.on_failure(exc_info)
+            job.on_failure(exc_info, **kwargs)
             self.assertTrue(job.send_event.called)
+        return job
 
     def test_on_failure(self):
         self._test_on_failure(Exception('Inside unit tests'))
 
-    def test_on_failure_unicode_exception(self):
+    def test_on_failure__unicode_exception(self):
         self._test_on_failure(Exception('Бобры атакуют'))
 
-    def test_on_failure_utf8_exception(self):
+    def test_on_failure__utf8_exception(self):
         self._test_on_failure(Exception(
             from_utf8('Бобры атакуют')))
+
+    def test_on_failure__WorkerLostError(self):
+        exc = WorkerLostError()
+        job = self._test_on_failure(exc)
+        job.task.backend.mark_as_failure.assert_called_with(
+            job.id, exc, request=job, store_result=True,
+        )
+
+    def test_on_failure__return_ok(self):
+        self._test_on_failure(KeyError(), return_ok=True)
+
+    def test_reject(self):
+        job = self.xRequest(id=uuid())
+        job.on_reject = Mock(name='on_reject')
+        job.acknowleged = False
+        job.reject(requeue=True)
+        job.on_reject.assert_called_with(
+            req_logger, job.connection_errors, True,
+        )
+        self.assertTrue(job.acknowledged)
+        job.on_reject.reset_mock()
+        job.reject(requeue=True)
+        self.assertFalse(job.on_reject.called)
+
+    def test_group(self):
+        gid = uuid()
+        job = self.xRequest(id=uuid(), group=gid)
+        self.assertEqual(job.group, gid)
+
+
+class test_create_request_class(RequestCase):
+
+    def setup(self):
+        RequestCase.setup(self)
+        self.task = Mock(name='task')
+        self.pool = Mock(name='pool')
+        self.eventer = Mock(name='eventer')
+
+    def create_request_cls(self, **kwargs):
+        return create_request_cls(
+            Request, self.task, self.pool, 'foo', self.eventer, **kwargs
+        )
+
+    def zRequest(self, Request=None, revoked_tasks=None, ref=None, **kwargs):
+        return self.xRequest(
+            Request=Request or self.create_request_cls(
+                ref=ref,
+                revoked_tasks=revoked_tasks,
+            ),
+            **kwargs)
+
+    def test_on_success(self):
+        self.zRequest(id=uuid()).on_success((False, "hey", 3.1222))
+
+    def test_on_success__SystemExit(self,
+                                    errors=(SystemExit, KeyboardInterrupt)):
+        for exc in errors:
+            einfo = None
+            try:
+                raise exc()
+            except exc:
+                einfo = ExceptionInfo()
+            with self.assertRaises(exc):
+                self.zRequest(id=uuid()).on_success((True, einfo, 1.0))
+
+    def test_on_success__calls_failure(self):
+        job = self.zRequest(id=uuid())
+        einfo = Mock(name='einfo')
+        job.on_failure = Mock(name='on_failure')
+        job.on_success((True, einfo, 1.0))
+        job.on_failure.assert_called_with(einfo, return_ok=True)
+
+    def test_on_success__acks_late_enabled(self):
+        self.task.acks_late = True
+        job = self.zRequest(id=uuid())
+        job.acknowledge = Mock(name='ack')
+        job.on_success((False, 'foo', 1.0))
+        job.acknowledge.assert_called_with()
+
+    def test_on_success__acks_late_disabled(self):
+        self.task.acks_late = False
+        job = self.zRequest(id=uuid())
+        job.acknowledge = Mock(name='ack')
+        job.on_success((False, 'foo', 1.0))
+        self.assertFalse(job.acknowledge.called)
+
+    def test_on_success__no_events(self):
+        self.eventer = None
+        job = self.zRequest(id=uuid())
+        job.send_event = Mock(name='send_event')
+        job.on_success((False, 'foo', 1.0))
+        self.assertFalse(job.send_event.called)
+
+    def test_on_success__with_events(self):
+        job = self.zRequest(id=uuid())
+        job.send_event = Mock(name='send_event')
+        job.on_success((False, 'foo', 1.0))
+        job.send_event.assert_called_with(
+            'task-succeeded', result='foo', runtime=1.0,
+        )
+
+    def test_execute_using_pool__revoked(self):
+        tid = uuid()
+        job = self.zRequest(id=tid, revoked_tasks={tid})
+        job.revoked = Mock()
+        job.revoked.return_value = True
+        with self.assertRaises(TaskRevokedError):
+            job.execute_using_pool(self.pool)
+
+    def test_execute_using_pool__expired(self):
+        tid = uuid()
+        job = self.zRequest(id=tid, revoked_tasks=set())
+        job.expires = 1232133
+        job.revoked = Mock()
+        job.revoked.return_value = True
+        with self.assertRaises(TaskRevokedError):
+            job.execute_using_pool(self.pool)
+
+    def test_execute_using_pool(self):
+        from celery.app.trace import trace_task_ret as trace
+        weakref_ref = Mock(name='weakref.ref')
+        job = self.zRequest(id=uuid(), revoked_tasks=set(), ref=weakref_ref)
+        job.execute_using_pool(self.pool)
+        self.pool.apply_async.assert_called_with(
+            trace,
+            args=(job.type, job.id, job.request_dict, job.body,
+                  job.content_type, job.content_encoding),
+            accept_callback=job.on_accepted,
+            timeout_callback=job.on_timeout,
+            callback=job.on_success,
+            error_callback=job.on_failure,
+            soft_timeout=self.task.soft_time_limit,
+            timeout=self.task.time_limit,
+            correlation_id=job.id,
+        )
+        self.assertTrue(job._apply_result)
+        weakref_ref.assert_called_with(self.pool.apply_async())
+        self.assertIs(job._apply_result, weakref_ref())

+ 85 - 6
celery/tests/worker/test_strategy.py

@@ -5,13 +5,57 @@ from contextlib import contextmanager
 
 from kombu.utils.limits import TokenBucket
 
+from celery.exceptions import InvalidTaskError
 from celery.worker import state
+from celery.worker.strategy import proto1_to_proto2
 from celery.utils.timeutils import rate
 
-from celery.tests.case import AppCase, Mock, patch, task_message_from_sig
+from celery.tests.case import (
+    AppCase, Mock, TaskMessage, TaskMessage1, patch, task_message_from_sig,
+)
 
 
-class test_default_strategy(AppCase):
+class test_proto1_to_proto2(AppCase):
+
+    def setup(self):
+        self.message = Mock(name='message')
+        self.body = {
+            'args': (1,),
+            'kwargs': {'foo': 'baz'},
+            'utc': False,
+            'taskset': '123',
+        }
+
+    def test_message_without_args(self):
+        self.body.pop('args')
+        with self.assertRaises(InvalidTaskError):
+            proto1_to_proto2(self.message, self.body)
+
+    def test_message_without_kwargs(self):
+        self.body.pop('kwargs')
+        with self.assertRaises(InvalidTaskError):
+            proto1_to_proto2(self.message, self.body)
+
+    def test_message_kwargs_not_mapping(self):
+        self.body['kwargs'] = (2,)
+        with self.assertRaises(InvalidTaskError):
+            proto1_to_proto2(self.message, self.body)
+
+    def test_message_no_taskset_id(self):
+        self.body.pop('taskset')
+        self.assertTrue(proto1_to_proto2(self.message, self.body))
+
+    def test_message(self):
+        body, headers, decoded, utc = proto1_to_proto2(self.message, self.body)
+        self.assertTupleEqual(body, ((1,), {'foo': 'baz'}, {
+            'callbacks': None, 'errbacks': None, 'chord': None, 'chain': None,
+        }))
+        self.assertDictEqual(headers, dict(self.body, group='123'))
+        self.assertTrue(decoded)
+        self.assertFalse(utc)
+
+
+class test_default_strategy_proto2(AppCase):
 
     def setup(self):
         @self.app.task(shared=False)
@@ -20,6 +64,12 @@ class test_default_strategy(AppCase):
 
         self.add = add
 
+    def get_message_class(self):
+        return TaskMessage
+
+    def prepare_message(self, message):
+        return message
+
     class Context(object):
 
         def __init__(self, sig, s, reserved, consumer, message):
@@ -29,10 +79,12 @@ class test_default_strategy(AppCase):
             self.consumer = consumer
             self.message = message
 
-        def __call__(self, **kwargs):
+        def __call__(self, callbacks=[], **kwargs):
             return self.s(
-                self.message, None,
-                self.message.ack, self.message.reject, [], **kwargs
+                self.message,
+                (self.message.payload
+                    if not self.message.headers.get('id') else None),
+                self.message.ack, self.message.reject, callbacks, **kwargs
             )
 
         def was_reserved(self):
@@ -76,7 +128,10 @@ class test_default_strategy(AppCase):
         s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
         self.assertTrue(s)
 
-        message = task_message_from_sig(self.app, sig, utc=utc)
+        message = task_message_from_sig(
+            self.app, sig, utc=utc, TaskMessage=self.get_message_class(),
+        )
+        message = self.prepare_message(message)
         yield self.Context(sig, s, reserved, consumer, message)
 
     def test_when_logging_disabled(self):
@@ -94,6 +149,14 @@ class test_default_strategy(AppCase):
             C.consumer.on_task_request.assert_called_with(req)
             self.assertTrue(C.event_sent())
 
+    def test_callbacks(self):
+        with self._context(self.add.s(2, 2)) as C:
+            callbacks = [Mock(name='cb1'), Mock(name='cb2')]
+            C(callbacks=callbacks)
+            req = C.get_request()
+            for callback in callbacks:
+                callback.assert_called_with(req)
+
     def test_when_events_disabled(self):
         with self._context(self.add.s(2, 2), events=False) as C:
             C()
@@ -136,3 +199,19 @@ class test_default_strategy(AppCase):
                     C.get_request()
         finally:
             state.revoked.discard(task.id)
+
+
+class test_default_strategy_proto1(test_default_strategy_proto2):
+
+    def get_message_class(self):
+        return TaskMessage1
+
+
+class test_default_strategy_proto1__no_utc(test_default_strategy_proto2):
+
+    def get_message_class(self):
+        return TaskMessage1
+
+    def prepare_message(self, message):
+        message.payload['utc'] = False
+        return message

+ 4 - 0
celery/utils/functional.py

@@ -210,6 +210,10 @@ def noop(*args, **kwargs):
     pass
 
 
+def pass1(arg, *args, **kwargs):
+    return arg
+
+
 def evaluate_promises(it):
     for value in it:
         if isinstance(value, promise):

+ 4 - 2
celery/worker/request.py

@@ -81,6 +81,7 @@ class Request(object):
             'on_ack', 'body', 'hostname', 'eventer', 'connection_errors',
             'task', 'eta', 'expires', 'request_dict', 'on_reject', 'utc',
             'content_type', 'content_encoding', 'argsrepr', 'kwargsrepr',
+            '_decoded',
             '__weakref__', '__dict__',
         )
 
@@ -99,6 +100,7 @@ class Request(object):
         self.message = message
         self.body = body
         self.utc = utc
+        self._decoded = decoded
         if decoded:
             self.content_type = self.content_encoding = None
         else:
@@ -111,7 +113,7 @@ class Request(object):
         self.root_id = headers.get('root_id')
         self.parent_id = headers.get('parent_id')
         if 'shadow' in headers:
-            self.name = headers['shadow']
+            self.name = headers['shadow'] or self.name
         if 'timelimit' in headers:
             self.time_limits = headers['timelimit']
         self.argsrepr = headers.get('argsrepr', '')
@@ -460,7 +462,7 @@ class Request(object):
 
     @cached_property
     def _payload(self):
-        return self.message.payload
+        return self.body if self._decoded else self.message.payload
 
     @cached_property
     def chord(self):

+ 7 - 1
celery/worker/strategy.py

@@ -50,7 +50,13 @@ def proto1_to_proto2(message, body):
         body['group'] = body['taskset']
     except KeyError:
         pass
-    return (args, kwargs), body, True, body.get('utc', True)
+    embed = {
+        'callbacks': body.get('callbacks'),
+        'errbacks': body.get('errbacks'),
+        'chord': body.get('chord'),
+        'chain': None,
+    }
+    return (args, kwargs, embed), body, True, body.get('utc', True)
 
 
 def default(task, app, consumer,