Browse Source

Tests passing, more refactoring

Ask Solem 11 years ago
parent
commit
e442df61b2

+ 1 - 1
celery/app/amqp.py

@@ -297,7 +297,7 @@ class AMQP(object):
             headers={
                 'lang': 'py',
                 'c_type': name,
-                'task_id': task_id,
+                'id': task_id,
                 'eta': eta,
                 'expires': expires,
                 'callbacks': callbacks,

+ 64 - 45
celery/app/trace.py

@@ -25,7 +25,7 @@ from warnings import warn
 
 from billiard.einfo import ExceptionInfo
 from kombu.exceptions import EncodeError
-from kombu.serialization import decode as decode_message
+from kombu.serialization import loads as loads_message, prepare_accept_content
 from kombu.utils.encoding import safe_repr, safe_str
 
 from celery import current_app, group
@@ -78,6 +78,22 @@ LOG_RETRY = """\
 Task %(name)s[%(id)s] retry: %(exc)s\
 """
 
+log_policy_t = namedtuple(
+    'log_policy_t', ('format', 'description', 'severity', 'traceback', 'mail'),
+)
+
+log_policy_reject = log_policy_t(LOG_REJECTED, 'rejected', logging.WARN, 1, 1)
+log_policy_ignore = log_policy_t(LOG_IGNORED, 'ignored', logging.INFO, 0, 0)
+log_policy_internal = log_policy_t(
+    LOG_INTERNAL_ERROR, 'INTERNAL ERROR', logging.CRITICAL, 1, 1,
+)
+log_policy_expected = log_policy_t(
+    LOG_FAILURE, 'raised expected', logging.INFO, 0, 0,
+)
+log_policy_unexpected = log_policy_t(
+    LOG_FAILURE, 'raised unexpected', logging.ERROR, 1, 1,
+)
+
 send_prerun = signals.task_prerun.send
 send_postrun = signals.task_postrun.send
 send_success = signals.task_success.send
@@ -91,7 +107,7 @@ EXCEPTION_STATES = states.EXCEPTION_STATES
 IGNORE_STATES = frozenset([IGNORED, RETRY, REJECTED])
 
 #: set by :func:`setup_worker_optimizations`
-_tasks = None
+_localized = []
 _patched = {}
 
 trace_ok_t = namedtuple('trace_ok_t', ('retval', 'info', 'runtime', 'retstr'))
@@ -104,6 +120,19 @@ def task_has_custom(task, attr):
                       monkey_patched=['celery.app.task'])
 
 
+def get_log_policy(task, einfo, exc):
+    if isinstance(exc, Reject):
+        return log_policy_reject
+    elif isinstance(exc, Ignore):
+        return log_policy_ignore
+    elif einfo.internal:
+        return log_policy_internal
+    else:
+        if task.throws and isinstance(exc, task.throws):
+            return log_policy_expected
+        return log_policy_unexpected
+
+
 class TraceInfo(object):
     __slots__ = ('state', 'retval')
 
@@ -172,39 +201,14 @@ class TraceInfo(object):
     def _log_error(self, task, einfo):
         req = task.request
         eobj = einfo.exception = get_pickled_exception(einfo.exception)
-        exception, traceback, exc_info, internal, sargs, skwargs = (
+        exception, traceback, exc_info, sargs, skwargs = (
             safe_repr(eobj),
             safe_str(einfo.traceback),
             einfo.exc_info,
-            einfo.internal,
             safe_repr(req.args),
             safe_repr(req.kwargs),
         )
-        if task.throws and isinstance(eobj, task.throws):
-            do_send_mail, severity, exc_info, description = (
-                False, logging.INFO, None, 'raised expected',
-            )
-        else:
-            do_send_mail, severity, description = (
-                True, logging.ERROR, 'raised unexpected',
-            )
-        format = LOG_FAILURE
-
-        if internal:
-            if isinstance(einfo.exception, Reject):
-                format = LOG_REJECTED
-                description = 'rejected'
-                severity = logging.WARN
-                exc_info = einfo
-            elif isinstance(einfo.exception, Ignore):
-                format = LOG_IGNORED
-                description = 'ignored'
-                severity = logging.INFO
-                exc_info = None
-            else:
-                format = LOG_INTERNAL_ERROR
-                description = 'INTERNAL ERROR'
-                severity = logging.CRITICAL
+        policy = get_log_policy(task, einfo, eobj)
 
         context = {
             'hostname': req.hostname,
@@ -214,15 +218,16 @@ class TraceInfo(object):
             'traceback': traceback,
             'args': sargs,
             'kwargs': skwargs,
-            'description': description,
-            'internal': internal,
+            'description': policy.description,
+            'internal': einfo.internal,
         }
 
-        logger.log(severity, format.strip(), context,
-                   exc_info=exc_info,
+        logger.log(policy.severity, policy.format.strip(), context,
+                   exc_info=exc_info if policy.traceback else None,
                    extra={'data': context})
 
-        task.send_error_email(context, einfo.exception)
+        if policy.mail:
+            task.send_error_email(context, einfo.exception)
 
 
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
@@ -444,14 +449,21 @@ def trace_task(task, uuid, args, kwargs, request={}, **opts):
     try:
         if task.__trace__ is None:
             task.__trace__ = build_tracer(task.name, task, **opts)
-        return task.__trace__(uuid, args, kwargs, request)[0]
+        return task.__trace__(uuid, args, kwargs, request)
     except Exception as exc:
         return report_internal_error(task, exc)
 
 
-def _trace_task_ret(name, uuid, args, kwargs, request={}, app=None, **opts):
-    return trace_task((app or current_app).tasks[name],
-                      uuid, args, kwargs, request, app=app, **opts)
+def _trace_task_ret(name, uuid, request, body, content_type,
+                    content_encoding, loads=loads_message, app=None,
+                    **extra_request):
+    app = app or current_app._get_current_object()
+    accept = prepare_accept_content(app.conf.CELERY_ACCEPT_CONTENT)
+    args, kwargs = loads(body, content_type, content_encoding, accept=accept)
+    request.update(args=args, kwargs=kwargs, **extra_request)
+    R, I, T, Rstr = trace_task(app.tasks[name],
+                               uuid, args, kwargs, request, app=app)
+    return (1, R, T) if I else (0, Rstr, T)
 trace_task_ret = _trace_task_ret
 
 
@@ -460,18 +472,23 @@ def _fast_trace_task_v1(task, uuid, args, kwargs, request={}):
     # so this is the function used in the worker.
     R, I, T, Rstr = _tasks[task].__trace__(uuid, args, kwargs, request)[0]
     # exception instance if error, else result text
-    return (R if I else Rstr), T
+    return (1, R, T) if I else (0, Rstr, T)
 
 
 def _fast_trace_task(task, uuid, request, body, content_type,
-                     content_encoding, decode_message=decode_message,
+                     content_encoding, loads=loads_message, _loc=_localized,
                      **extra_request):
-    args, kwargs = decode_message(body, content_type, content_encoding)
+    tasks, accept = _loc
+    try:
+        args, kwargs = loads(body, content_type, content_encoding,
+                             accept=accept)
+    except Exception as exc:
+        print('OH NOEEES: %r' % (exc, ))
     request.update(args=args, kwargs=kwargs, **extra_request)
-    R, I, T, Rstr = _tasks[task].__trace__(
+    R, I, T, Rstr = tasks[task].__trace__(
         uuid, args, kwargs, request,
     )
-    return (R if I else Rstr), T
+    return (1, R, T) if I else (0, Rstr, T)
 
 
 def report_internal_error(task, exc):
@@ -488,7 +505,6 @@ def report_internal_error(task, exc):
 
 
 def setup_worker_optimizations(app):
-    global _tasks
     global trace_task_ret
 
     # make sure custom Task.__call__ methods that calls super
@@ -508,7 +524,10 @@ def setup_worker_optimizations(app):
     app.finalize()
 
     # set fast shortcut to task registry
-    _tasks = app._tasks
+    _localized[:] = [
+        app._tasks,
+        prepare_accept_content(app.conf.CELERY_ACCEPT_CONTENT),
+    ]
 
     trace_task_ret = _fast_trace_task
     from celery.worker import job as job_module

+ 2 - 4
celery/events/state.py

@@ -30,7 +30,7 @@ from time import time
 from weakref import ref
 
 from kombu.clocks import timetuple
-from kombu.utils import cached_property, kwdict
+from kombu.utils import cached_property
 
 from celery import states
 from celery.five import class_property, items, values
@@ -54,8 +54,6 @@ Substantial drift from %s may mean clocks are out of sync.  Current drift is
 %s seconds.  [orig: %s recv: %s]
 """
 
-CAN_KWDICT = sys.version_info >= (2, 6, 5)
-
 logger = get_logger(__name__)
 warn = logger.warning
 
@@ -86,7 +84,7 @@ def heartbeat_expires(timestamp, freq=60,
 
 
 def _depickle_task(cls, fields):
-    return cls(**(fields if CAN_KWDICT else kwdict(fields)))
+    return cls(**fields)
 
 
 def with_unique_field(attr):

+ 28 - 14
celery/tests/case.py

@@ -48,7 +48,7 @@ from celery.utils.functional import noop
 from celery.utils.imports import qualname
 
 __all__ = [
-    'Case', 'AppCase', 'Mock', 'MagicMock', 'ANY',
+    'Case', 'AppCase', 'Mock', 'MagicMock', 'ANY', 'TaskMessage',
     'patch', 'call', 'sentinel', 'skip_unless_module',
     'wrap_logger', 'with_environ', 'sleepdeprived',
     'skip_if_environ', 'todo', 'skip', 'skip_if',
@@ -56,7 +56,7 @@ __all__ = [
     'replace_module_value', 'sys_platform', 'reset_modules',
     'patch_modules', 'mock_context', 'mock_open', 'patch_many',
     'assert_signal_called', 'skip_if_pypy',
-    'skip_if_jython', 'body_from_sig', 'restore_logging',
+    'skip_if_jython', 'task_message_from_sig', 'restore_logging',
 ]
 patch = mock.patch
 call = mock.call
@@ -819,7 +819,7 @@ def skip_if_jython(fun):
     return _inner
 
 
-def body_from_sig(app, sig, utc=True):
+def task_message_from_sig(app, sig, utc=True):
     sig.freeze()
     callbacks = sig.options.pop('link', None)
     errbacks = sig.options.pop('link_error', None)
@@ -835,17 +835,14 @@ def body_from_sig(app, sig, utc=True):
         expires = app.now() + timedelta(seconds=expires)
     if expires and isinstance(expires, datetime):
         expires = expires.isoformat()
-    return {
-        'task': sig.task,
-        'id': sig.id,
-        'args': sig.args,
-        'kwargs': sig.kwargs,
-        'callbacks': [dict(s) for s in callbacks] if callbacks else None,
-        'errbacks': [dict(s) for s in errbacks] if errbacks else None,
-        'eta': eta,
-        'utc': utc,
-        'expires': expires,
-    }
+    return TaskMessage(
+        sig.task, id=sig.id, args=sig.args,
+        kwargs=sig.kwargs,
+        callbacks=[dict(s) for s in callbacks] if callbacks else None,
+        errbacks=[dict(s) for s in errbacks] if errbacks else None,
+        eta=eta,
+        expires=expires,
+    )
 
 
 @contextmanager
@@ -861,3 +858,20 @@ def restore_logging():
         sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
         root.level = level
         root.handlers[:] = handlers
+
+
+def TaskMessage(name, id=None, args=(), kwargs={}, **options):
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {
+        'id': id,
+        'c_type': name,
+    }
+    message.headers.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        (args, kwargs), serializer='json',
+    )
+    message.payload = (args, kwargs)
+    return message

+ 4 - 3
celery/tests/tasks/test_trace.py

@@ -14,10 +14,11 @@ from celery.app.trace import (
 from celery.tests.case import AppCase, Mock, patch
 
 
-def trace(app, task, args=(), kwargs={}, propagate=False, **opts):
+def trace(app, task, args=(), kwargs={},
+          propagate=False, eager=True, request=None, **opts):
     t = build_tracer(task.name, task,
-                     eager=True, propagate=propagate, app=app, **opts)
-    ret = t('id-1', args, kwargs, None)
+                     eager=eager, propagate=propagate, app=app, **opts)
+    ret = t('id-1', args, kwargs, request)
     return ret.retval, ret.info
 
 

+ 6 - 21
celery/tests/worker/test_control.py

@@ -21,7 +21,7 @@ from celery.worker.state import revoked
 from celery.worker.control import Panel
 from celery.worker.pidbox import Pidbox, gPidbox
 
-from celery.tests.case import AppCase, Mock, call, patch
+from celery.tests.case import AppCase, Mock, TaskMessage, call, patch
 
 hostname = socket.gethostname()
 
@@ -250,12 +250,7 @@ class test_ControlPanel(AppCase):
         self.panel.handle('report')
 
     def test_active(self):
-        r = Request({
-            'task': self.mytask.name,
-            'id': 'do re mi',
-            'args': (),
-            'kwargs': {},
-        }, app=self.app)
+        r = Request(TaskMessage(self.mytask.name, 'do re mi'), app=self.app)
         worker_state.active_requests.add(r)
         try:
             self.assertTrue(self.panel.handle('dump_active'))
@@ -347,12 +342,7 @@ class test_ControlPanel(AppCase):
         consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         self.assertFalse(panel.handle('dump_schedule'))
-        r = Request({
-            'task': self.mytask.name,
-            'id': 'CAFEBABE',
-            'args': (),
-            'kwargs': {},
-        }, app=self.app)
+        r = Request(TaskMessage(self.mytask.name, 'CAFEBABE'), app=self.app)
         consumer.timer.schedule.enter_at(
             consumer.timer.Entry(lambda x: x, (r, )),
             datetime.now() + timedelta(seconds=10))
@@ -363,19 +353,14 @@ class test_ControlPanel(AppCase):
 
     def test_dump_reserved(self):
         consumer = Consumer(self.app)
-        worker_state.reserved_requests.add(Request({
-            'task': self.mytask.name,
-            'id': uuid(),
-            'args': (2, 2),
-            'kwargs': {},
-        }, app=self.app))
+        worker_state.reserved_requests.add(
+            Request(TaskMessage(self.mytask.name, args=(2, 2)), app=self.app),
+        )
         try:
             panel = self.create_panel(consumer=consumer)
             response = panel.handle('dump_reserved', {'safe': True})
             self.assertDictContainsSubset(
                 {'name': self.mytask.name,
-                 'args': (2, 2),
-                 'kwargs': {},
                  'hostname': socket.gethostname()},
                 response[0],
             )

+ 21 - 22
celery/tests/worker/test_loops.py

@@ -11,7 +11,7 @@ from celery.worker import state
 from celery.worker.consumer import Consumer
 from celery.worker.loops import asynloop, synloop
 
-from celery.tests.case import AppCase, Mock, body_from_sig
+from celery.tests.case import AppCase, Mock, task_message_from_sig
 
 
 class X(object):
@@ -107,7 +107,7 @@ def get_task_callback(*args, **kwargs):
     x = X(*args, **kwargs)
     x.blueprint.state = CLOSE
     asynloop(*x.args)
-    return x, x.consumer.callbacks[0]
+    return x, x.consumer.on_message
 
 
 class test_asynloop(AppCase):
@@ -132,45 +132,44 @@ class test_asynloop(AppCase):
 
     def task_context(self, sig, **kwargs):
         x, on_task = get_task_callback(self.app, **kwargs)
-        body = body_from_sig(self.app, sig)
-        message = Mock()
-        strategy = x.obj.strategies[sig.task] = Mock()
-        return x, on_task, body, message, strategy
+        message = task_message_from_sig(self.app, sig)
+        strategy = x.obj.strategies[sig.task] = Mock(name='strategy')
+        return x, on_task, message, strategy
 
     def test_on_task_received(self):
-        _, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
-        on_task(body, msg)
+        _, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
+        on_task(msg)
         strategy.assert_called_with(
-            msg, body, msg.ack_log_error, msg.reject_log_error, [],
+            msg, None, msg.ack_log_error, msg.reject_log_error, [],
         )
 
     def test_on_task_received_executes_on_task_message(self):
         cbs = [Mock(), Mock(), Mock()]
-        _, on_task, body, msg, strategy = self.task_context(
+        _, on_task, msg, strategy = self.task_context(
             self.add.s(2, 2), on_task_message=cbs,
         )
-        on_task(body, msg)
+        on_task(msg)
         strategy.assert_called_with(
-            msg, body, msg.ack_log_error, msg.reject_log_error, cbs,
+            msg, None, msg.ack_log_error, msg.reject_log_error, cbs,
         )
 
     def test_on_task_message_missing_name(self):
-        x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
-        body.pop('task')
-        on_task(body, msg)
-        x.on_unknown_message.assert_called_with(body, msg)
+        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
+        msg.headers.pop('c_type')
+        on_task(msg)
+        x.on_unknown_message.assert_called_with(((2, 2), {}), msg)
 
     def test_on_task_not_registered(self):
-        x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
+        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
         exc = strategy.side_effect = KeyError(self.add.name)
-        on_task(body, msg)
-        x.on_unknown_task.assert_called_with(body, msg, exc)
+        on_task(msg)
+        x.on_unknown_task.assert_called_with(None, msg, exc)
 
     def test_on_task_InvalidTaskError(self):
-        x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
+        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
         exc = strategy.side_effect = InvalidTaskError()
-        on_task(body, msg)
-        x.on_invalid_task.assert_called_with(body, msg, exc)
+        on_task(msg)
+        x.on_invalid_task.assert_called_with(None, msg, exc)
 
     def test_should_terminate(self):
         x = X(self.app)

+ 149 - 251
celery/tests/worker/test_request.py

@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from __future__ import absolute_import, unicode_literals
 
-import anyjson
+import numbers
 import os
 import signal
 import socket
@@ -10,7 +10,6 @@ import sys
 from datetime import datetime, timedelta
 
 from billiard.einfo import ExceptionInfo
-from kombu.transport.base import Message
 from kombu.utils.encoding import from_utf8, default_encode
 
 from celery import states
@@ -27,12 +26,13 @@ from celery.concurrency.base import BasePool
 from celery.exceptions import (
     Ignore,
     InvalidTaskError,
+    Reject,
     Retry,
     TaskRevokedError,
     Terminated,
     WorkerLostError,
 )
-from celery.five import keys, monotonic
+from celery.five import monotonic
 from celery.signals import task_revoked
 from celery.utils import uuid
 from celery.worker import job as module
@@ -44,8 +44,9 @@ from celery.tests.case import (
     Case,
     Mock,
     SkipTest,
+    TaskMessage,
     assert_signal_called,
-    body_from_sig,
+    task_message_from_sig,
     patch,
 )
 
@@ -85,7 +86,7 @@ def jail(app, task_id, name, args, kwargs):
     task.__trace__ = None  # rebuild
     return trace_task(
         task, task_id, args, kwargs, request=request, eager=False, app=app,
-    )
+    ).retval
 
 
 class test_default_encode(AppCase):
@@ -138,7 +139,7 @@ class test_trace_task(AppCase):
             raise KeyError(i)
         self.mytask_raising = mytask_raising
 
-    @patch('celery.app.trace._logger')
+    @patch('celery.app.trace.logger')
     def test_process_cleanup_fails(self, _logger):
         self.mytask.backend = Mock()
         self.mytask.backend.process_cleanup = Mock(side_effect=KeyError())
@@ -227,9 +228,10 @@ class test_Request(AppCase):
 
     def get_request(self, sig, Request=Request, **kwargs):
         return Request(
-            body_from_sig(self.app, sig),
-            on_ack=Mock(),
-            eventer=Mock(),
+            task_message_from_sig(self.app, sig),
+            on_ack=Mock(name='on_ack'),
+            on_reject=Mock(name='on_reject'),
+            eventer=Mock(name='eventer'),
             app=self.app,
             connection_errors=(socket.error, ),
             task=sig.type,
@@ -273,7 +275,7 @@ class test_Request(AppCase):
             uuid=req.id, terminated=True, signum='9', expired=False,
         )
 
-    def test_log_error_propagates_MemoryError(self):
+    def test_on_failure_propagates_MemoryError(self):
         einfo = None
         try:
             raise MemoryError()
@@ -282,9 +284,9 @@ class test_Request(AppCase):
         self.assertIsNotNone(einfo)
         req = self.get_request(self.add.s(2, 2))
         with self.assertRaises(MemoryError):
-            req._log_error(einfo)
+            req.on_failure(einfo)
 
-    def test_log_error_when_Ignore(self):
+    def test_on_failure_Ignore_acknowledges(self):
         einfo = None
         try:
             raise Ignore()
@@ -292,48 +294,55 @@ class test_Request(AppCase):
             einfo = ExceptionInfo(internal=True)
         self.assertIsNotNone(einfo)
         req = self.get_request(self.add.s(2, 2))
-        req._log_error(einfo)
+        req.on_failure(einfo)
         req.on_ack.assert_called_with(req_logger, req.connection_errors)
 
+    def test_on_failure_Reject_rejects(self):
+        einfo = None
+        try:
+            raise Reject()
+        except Reject:
+            einfo = ExceptionInfo(internal=True)
+        self.assertIsNotNone(einfo)
+        req = self.get_request(self.add.s(2, 2))
+        req.on_failure(einfo)
+        req.on_reject.assert_called_with(
+            req_logger, req.connection_errors, False,
+        )
+
+    def test_on_failure_Reject_rejects_with_requeue(self):
+        einfo = None
+        try:
+            raise Reject(requeue=True)
+        except Reject:
+            einfo = ExceptionInfo(internal=True)
+        self.assertIsNotNone(einfo)
+        req = self.get_request(self.add.s(2, 2))
+        req.on_failure(einfo)
+        req.on_reject.assert_called_with(
+            req_logger, req.connection_errors, True,
+        )
+
     def test_tzlocal_is_cached(self):
         req = self.get_request(self.add.s(2, 2))
         req._tzlocal = 'foo'
         self.assertEqual(req.tzlocal, 'foo')
 
-    def test_execute_magic_kwargs(self):
-        task = self.add.s(2, 2)
-        task.freeze()
-        req = self.get_request(task)
-        self.add.accept_magic_kwargs = True
-        pool = Mock()
-        req.execute_using_pool(pool)
-        self.assertTrue(pool.apply_async.called)
-        args = pool.apply_async.call_args[1]['args']
-        self.assertEqual(args[0], task.task)
-        self.assertEqual(args[1], task.id)
-        self.assertEqual(args[2], task.args)
-        kwargs = args[3]
-        self.assertEqual(kwargs.get('task_name'), task.task)
-
-    def xRequest(self, body=None, **kwargs):
-        body = dict({'task': self.mytask.name,
-                     'id': uuid(),
-                     'args': [1],
-                     'kwargs': {'f': 'x'}}, **body or {})
-        return Request(body, app=self.app, **kwargs)
+    def xRequest(self, name=None, id=None, args=None, kwargs=None,
+                 on_ack=None, on_reject=None, **head):
+        args = [1] if args is None else args
+        kwargs = {'f': 'x'} if kwargs is None else kwargs
+        on_ack = on_ack or Mock(name='on_ack')
+        on_reject = on_reject or Mock(name='on_reject')
+        message = TaskMessage(
+            name or self.mytask.name, id, args=args, kwargs=kwargs, **head
+        )
+        return Request(message, app=self.app,
+                       on_ack=on_ack, on_reject=on_reject)
 
     def test_task_wrapper_repr(self):
         self.assertTrue(repr(self.xRequest()))
 
-    @patch('celery.worker.job.kwdict')
-    def test_kwdict(self, kwdict):
-        prev, module.NEEDS_KWDICT = module.NEEDS_KWDICT, True
-        try:
-            self.xRequest()
-            self.assertTrue(kwdict.called)
-        finally:
-            module.NEEDS_KWDICT = prev
-
     def test_sets_store_errors(self):
         self.mytask.ignore_result = True
         job = self.xRequest()
@@ -350,12 +359,7 @@ class test_Request(AppCase):
         self.assertIn('task-frobulated', job.eventer.sent)
 
     def test_on_retry(self):
-        job = Request({
-            'task': self.mytask.name,
-            'id': uuid(),
-            'args': [1],
-            'kwargs': {'f': 'x'},
-        }, app=self.app)
+        job = self.get_request(self.mytask.s(1, f='x'))
         job.eventer = MockEventDispatcher()
         try:
             raise Retry('foo', KeyError('moofoobar'))
@@ -372,12 +376,7 @@ class test_Request(AppCase):
             job.on_failure(einfo)
 
     def test_compat_properties(self):
-        job = Request({
-            'task': self.mytask.name,
-            'id': uuid(),
-            'args': [1],
-            'kwargs': {'f': 'x'},
-        }, app=self.app)
+        job = self.xRequest()
         self.assertEqual(job.task_id, job.id)
         self.assertEqual(job.task_name, job.name)
         job.task_id = 'ID'
@@ -388,12 +387,7 @@ class test_Request(AppCase):
     def test_terminate__task_started(self):
         pool = Mock()
         signum = signal.SIGTERM
-        job = Request({
-            'task': self.mytask.name,
-            'id': uuid(),
-            'args': [1],
-            'kwrgs': {'f': 'x'},
-        }, app=self.app)
+        job = self.get_request(self.mytask.s(1, f='x'))
         with assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
@@ -404,12 +398,7 @@ class test_Request(AppCase):
 
     def test_terminate__task_reserved(self):
         pool = Mock()
-        job = Request({
-            'task': self.mytask.name,
-            'id': uuid(),
-            'args': [1],
-            'kwargs': {'f': 'x'},
-        }, app=self.app)
+        job = self.get_request(self.mytask.s(1, f='x'))
         job.time_start = None
         job.terminate(pool, signal='TERM')
         self.assertFalse(pool.terminate_job.called)
@@ -417,13 +406,9 @@ class test_Request(AppCase):
         job.terminate(pool, signal='TERM')
 
     def test_revoked_expires_expired(self):
-        job = Request({
-            'task': self.mytask.name,
-            'id': uuid(),
-            'args': [1],
-            'kwargs': {'f': 'x'},
-            'expires': datetime.utcnow() - timedelta(days=1),
-        }, app=self.app)
+        job = self.get_request(self.mytask.s(1, f='x').set(
+            expires=datetime.utcnow() - timedelta(days=1)
+        ))
         with assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=False, expired=True, signum=None):
@@ -435,9 +420,9 @@ class test_Request(AppCase):
             )
 
     def test_revoked_expires_not_expired(self):
-        job = self.xRequest({
-            'expires': datetime.utcnow() + timedelta(days=1),
-        })
+        job = self.xRequest(
+            expires=datetime.utcnow() + timedelta(days=1),
+        )
         job.revoked()
         self.assertNotIn(job.id, revoked)
         self.assertNotEqual(
@@ -447,47 +432,15 @@ class test_Request(AppCase):
 
     def test_revoked_expires_ignore_result(self):
         self.mytask.ignore_result = True
-        job = self.xRequest({
-            'expires': datetime.utcnow() - timedelta(days=1),
-        })
+        job = self.xRequest(
+            expires=datetime.utcnow() - timedelta(days=1),
+        )
         job.revoked()
         self.assertIn(job.id, revoked)
         self.assertNotEqual(
             self.mytask.backend.get_status(job.id), states.REVOKED,
         )
 
-    def test_send_email(self):
-        app = self.app
-        mail_sent = [False]
-
-        def mock_mail_admins(*args, **kwargs):
-            mail_sent[0] = True
-
-        def get_ei():
-            try:
-                raise KeyError('moofoobar')
-            except:
-                return ExceptionInfo()
-
-        app.mail_admins = mock_mail_admins
-        self.mytask.send_error_emails = True
-        job = self.xRequest()
-        einfo = get_ei()
-        job.on_failure(einfo)
-        self.assertTrue(mail_sent[0])
-
-        einfo = get_ei()
-        mail_sent[0] = False
-        self.mytask.send_error_emails = False
-        job.on_failure(einfo)
-        self.assertFalse(mail_sent[0])
-
-        einfo = get_ei()
-        mail_sent[0] = False
-        self.mytask.send_error_emails = True
-        job.on_failure(einfo)
-        self.assertTrue(mail_sent[0])
-
     def test_already_revoked(self):
         job = self.xRequest()
         job._already_revoked = True
@@ -510,10 +463,10 @@ class test_Request(AppCase):
 
     def test_execute_acks_late(self):
         self.mytask_raising.acks_late = True
-        job = self.xRequest({
-            'task': self.mytask_raising.name,
-            'kwargs': {},
-        })
+        job = self.xRequest(
+            name=self.mytask_raising.name,
+            kwargs={},
+        )
         job.execute()
         self.assertTrue(job.acknowledged)
         job.execute()
@@ -555,10 +508,10 @@ class test_Request(AppCase):
     def test_on_success_acks_early(self):
         job = self.xRequest()
         job.time_start = 1
-        job.on_success(42)
+        job.on_success((0, 42, 0.001))
         prev, module._does_info = module._does_info, False
         try:
-            job.on_success(42)
+            job.on_success((0, 42, 0.001))
             self.assertFalse(job.acknowledged)
         finally:
             module._does_info = prev
@@ -570,7 +523,7 @@ class test_Request(AppCase):
             try:
                 raise SystemExit()
             except SystemExit:
-                job.on_success(ExceptionInfo())
+                job.on_success((1, ExceptionInfo(), 0.01))
             else:
                 assert False
 
@@ -579,7 +532,7 @@ class test_Request(AppCase):
         job.time_start = 1
         job.eventer = Mock()
         job.eventer.send = Mock()
-        job.on_success(42)
+        job.on_success((0, 42, 0.001))
         self.assertTrue(job.eventer.send.called)
 
     def test_on_success_when_failure(self):
@@ -589,14 +542,14 @@ class test_Request(AppCase):
         try:
             raise KeyError('foo')
         except Exception:
-            job.on_success(ExceptionInfo())
+            job.on_success((1, ExceptionInfo(), 0.001))
             self.assertTrue(job.on_failure.called)
 
     def test_on_success_acks_late(self):
         job = self.xRequest()
         job.time_start = 1
         self.mytask.acks_late = True
-        job.on_success(42)
+        job.on_success((0, 42, 0.001))
         self.assertTrue(job.acknowledged)
 
     def test_on_failure_WorkerLostError(self):
@@ -634,9 +587,10 @@ class test_Request(AppCase):
             self.assertTrue(job.acknowledged)
 
     def test_from_message_invalid_kwargs(self):
-        body = dict(task=self.mytask.name, id=1, args=(), kwargs='foo')
+        m = TaskMessage(self.mytask.name, args=(), kwargs='foo')
+        req = Request(m, app=self.app)
         with self.assertRaises(InvalidTaskError):
-            Request(body, message=None, app=self.app)
+            raise req.execute().exception
 
     @patch('celery.worker.job.error')
     @patch('celery.worker.job.warn')
@@ -662,37 +616,60 @@ class test_Request(AppCase):
         from celery.app import trace
         setup_worker_optimizations(self.app)
         self.assertIs(trace.trace_task_ret, trace._fast_trace_task)
+        tid = uuid()
+        message = TaskMessage(self.mytask.name, tid, args=[4])
         try:
             self.mytask.__trace__ = build_tracer(
                 self.mytask.name, self.mytask, self.app.loader, 'test',
                 app=self.app,
             )
-            res = trace.trace_task_ret(self.mytask.name, uuid(), [4], {})
-            self.assertEqual(res, 4 ** 4)
+            failed, res, runtime = trace.trace_task_ret(
+                self.mytask.name, tid, message.headers, message.body,
+                message.content_type, message.content_encoding)
+            self.assertFalse(failed)
+            self.assertEqual(res, repr(4 ** 4))
+            self.assertTrue(runtime)
+            self.assertIsInstance(runtime, numbers.Real)
         finally:
             reset_worker_optimizations()
             self.assertIs(trace.trace_task_ret, trace._trace_task_ret)
         delattr(self.mytask, '__trace__')
-        res = trace.trace_task_ret(
-            self.mytask.name, uuid(), [4], {}, app=self.app,
+        failed, res, runtime = trace.trace_task_ret(
+            self.mytask.name, tid, message.headers, message.body,
+            message.content_type, message.content_encoding, app=self.app,
         )
-        self.assertEqual(res, 4 ** 4)
+        self.assertFalse(failed)
+        self.assertEqual(res, repr(4 ** 4))
+        self.assertTrue(runtime)
+        self.assertIsInstance(runtime, numbers.Real)
 
     def test_trace_task_ret(self):
         self.mytask.__trace__ = build_tracer(
             self.mytask.name, self.mytask, self.app.loader, 'test',
             app=self.app,
         )
-        res = _trace_task_ret(self.mytask.name, uuid(), [4], {}, app=self.app)
-        self.assertEqual(res, 4 ** 4)
+        tid = uuid()
+        message = TaskMessage(self.mytask.name, tid, args=[4])
+        _, R, _ = _trace_task_ret(
+            self.mytask.name, tid, message.headers,
+            message.body, message.content_type,
+            message.content_encoding, app=self.app,
+        )
+        self.assertEqual(R, repr(4 ** 4))
 
     def test_trace_task_ret__no_trace(self):
         try:
             delattr(self.mytask, '__trace__')
         except AttributeError:
             pass
-        res = _trace_task_ret(self.mytask.name, uuid(), [4], {}, app=self.app)
-        self.assertEqual(res, 4 ** 4)
+        tid = uuid()
+        message = TaskMessage(self.mytask.name, tid, args=[4])
+        _, R, _ = _trace_task_ret(
+            self.mytask.name, tid, message.headers,
+            message.body, message.content_type,
+            message.content_encoding, app=self.app,
+        )
+        self.assertEqual(R, repr(4 ** 4))
 
     def test_trace_catches_exception(self):
 
@@ -705,7 +682,7 @@ class test_Request(AppCase):
 
         with self.assertWarnsRegex(RuntimeWarning,
                                    r'Exception raised outside'):
-            res = trace_task(raising, uuid(), [], {}, app=self.app)
+            res = trace_task(raising, uuid(), [], {}, app=self.app)[0]
             self.assertIsInstance(res, ExceptionInfo)
 
     def test_worker_task_trace_handle_retry(self):
@@ -749,71 +726,39 @@ class test_Request(AppCase):
         finally:
             self.mytask.pop_request()
 
-    def test_task_wrapper_mail_attrs(self):
-        job = self.xRequest({'args': [], 'kwargs': {}})
-        x = job.success_msg % {
-            'name': job.name,
-            'id': job.id,
-            'return_value': 10,
-            'runtime': 0.3641,
-        }
-        self.assertTrue(x)
-        x = job.error_msg % {
-            'name': job.name,
-            'id': job.id,
-            'exc': 'FOOBARBAZ',
-            'description': 'raised unexpected',
-            'traceback': 'foobarbaz',
-        }
-        self.assertTrue(x)
-
     def test_from_message(self):
         us = 'æØåveéðƒeæ'
-        body = {'task': self.mytask.name, 'id': uuid(),
-                'args': [2], 'kwargs': {us: 'bar'}}
-        m = Message(None, body=anyjson.dumps(body), backend='foo',
-                    content_type='application/json',
-                    content_encoding='utf-8')
-        job = Request(m.decode(), message=m, app=self.app)
+        tid = uuid()
+        m = TaskMessage(self.mytask.name, tid, args=[2], kwargs={us: 'bar'})
+        job = Request(m, app=self.app)
         self.assertIsInstance(job, Request)
-        self.assertEqual(job.name, body['task'])
-        self.assertEqual(job.id, body['id'])
-        self.assertEqual(job.args, body['args'])
-        us = from_utf8(us)
-        if sys.version_info < (2, 6):
-            self.assertEqual(next(keys(job.kwargs)), us)
-            self.assertIsInstance(next(keys(job.kwargs)), str)
+        self.assertEqual(job.name, self.mytask.name)
+        self.assertEqual(job.id, tid)
+        self.assertIs(job.message, m)
 
     def test_from_message_empty_args(self):
-        body = {'task': self.mytask.name, 'id': uuid()}
-        m = Message(None, body=anyjson.dumps(body), backend='foo',
-                    content_type='application/json',
-                    content_encoding='utf-8')
-        job = Request(m.decode(), message=m, app=self.app)
+        tid = uuid()
+        m = TaskMessage(self.mytask.name, tid, args=[], kwargs={})
+        job = Request(m, app=self.app)
         self.assertIsInstance(job, Request)
-        self.assertEqual(job.args, [])
-        self.assertEqual(job.kwargs, {})
 
     def test_from_message_missing_required_fields(self):
-        body = {}
-        m = Message(None, body=anyjson.dumps(body), backend='foo',
-                    content_type='application/json',
-                    content_encoding='utf-8')
+        m = TaskMessage(self.mytask.name)
+        m.headers.clear()
         with self.assertRaises(KeyError):
-            Request(m.decode(), message=m, app=self.app)
+            Request(m, app=self.app)
 
     def test_from_message_nonexistant_task(self):
-        body = {'task': 'cu.mytask.doesnotexist', 'id': uuid(),
-                'args': [2], 'kwargs': {'æØåveéðƒeæ': 'bar'}}
-        m = Message(None, body=anyjson.dumps(body), backend='foo',
-                    content_type='application/json',
-                    content_encoding='utf-8')
+        m = TaskMessage(
+            'cu.mytask.doesnotexist',
+            args=[2], kwargs={'æØåveéðƒeæ': 'bar'},
+        )
         with self.assertRaises(KeyError):
-            Request(m.decode(), message=m, app=self.app)
+            Request(m, app=self.app)
 
     def test_execute(self):
         tid = uuid()
-        job = self.xRequest({'id': tid, 'args': [4], 'kwargs': {}})
+        job = self.xRequest(id=tid, args=[4], kwargs={})
         self.assertEqual(job.execute(), 256)
         meta = self.mytask.backend.get_task_meta(tid)
         self.assertEqual(meta['status'], states.SUCCESS)
@@ -826,38 +771,17 @@ class test_Request(AppCase):
             return i ** i
 
         tid = uuid()
-        job = self.xRequest({
-            'task': mytask_no_kwargs.name,
-            'id': tid,
-            'args': [4],
-            'kwargs': {},
-        })
+        job = self.xRequest(
+            name=mytask_no_kwargs.name,
+            id=tid,
+            args=[4],
+            kwargs={},
+        )
         self.assertEqual(job.execute(), 256)
         meta = mytask_no_kwargs.backend.get_task_meta(tid)
         self.assertEqual(meta['result'], 256)
         self.assertEqual(meta['status'], states.SUCCESS)
 
-    def test_execute_success_some_kwargs(self):
-        scratch = {'task_id': None}
-
-        @self.app.task(shared=False, accept_magic_kwargs=True)
-        def mytask_some_kwargs(i, task_id):
-            scratch['task_id'] = task_id
-            return i ** i
-
-        tid = uuid()
-        job = self.xRequest({
-            'task': mytask_some_kwargs.name,
-            'id': tid,
-            'args': [4],
-            'kwargs': {},
-        })
-        self.assertEqual(job.execute(), 256)
-        meta = mytask_some_kwargs.backend.get_task_meta(tid)
-        self.assertEqual(scratch.get('task_id'), tid)
-        self.assertEqual(meta['result'], 256)
-        self.assertEqual(meta['status'], states.SUCCESS)
-
     def test_execute_ack(self):
         scratch = {'ACK': False}
 
@@ -865,7 +789,7 @@ class test_Request(AppCase):
             scratch['ACK'] = True
 
         tid = uuid()
-        job = self.xRequest({'id': tid, 'args': [4]}, on_ack=on_ack)
+        job = self.xRequest(id=tid, args=[4], on_ack=on_ack)
         self.assertEqual(job.execute(), 256)
         meta = self.mytask.backend.get_task_meta(tid)
         self.assertTrue(scratch['ACK'])
@@ -874,12 +798,13 @@ class test_Request(AppCase):
 
     def test_execute_fail(self):
         tid = uuid()
-        job = self.xRequest({
-            'task': self.mytask_raising.name,
-            'id': tid,
-            'args': [4],
-            'kwargs': {},
-        })
+        job = self.xRequest(
+            name=self.mytask_raising.name,
+            id=tid,
+            args=[4],
+            kwargs={},
+        )
+        print(job.execute())
         self.assertIsInstance(job.execute(), ExceptionInfo)
         meta = self.mytask_raising.backend.get_task_meta(tid)
         self.assertEqual(meta['status'], states.FAILURE)
@@ -887,7 +812,7 @@ class test_Request(AppCase):
 
     def test_execute_using_pool(self):
         tid = uuid()
-        job = self.xRequest({'id': tid, 'args': [4]})
+        job = self.xRequest(id=tid, args=[4])
 
         class MockPool(BasePool):
             target = None
@@ -908,48 +833,21 @@ class test_Request(AppCase):
         self.assertTrue(p.target)
         self.assertEqual(p.args[0], self.mytask.name)
         self.assertEqual(p.args[1], tid)
-        self.assertEqual(p.args[2], [4])
-        self.assertIn('f', p.args[3])
-        self.assertIn([4], p.args)
+        self.assertEqual(p.args[3], job.message.body)
 
         job.task.accept_magic_kwargs = False
         job.execute_using_pool(p)
 
-    def test_default_kwargs(self):
-        self.maxDiff = 3000
-        tid = uuid()
-        job = self.xRequest({'id': tid, 'args': [4]})
-        self.assertDictEqual(
-            job.extend_with_default_kwargs(), {
-                'f': 'x',
-                'logfile': None,
-                'loglevel': None,
-                'task_id': job.id,
-                'task_retries': 0,
-                'task_is_eager': False,
-                'delivery_info': {
-                    'exchange': None,
-                    'routing_key': None,
-                    'priority': 0,
-                    'redelivered': False,
-                },
-                'task_name': job.name})
-
-    @patch('celery.worker.job.logger')
-    def _test_on_failure(self, exception, logger):
-        app = self.app
+    def _test_on_failure(self, exception):
         tid = uuid()
-        job = self.xRequest({'id': tid, 'args': [4]})
+        job = self.xRequest(id=tid, args=[4])
+        job.send_event = Mock(name='send_event')
         try:
             raise exception
         except Exception:
             exc_info = ExceptionInfo()
-            app.conf.CELERY_SEND_TASK_ERROR_EMAILS = True
             job.on_failure(exc_info)
-            self.assertTrue(logger.log.called)
-            context = logger.log.call_args[0][2]
-            self.assertEqual(self.mytask.name, context['name'])
-            self.assertIn(tid, context['id'])
+            self.assertTrue(job.send_event.called)
 
     def test_on_failure(self):
         self._test_on_failure(Exception('Inside unit tests'))

+ 5 - 8
celery/tests/worker/test_strategy.py

@@ -8,7 +8,7 @@ from kombu.utils.limits import TokenBucket
 from celery.worker import state
 from celery.utils.timeutils import rate
 
-from celery.tests.case import AppCase, Mock, patch, body_from_sig
+from celery.tests.case import AppCase, Mock, patch, task_message_from_sig
 
 
 class test_default_strategy(AppCase):
@@ -22,17 +22,16 @@ class test_default_strategy(AppCase):
 
     class Context(object):
 
-        def __init__(self, sig, s, reserved, consumer, message, body):
+        def __init__(self, sig, s, reserved, consumer, message):
             self.sig = sig
             self.s = s
             self.reserved = reserved
             self.consumer = consumer
             self.message = message
-            self.body = body
 
         def __call__(self, **kwargs):
             return self.s(
-                self.message, self.body,
+                self.message, None,
                 self.message.ack, self.message.reject, [], **kwargs
             )
 
@@ -76,10 +75,8 @@ class test_default_strategy(AppCase):
         s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
         self.assertTrue(s)
 
-        message = Mock()
-        body = body_from_sig(self.app, sig, utc=utc)
-
-        yield self.Context(sig, s, reserved, consumer, message, body)
+        message = task_message_from_sig(self.app, sig, utc=utc)
+        yield self.Context(sig, s, reserved, consumer, message)
 
     def test_when_logging_disabled(self):
         with patch('celery.worker.strategy.logger') as logger:

+ 85 - 53
celery/tests/worker/test_worker.py

@@ -17,7 +17,7 @@ from celery.bootsteps import RUN, CLOSE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.exceptions import (
-    WorkerShutdown, WorkerTerminate, TaskRevokedError,
+    WorkerShutdown, WorkerTerminate, TaskRevokedError, InvalidTaskError,
 )
 from celery.five import Empty, range, Queue as FastQueue
 from celery.utils import uuid
@@ -29,7 +29,9 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
-from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
+from celery.tests.case import (
+    AppCase, Mock, SkipTest, TaskMessage, patch, restore_logging,
+)
 
 
 def MockStep(step=None):
@@ -123,6 +125,13 @@ def create_message(channel, **data):
     return m
 
 
+def create_task_message(channel, *args, **kwargs):
+    m = TaskMessage(*args, **kwargs)
+    m.channel = channel
+    m.delivery_info = {'consumer_tag': 'mock'}
+    return m
+
+
 class test_Consumer(AppCase):
 
     def setup(self):
@@ -207,13 +216,13 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.steps.pop()
-        backend = Mock()
-        m = create_message(backend, unknown={'baz': '!!!'})
+        channel = Mock()
+        m = create_message(channel, unknown={'baz': '!!!'})
         l.event_dispatcher = mock_event_dispatcher()
         l.node = MockNode()
 
         callback = self._get_on_message(l)
-        callback(m.decode(), m)
+        callback(m)
         self.assertTrue(warn.call_count)
 
     @patch('celery.worker.strategy.to_timestamp')
@@ -222,17 +231,18 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.steps.pop()
-        m = create_message(Mock(), task=self.foo_task.name,
-                           args=('2, 2'),
-                           kwargs={},
-                           eta=datetime.now().isoformat())
+        m = create_task_message(
+            Mock(), self.foo_task.name,
+            args=('2, 2'), kwargs={},
+            eta=datetime.now().isoformat(),
+        )
         l.event_dispatcher = mock_event_dispatcher()
         l.node = MockNode()
         l.update_strategies()
         l.qos = Mock()
 
         callback = self._get_on_message(l)
-        callback(m.decode(), m)
+        callback(m)
         self.assertTrue(m.acknowledged)
 
     @patch('celery.worker.consumer.error')
@@ -241,13 +251,17 @@ class test_Consumer(AppCase):
         l.blueprint.state = RUN
         l.event_dispatcher = mock_event_dispatcher()
         l.steps.pop()
-        m = create_message(Mock(), task=self.foo_task.name,
-                           args=(1, 2), kwargs='foobarbaz', id=1)
+        m = create_task_message(
+            Mock(), self.foo_task.name,
+            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
         l.event_dispatcher = mock_event_dispatcher()
+        strat = l.strategies[self.foo_task.name] = Mock(name='strategy')
+        strat.side_effect = InvalidTaskError()
 
         callback = self._get_on_message(l)
-        callback(m.decode(), m)
+        callback(m)
+        self.assertTrue(error.called)
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
     @patch('celery.worker.consumer.crit')
@@ -274,18 +288,20 @@ class test_Consumer(AppCase):
 
         with self.assertRaises(WorkerShutdown):
             l.loop(*l.loop_args())
-        self.assertTrue(l.task_consumer.register_callback.called)
-        return l.task_consumer.register_callback.call_args[0][0]
+        self.assertTrue(l.task_consumer.on_message)
+        return l.task_consumer.on_message
 
     def test_receieve_message(self):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.event_dispatcher = mock_event_dispatcher()
-        m = create_message(Mock(), task=self.foo_task.name,
-                           args=[2, 4, 8], kwargs={})
+        m = create_task_message(
+            Mock(), self.foo_task.name,
+            args=[2, 4, 8], kwargs={},
+        )
         l.update_strategies()
         callback = self._get_on_message(l)
-        callback(m.decode(), m)
+        callback(m)
 
         in_bucket = self.buffer.get_nowait()
         self.assertIsInstance(in_bucket, Request)
@@ -419,8 +435,8 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.steps.pop()
-        m = create_message(
-            Mock(), task=self.foo_task.name,
+        m = create_task_message(
+            Mock(), self.foo_task.name,
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
             args=[2, 4, 8], kwargs={},
         )
@@ -432,7 +448,7 @@ class test_Consumer(AppCase):
         l.enabled = False
         l.update_strategies()
         callback = self._get_on_message(l)
-        callback(m.decode(), m)
+        callback(m)
         l.timer.stop()
         l.timer.join(1)
 
@@ -469,27 +485,31 @@ class test_Consumer(AppCase):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.steps.pop()
-        backend = Mock()
+        channel = Mock()
         id = uuid()
-        t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8],
-                           kwargs={}, id=id)
+        t = create_task_message(
+            channel, self.foo_task.name,
+            args=[2, 4, 8], kwargs={}, id=id,
+        )
         from celery.worker.state import revoked
         revoked.add(id)
 
         callback = self._get_on_message(l)
-        callback(t.decode(), t)
+        callback(t)
         self.assertTrue(self.buffer.empty())
 
     def test_receieve_message_not_registered(self):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
         l.steps.pop()
-        backend = Mock()
-        m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
+        channel = Mock(name='channel')
+        m = create_task_message(
+            channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
+        )
 
         l.event_dispatcher = mock_event_dispatcher()
         callback = self._get_on_message(l)
-        self.assertFalse(callback(m.decode(), m))
+        self.assertFalse(callback(m))
         with self.assertRaises(Empty):
             self.buffer.get_nowait()
         self.assertTrue(self.timer.empty())
@@ -499,21 +519,25 @@ class test_Consumer(AppCase):
     def test_receieve_message_ack_raises(self, logger, warn):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
-        backend = Mock()
-        m = create_message(backend, args=[2, 4, 8], kwargs={})
+        channel = Mock()
+        m = create_task_message(
+            channel, self.foo_task.name,
+            args=[2, 4, 8], kwargs={},
+        )
+        m.headers = None
 
         l.event_dispatcher = mock_event_dispatcher()
+        l.update_strategies()
         l.connection_errors = (socket.error, )
         m.reject = Mock()
         m.reject.side_effect = socket.error('foo')
         callback = self._get_on_message(l)
-        self.assertFalse(callback(m.decode(), m))
+        self.assertFalse(callback(m))
         self.assertTrue(warn.call_count)
         with self.assertRaises(Empty):
             self.buffer.get_nowait()
         self.assertTrue(self.timer.empty())
-        m.reject.assert_called_with(requeue=False)
-        self.assertTrue(logger.critical.call_count)
+        m.reject_log_error.assert_called_with(logger, l.connection_errors)
 
     def test_receive_message_eta(self):
         import sys
@@ -529,10 +553,10 @@ class test_Consumer(AppCase):
         pp('-CREATE MYKOMBUCONSUMER')
         l.steps.pop()
         l.event_dispatcher = mock_event_dispatcher()
-        backend = Mock()
+        channel = Mock(name='channel')
         pp('+ CREATE MESSAGE')
-        m = create_message(
-            backend, task=self.foo_task.name,
+        m = create_task_message(
+            channel, self.foo_task.name,
             args=[2, 4, 8], kwargs={},
             eta=(datetime.now() + timedelta(days=1)).isoformat(),
         )
@@ -556,7 +580,7 @@ class test_Consumer(AppCase):
             callback = self._get_on_message(l)
             pp('- GET ON MESSAGE')
             pp('+ CALLBACK')
-            callback(m.decode(), m)
+            callback(m)
             pp('- CALLBACK')
         finally:
             pp('+ STOP TIMER')
@@ -925,10 +949,12 @@ class test_WorkController(AppCase):
     def test_process_task(self):
         worker = self.worker
         worker.pool = Mock()
-        backend = Mock()
-        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request(m.decode(), message=m, app=self.app)
+        channel = Mock()
+        m = create_task_message(
+            channel, self.foo_task.name,
+            args=[4, 8, 10], kwargs={},
+        )
+        task = Request(m, app=self.app)
         worker._process_task(task)
         self.assertEqual(worker.pool.apply_async.call_count, 1)
         worker.pool.stop()
@@ -937,10 +963,12 @@ class test_WorkController(AppCase):
         worker = self.worker
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
-        backend = Mock()
-        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request(m.decode(), message=m, app=self.app)
+        channel = Mock()
+        m = create_task_message(
+            channel, self.foo_task.name,
+            args=[4, 8, 10], kwargs={},
+        )
+        task = Request(m, app=self.app)
         worker.steps = []
         worker.blueprint.state = RUN
         with self.assertRaises(KeyboardInterrupt):
@@ -950,10 +978,12 @@ class test_WorkController(AppCase):
         worker = self.worker
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = WorkerTerminate()
-        backend = Mock()
-        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request(m.decode(), message=m, app=self.app)
+        channel = Mock()
+        m = create_task_message(
+            channel, self.foo_task.name,
+            args=[4, 8, 10], kwargs={},
+        )
+        task = Request(m, app=self.app)
         worker.steps = []
         worker.blueprint.state = RUN
         with self.assertRaises(SystemExit):
@@ -963,10 +993,12 @@ class test_WorkController(AppCase):
         worker = self.worker
         worker.pool = Mock()
         worker.pool.apply_async.side_effect = KeyError('some exception')
-        backend = Mock()
-        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request(m.decode(), message=m, app=self.app)
+        channel = Mock()
+        m = create_task_message(
+            channel, self.foo_task.name,
+            args=[4, 8, 10], kwargs={},
+        )
+        task = Request(m, app=self.app)
         worker._process_task(task)
         worker.pool.stop()
 

+ 1 - 1
celery/worker/autoscale.py

@@ -81,7 +81,7 @@ class Autoscaler(bgThread):
             self.maybe_scale()
         sleep(1.0)
 
-    def _maybe_scale(self):
+    def _maybe_scale(self, req=None):
         procs = self.processes
         cur = min(self.qty, self.max_concurrency)
         if cur > procs:

+ 14 - 21
celery/worker/consumer.py

@@ -447,37 +447,30 @@ class Consumer(object):
         on_invalid_task = self.on_invalid_task
         callbacks = self.on_task_message
 
-        def on_v1_task_received(body, message):
-            try:
-                name = body['task']
-            except (KeyError, TypeError):
-                return on_unknown_message(body, message)
-
-            try:
-                strategies[name](message, body,
-                                 message.ack_log_error,
-                                 message.reject_log_error,
-                                 callbacks)
-            except KeyError as exc:
-                on_unknown_task(body, message, exc)
-            except InvalidTaskError as exc:
-                on_invalid_task(body, message, exc)
-
         def on_task_received(message):
-            headers = message.headers
+
+            # payload will only be set for v1 protocol, since v2
+            # will defer deserializing the message body to the pool.
+            payload = None
             try:
-                type_ = headers['c_type']
+                type_ = message.headers['c_type']   # protocol v2
+            except TypeError:
+                return on_unknown_message(None, message)
             except KeyError:
-                return on_v1_task_received(message.payload, message)
+                payload = message.payload
+                try:
+                    type_ = payload['task']         # protocol v1
+                except (TypeError, KeyError):
+                    return on_unknown_message(payload, message)
             try:
                 strategies[type_](
                     message, None,
                     message.ack_log_error, message.reject_log_error, callbacks,
                 )
             except KeyError as exc:
-                on_unknown_task(None, message, exc)
+                on_unknown_task(payload, message, exc)
             except InvalidTaskError as exc:
-                on_invalid_task(None, message, exc)
+                on_invalid_task(payload, message, exc)
 
         return on_task_received
 

+ 47 - 49
celery/worker/job.py

@@ -13,7 +13,6 @@ import logging
 import socket
 import sys
 
-from billiard.einfo import ExceptionInfo
 from datetime import datetime
 from weakref import ref
 
@@ -83,7 +82,7 @@ DEFAULT_FIELDS = {
 class RequestV1(object):
     if not IS_PYPY:
         __slots__ = (
-            'app', 'name', 'id', 'root_id', 'parent_id',
+            'app', 'message', 'name', 'id', 'root_id', 'parent_id',
             'on_ack', 'hostname', 'eventer', 'connection_errors', 'task',
             'eta', 'expires', 'request_dict', 'acknowledged', 'on_reject',
             'utc', 'time_start', 'worker_pid', '_already_revoked',
@@ -94,9 +93,10 @@ class RequestV1(object):
 
 class Request(object):
     """A request for task execution."""
+    utc = True
     if not IS_PYPY:  # pragma: no cover
         __slots__ = (
-            'app', 'name', 'id', 'on_ack', 'payload',
+            'app', 'name', 'id', 'on_ack', 'body',
             'hostname', 'eventer', 'connection_errors', 'task', 'eta',
             'expires', 'request_dict', 'acknowledged', 'on_reject',
             'utc', 'time_start', 'worker_pid', 'timeouts',
@@ -111,9 +111,10 @@ class Request(object):
                  task=None, on_reject=noop, **opts):
         headers = message.headers
         self.app = app
+        self.message = message
         name = self.name = headers['c_type']
-        self.id = headers['task_id']
-        self.payload = message.body
+        self.id = headers['id']
+        self.body = message.body
         self.content_type = message.content_type
         self.content_encoding = message.content_encoding
         eta = headers.get('eta')
@@ -185,14 +186,14 @@ class Request(object):
         if self.revoked():
             raise TaskRevokedError(task_id)
 
-        payload = self.payload
+        body = self.body
         timeout, soft_timeout = self.timeouts
         timeout = timeout or task.time_limit
         soft_timeout = soft_timeout or task.soft_time_limit
         result = pool.apply_async(
             trace_task_ret,
             args=(self.name, task_id, self.request_dict,
-                  bytes(payload) if isinstance(payload, buffer) else payload,
+                  bytes(body) if isinstance(body, buffer) else body,
                   self.content_type, self.content_encoding),
             kwargs={'hostname': self.hostname, 'is_eager': False},
             accept_callback=self.on_accepted,
@@ -221,14 +222,14 @@ class Request(object):
         if not self.task.acks_late:
             self.acknowledge()
 
-        kwargs = self.kwargs
         request = self.request_dict
+        args, kwargs = self.message.payload
         request.update({'loglevel': loglevel, 'logfile': logfile,
                         'hostname': self.hostname, 'is_eager': False,
-                        'delivery_info': self.delivery_info})
-        retval = trace_task(self.task, self.id, self.args, kwargs, request,
+                        'args': args, 'kwargs': kwargs})
+        retval = trace_task(self.task, self.id, args, kwargs, request,
                             hostname=self.hostname, loader=self.app.loader,
-                            app=self.app)
+                            app=self.app)[0]
         self.acknowledge()
         return retval
 
@@ -313,22 +314,21 @@ class Request(object):
         if self.task.acks_late:
             self.acknowledge()
 
-    def on_success(self, ret_value, **kwargs):
+    def on_success(self, failed__retval__runtime, **kwargs):
         """Handler called if the task was successfully processed."""
-        if isinstance(ret_value, ExceptionInfo):
-            if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):
-                raise ret_value.exception
-            return self.on_failure(ret_value)
+        failed, retval, runtime = failed__retval__runtime
+        if failed:
+            if isinstance(retval.exception, (SystemExit, KeyboardInterrupt)):
+                raise retval.exception
+            return self.on_failure(retval)
         task_ready(self)
 
         if self.task.acks_late:
             self.acknowledge()
 
         if self.eventer and self.eventer.enabled:
-            result, runtime = ret_value
             self.send_event(
-                'task-succeeded', result=ret_value, runtime=runtime,
+                'task-succeeded', result=retval, runtime=runtime,
             )
 
     def on_retry(self, exc_info):
@@ -340,38 +340,36 @@ class Request(object):
                         exception=safe_repr(exc_info.exception.exc),
                         traceback=safe_str(exc_info.traceback))
 
-    def on_failure(self, exc_info):
+    def on_failure(self, exc_info, send_failed_event=True):
         """Handler called if the task raised an exception."""
         task_ready(self)
-        send_failed_event = True
-
-        if exc_info.internal:
-            if isinstance(exc_info.exception, MemoryError):
-                raise MemoryError('Process got: %s' % (exc_info.exception, ))
-            elif isinstance(exc_info.exception, Reject):
-                self.reject(requeue=exc_info.exception.requeue)
-            elif isinstance(exc_info.exception, Ignore):
-                self.acknowledge()
-        else:
-            exc = exc_info.exception
-
-            if isinstance(exc, Retry):
-                return self.on_retry(exc_info)
-
-            # These are special cases where the process would not have had
-            # time to write the result.
-            if self.store_errors:
-                if isinstance(exc, WorkerLostError):
-                    self.task.backend.mark_as_failure(
-                        self.id, exc, request=self,
-                    )
-                elif isinstance(exc, Terminated):
-                    self._announce_revoked(
-                        'terminated', True, string(exc), False)
-                    send_failed_event = False  # already sent revoked event
-            # (acks_late) acknowledge after result stored.
-            if self.task.acks_late:
-                self.acknowledge()
+
+        if isinstance(exc_info.exception, MemoryError):
+            raise MemoryError('Process got: %s' % (exc_info.exception, ))
+        elif isinstance(exc_info.exception, Reject):
+            return self.reject(requeue=exc_info.exception.requeue)
+        elif isinstance(exc_info.exception, Ignore):
+            return self.acknowledge()
+
+        exc = exc_info.exception
+
+        if isinstance(exc, Retry):
+            return self.on_retry(exc_info)
+
+        # These are special cases where the process would not have had
+        # time to write the result.
+        if self.store_errors:
+            if isinstance(exc, WorkerLostError):
+                self.task.backend.mark_as_failure(
+                    self.id, exc, request=self,
+                )
+            elif isinstance(exc, Terminated):
+                self._announce_revoked(
+                    'terminated', True, string(exc), False)
+                send_failed_event = False  # already sent revoked event
+        # (acks_late) acknowledge after result stored.
+        if self.task.acks_late:
+            self.acknowledge()
 
         if send_failed_event:
             self.send_event(

+ 1 - 1
celery/worker/strategy.py

@@ -89,7 +89,7 @@ def default(task, app, consumer,
                     return limit_task(req, bucket, 1)
             task_reserved(req)
             if callbacks:
-                [callback() for callback in callbacks]
+                [callback(req) for callback in callbacks]
             handle(req)
 
     return task_message_handler

+ 2 - 2
docs/internals/protov2.rst

@@ -102,8 +102,8 @@ Definition
     }
     headers = {
         'lang': (string)'py'
-        'c_type': (string)task,
-        'task_id': (uuid)task_id,
+        'task': (string)task,
+        'id': (uuid)task_id,
         'root_id': (uuid)root_id,
         'parent_id': (uuid)parent_id,