Przeglądaj źródła

RPC Backend related improvements:

* reply_to is now used as a property

* correlation_id message property now set by amqp and rpc backends.

* CELERY_RESULT_PERSITENT is now None by default,
  and the actual default is decided by the result backend.

* amqp backend is persistent by default, rpc backend is not.

* rpc backend now uses the anon exchange to bypass routing.
Ask Solem 11 lat temu
rodzic
commit
77f0ee26dd

+ 2 - 1
celery/app/amqp.py

@@ -267,7 +267,6 @@ class TaskProducer(Producer):
             'utc': self.utc,
             'callbacks': callbacks,
             'errbacks': errbacks,
-            'reply_to': reply_to,
             'timelimit': (time_limit, soft_time_limit),
             'taskset': group_id or taskset_id,
             'chord': chord,
@@ -289,6 +288,8 @@ class TaskProducer(Producer):
             compression=compression or self.compression,
             headers=headers,
             retry=retry, retry_policy=_rp,
+            reply_to=reply_to,
+            correlation_id=task_id,
             delivery_mode=delivery_mode, declare=declare,
             **kwargs
         )

+ 1 - 1
celery/app/defaults.py

@@ -140,7 +140,7 @@ NAMESPACES = {
         'RESULT_EXCHANGE': Option('celeryresults'),
         'RESULT_EXCHANGE_TYPE': Option('direct'),
         'RESULT_SERIALIZER': Option('pickle'),
-        'RESULT_PERSISTENT': Option(False, type='bool'),
+        'RESULT_PERSISTENT': Option(None, type='bool'),
         'ROUTES': Option(type='any'),
         'SEND_EVENTS': Option(False, type='bool'),
         'SEND_TASK_ERROR_EMAILS': Option(False, type='bool'),

+ 2 - 0
celery/app/task.py

@@ -75,6 +75,8 @@ class Context(object):
     is_eager = False
     headers = None
     delivery_info = None
+    reply_to = None
+    correlation_id = None
     taskset = None   # compat alias to group
     group = None
     chord = None

+ 19 - 9
celery/backends/amqp.py

@@ -70,12 +70,13 @@ class AMQPBackend(BaseBackend):
         super(AMQPBackend, self).__init__(app, **kwargs)
         conf = self.app.conf
         self._connection = connection
-        self.persistent = (conf.CELERY_RESULT_PERSISTENT if persistent is None
-                           else persistent)
+        self.persistent = self.prepare_persistent(persistent)
+        self.delivery_mode = 2 if self.persistent else 1
         exchange = exchange or conf.CELERY_RESULT_EXCHANGE
         exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
-        self.exchange = self._create_exchange(exchange, exchange_type,
-                                              self.persistent)
+        self.exchange = self._create_exchange(
+            exchange, exchange_type, self.delivery_mode,
+        )
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.auto_delete = auto_delete
 
@@ -86,8 +87,7 @@ class AMQPBackend(BaseBackend):
             'x-expires': maybe_s_to_ms(self.expires),
         })
 
-    def _create_exchange(self, name, type='direct', persistent=True):
-        delivery_mode = persistent and 'persistent' or 'transient'
+    def _create_exchange(self, name, type='direct', delivery_mode=2):
         return self.Exchange(name=name,
                              type=type,
                              delivery_mode=delivery_mode,
@@ -95,7 +95,7 @@ class AMQPBackend(BaseBackend):
                              auto_delete=False)
 
     def _create_binding(self, task_id):
-        name = task_id.replace('-', '')
+        name = self.rkey(task_id)
         return self.Queue(name=name,
                           exchange=self.exchange,
                           routing_key=name,
@@ -106,12 +106,20 @@ class AMQPBackend(BaseBackend):
     def revive(self, channel):
         pass
 
-    def _routing_key(self, task_id, request):
+    def rkey(self, task_id):
         return task_id.replace('-', '')
 
+    def destination_for(self, task_id, request):
+        if request:
+            return self.rkey(task_id), request.correlation_id or task_id
+        return self.rkey(task_id), task_id
+
     def _store_result(self, task_id, result, status,
                       traceback=None, request=None, **kwargs):
         """Send task return value and status."""
+        routing_key, correlation_id = self.destination_for(task_id, request)
+        if not routing_key:
+            return
         with self.app.amqp.producer_pool.acquire(block=True) as producer:
             producer.publish(
                 {'task_id': task_id, 'status': status,
@@ -119,10 +127,12 @@ class AMQPBackend(BaseBackend):
                  'traceback': traceback,
                  'children': self.current_task_children(request)},
                 exchange=self.exchange,
-                routing_key=self._routing_key(task_id, request),
+                routing_key=routing_key,
+                correlation_id=correlation_id,
                 serializer=self.serializer,
                 retry=True, retry_policy=self.retry_policy,
                 declare=self.on_reply_declare(task_id),
+                delivery_mode=self.delivery_mode,
             )
         return result
 

+ 9 - 0
celery/backends/base.py

@@ -70,6 +70,9 @@ class BaseBackend(object):
     #: in this case.
     supports_autoexpire = False
 
+    #: Set to true if the backend is peristent by default.
+    persistent = True
+
     def __init__(self, app, serializer=None,
                  max_cached_results=None, accept=None, **kwargs):
         self.app = app
@@ -188,6 +191,12 @@ class BaseBackend(object):
             return type(value)
         return value
 
+    def prepare_persistent(self, enabled=None):
+        if enabled is not None:
+            return enabled
+        p = self.app.conf.CELERY_RESULT_PERSISTENT
+        return self.persistent if p is None else p
+
     def encode_result(self, result, status):
         if status in self.EXCEPTION_STATES and isinstance(result, Exception):
             return self.prepare_exception(result)

+ 19 - 12
celery/backends/rpc.py

@@ -8,8 +8,7 @@
 """
 from __future__ import absolute_import
 
-import kombu
-
+from kombu import Consumer, Exchange
 from kombu.common import maybe_declare
 from kombu.utils import cached_property
 
@@ -20,13 +19,14 @@ __all__ = ['RPCBackend']
 
 
 class RPCBackend(amqp.AMQPBackend):
+    persistent = False
 
-    class Consumer(kombu.Consumer):
+    class Consumer(Consumer):
         auto_declare = False
 
-    def _create_exchange(self, name, type='direct', persistent=False):
-        return self.Exchange('c.rep', type=type, delivery_mode=1,
-                             durable=False, auto_delete=False)
+    def _create_exchange(self, name, type='direct', delivery_mode=2):
+        # uses direct to queue routing (anon exchange).
+        return Exchange(None)
 
     def on_task_call(self, producer, task_id):
         maybe_declare(self.binding(producer.channel), retry=True)
@@ -37,12 +37,19 @@ class RPCBackend(amqp.AMQPBackend):
     def _many_bindings(self, ids):
         return [self.binding]
 
-    def _routing_key(self, task_id, request):
-        if request:
-            return request.reply_to
-        task = current_task._get_current_object()
-        if task is not None:
-            return task.request.reply_to
+    def rkey(self, task_id):
+        return task_id
+
+    def destination_for(self, task_id, request):
+        # Request is a new argument for backends, so must still support
+        # old code that rely on current_task
+        try:
+            request = request or current_task.request
+        except AttributeError:
+            raise RuntimeError(
+                'RPC backend missing task request for {0!r}'.format(task_id),
+            )
+        return request.reply_to, request.correlation_id or task_id
 
     def on_reply_declare(self, task_id):
         pass

+ 14 - 3
celery/tests/backends/test_rpc.py

@@ -22,18 +22,29 @@ class test_RPCBackend(AppCase):
     def test_interface(self):
         self.b.on_reply_declare('task_id')
 
-    def test_current_routing_key(self):
+    def test_destination_for(self):
         req = Mock(name='request')
         req.reply_to = 'reply_to'
-        self.assertEqual(self.b._routing_key('task_id', req), 'reply_to')
+        req.correlation_id = 'corid'
+        self.assertTupleEqual(
+            self.b.destination_for('task_id', req),
+            ('reply_to', 'corid'),
+        )
         task = Mock()
         _task_stack.push(task)
         try:
             task.request.reply_to = 'reply_to'
-            self.assertEqual(self.b._routing_key('task_id', None), 'reply_to')
+            task.request.correlation_id = 'corid'
+            self.assertTupleEqual(
+                self.b.destination_for('task_id', None),
+                ('reply_to', 'corid'),
+            )
         finally:
             _task_stack.pop()
 
+        with self.assertRaises(RuntimeError):
+            self.b.destination_for('task_id', None)
+
     def test_binding(self):
         queue = self.b.binding
         self.assertEqual(queue.name, self.b.oid)

+ 19 - 7
celery/tests/worker/test_control.py

@@ -17,7 +17,7 @@ from celery.worker import WorkController as _WC
 from celery.worker import consumer
 from celery.worker import control
 from celery.worker import state as worker_state
-from celery.worker.job import TaskRequest
+from celery.worker.job import Request
 from celery.worker.state import revoked
 from celery.worker.control import Panel
 from celery.worker.pidbox import Pidbox, gPidbox
@@ -251,7 +251,12 @@ class test_ControlPanel(AppCase):
         self.panel.handle('report')
 
     def test_active(self):
-        r = TaskRequest(self.mytask.name, 'do re mi', (), {}, app=self.app)
+        r = Request({
+            'task': self.mytask.name,
+            'id': 'do re mi',
+            'args': (),
+            'kwargs': {},
+        }, app=self.app)
         worker_state.active_requests.add(r)
         try:
             self.assertTrue(self.panel.handle('dump_active'))
@@ -339,7 +344,12 @@ class test_ControlPanel(AppCase):
         consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         self.assertFalse(panel.handle('dump_schedule'))
-        r = TaskRequest(self.mytask.name, 'CAFEBABE', (), {}, app=self.app)
+        r = Request({
+            'task': self.mytask.name,
+            'id': 'CAFEBABE',
+            'args': (),
+            'kwargs': {},
+        }, app=self.app)
         consumer.timer.schedule.enter_at(
             consumer.timer.Entry(lambda x: x, (r, )),
             datetime.now() + timedelta(seconds=10))
@@ -350,10 +360,12 @@ class test_ControlPanel(AppCase):
 
     def test_dump_reserved(self):
         consumer = Consumer(self.app)
-        worker_state.reserved_requests.add(
-            TaskRequest(self.mytask.name, uuid(), args=(2, 2), kwargs={},
-                        app=self.app),
-        )
+        worker_state.reserved_requests.add(Request({
+            'task': self.mytask.name,
+            'id': uuid(),
+            'args': (2, 2),
+            'kwargs': {},
+        }, app=self.app))
         try:
             panel = self.create_panel(consumer=consumer)
             response = panel.handle('dump_reserved', {'safe': True})

+ 107 - 119
celery/tests/worker/test_request.py

@@ -38,7 +38,7 @@ from celery.five import keys, monotonic
 from celery.signals import task_revoked
 from celery.utils import uuid
 from celery.worker import job as module
-from celery.worker.job import Request, TaskRequest, logger as req_logger
+from celery.worker.job import Request, logger as req_logger
 from celery.worker.state import revoked
 
 from celery.tests.case import (
@@ -314,48 +314,47 @@ class test_Request(AppCase):
         kwargs = args[3]
         self.assertEqual(kwargs.get('task_name'), task.task)
 
+    def xRequest(self, body=None, **kwargs):
+        body = dict({'task': self.mytask.name,
+                     'id': uuid(),
+                     'args': [1],
+                     'kwargs': {'f': 'x'}}, **body or {})
+        return Request(body, app=self.app, **kwargs)
+
     def test_task_wrapper_repr(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
-        self.assertTrue(repr(job))
+        self.assertTrue(repr(self.xRequest()))
 
     @patch('celery.worker.job.kwdict')
     def test_kwdict(self, kwdict):
         prev, module.NEEDS_KWDICT = module.NEEDS_KWDICT, True
         try:
-            TaskRequest(
-                self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-            )
+            self.xRequest()
             self.assertTrue(kwdict.called)
         finally:
             module.NEEDS_KWDICT = prev
 
     def test_sets_store_errors(self):
         self.mytask.ignore_result = True
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         self.assertFalse(job.store_errors)
 
         self.mytask.store_errors_even_if_ignored = True
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         self.assertTrue(job.store_errors)
 
     def test_send_event(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.eventer = MockEventDispatcher()
         job.send_event('task-frobulated')
         self.assertIn('task-frobulated', job.eventer.sent)
 
     def test_on_retry(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = Request({
+            'task': self.mytask.name,
+            'id': uuid(),
+            'args': [1],
+            'kwargs': {'f': 'x'},
+        }, app=self.app)
         job.eventer = MockEventDispatcher()
         try:
             raise Retry('foo', KeyError('moofoobar'))
@@ -372,9 +371,12 @@ class test_Request(AppCase):
             job.on_failure(einfo)
 
     def test_compat_properties(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = Request({
+            'task': self.mytask.name,
+            'id': uuid(),
+            'args': [1],
+            'kwargs': {'f': 'x'},
+        }, app=self.app)
         self.assertEqual(job.task_id, job.id)
         self.assertEqual(job.task_name, job.name)
         job.task_id = 'ID'
@@ -385,9 +387,12 @@ class test_Request(AppCase):
     def test_terminate__task_started(self):
         pool = Mock()
         signum = signal.SIGKILL
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = Request({
+            'task': self.mytask.name,
+            'id': uuid(),
+            'args': [1],
+            'kwrgs': {'f': 'x'},
+        }, app=self.app)
         with assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
@@ -398,9 +403,12 @@ class test_Request(AppCase):
 
     def test_terminate__task_reserved(self):
         pool = Mock()
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = Request({
+            'task': self.mytask.name,
+            'id': uuid(),
+            'args': [1],
+            'kwargs': {'f': 'x'},
+        }, app=self.app)
         job.time_start = None
         job.terminate(pool, signal='KILL')
         self.assertFalse(pool.terminate_job.called)
@@ -408,10 +416,13 @@ class test_Request(AppCase):
         job.terminate(pool, signal='KILL')
 
     def test_revoked_expires_expired(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-            expires=datetime.utcnow() - timedelta(days=1),
-        )
+        job = Request({
+            'task': self.mytask.name,
+            'id': uuid(),
+            'args': [1],
+            'kwargs': {'f': 'x'},
+            'expires': datetime.utcnow() - timedelta(days=1),
+        }, app=self.app)
         with assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=False, expired=True, signum=None):
@@ -423,10 +434,9 @@ class test_Request(AppCase):
             )
 
     def test_revoked_expires_not_expired(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-            expires=datetime.utcnow() + timedelta(days=1),
-        )
+        job = self.xRequest({
+            'expires': datetime.utcnow() + timedelta(days=1),
+        })
         job.revoked()
         self.assertNotIn(job.id, revoked)
         self.assertNotEqual(
@@ -436,10 +446,9 @@ class test_Request(AppCase):
 
     def test_revoked_expires_ignore_result(self):
         self.mytask.ignore_result = True
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-            expires=datetime.utcnow() - timedelta(days=1),
-        )
+        job = self.xRequest({
+            'expires': datetime.utcnow() - timedelta(days=1),
+        })
         job.revoked()
         self.assertIn(job.id, revoked)
         self.assertNotEqual(
@@ -461,10 +470,7 @@ class test_Request(AppCase):
 
         app.mail_admins = mock_mail_admins
         self.mytask.send_error_emails = True
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
-
+        job = self.xRequest()
         einfo = get_ei()
         job.on_failure(einfo)
         self.assertTrue(mail_sent[0])
@@ -482,16 +488,12 @@ class test_Request(AppCase):
         self.assertTrue(mail_sent[0])
 
     def test_already_revoked(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job._already_revoked = True
         self.assertTrue(job.revoked())
 
     def test_revoked(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         with assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=False, expired=False, signum=None):
@@ -501,31 +503,28 @@ class test_Request(AppCase):
             self.assertTrue(job.acknowledged)
 
     def test_execute_does_not_execute_revoked(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         revoked.add(job.id)
         job.execute()
 
     def test_execute_acks_late(self):
         self.mytask_raising.acks_late = True
-        job = TaskRequest(self.mytask_raising.name, uuid(), [1], app=self.app)
+        job = self.xRequest({
+            'task': self.mytask_raising.name,
+            'kwargs': {},
+        })
         job.execute()
         self.assertTrue(job.acknowledged)
         job.execute()
 
     def test_execute_using_pool_does_not_execute_revoked(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         revoked.add(job.id)
         with self.assertRaises(TaskRevokedError):
             job.execute_using_pool(None)
 
     def test_on_accepted_acks_early(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.on_accepted(pid=os.getpid(), time_accepted=monotonic())
         self.assertTrue(job.acknowledged)
         prev, module._does_debug = module._does_debug, False
@@ -535,9 +534,7 @@ class test_Request(AppCase):
             module._does_debug = prev
 
     def test_on_accepted_acks_late(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         self.mytask.acks_late = True
         job.on_accepted(pid=os.getpid(), time_accepted=monotonic())
         self.assertFalse(job.acknowledged)
@@ -545,9 +542,7 @@ class test_Request(AppCase):
     def test_on_accepted_terminates(self):
         signum = signal.SIGKILL
         pool = Mock()
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         with assert_signal_called(
                 task_revoked, sender=job.task, request=job,
                 terminated=True, expired=False, signum=signum):
@@ -557,9 +552,7 @@ class test_Request(AppCase):
             pool.terminate_job.assert_called_with(314, signum)
 
     def test_on_success_acks_early(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.time_start = 1
         job.on_success(42)
         prev, module._does_info = module._does_info, False
@@ -570,9 +563,7 @@ class test_Request(AppCase):
             module._does_info = prev
 
     def test_on_success_BaseException(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.time_start = 1
         with self.assertRaises(SystemExit):
             try:
@@ -583,19 +574,15 @@ class test_Request(AppCase):
                 assert False
 
     def test_on_success_eventer(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.time_start = 1
         job.eventer = Mock()
-        job.send_event = Mock()
+        job.eventer.send = Mock()
         job.on_success(42)
-        self.assertTrue(job.send_event.called)
+        self.assertTrue(job.eventer.send.called)
 
     def test_on_success_when_failure(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.time_start = 1
         job.on_failure = Mock()
         try:
@@ -605,9 +592,7 @@ class test_Request(AppCase):
             self.assertTrue(job.on_failure.called)
 
     def test_on_success_acks_late(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.time_start = 1
         self.mytask.acks_late = True
         job.on_success(42)
@@ -621,9 +606,7 @@ class test_Request(AppCase):
             except WorkerLostError:
                 return ExceptionInfo()
 
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         exc_info = get_ei()
         job.on_failure(exc_info)
         self.assertEqual(
@@ -632,18 +615,14 @@ class test_Request(AppCase):
 
         self.mytask.ignore_result = True
         exc_info = get_ei()
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.on_failure(exc_info)
         self.assertEqual(
             self.mytask.backend.get_status(job.id), states.PENDING,
         )
 
     def test_on_failure_acks_late(self):
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.time_start = 1
         self.mytask.acks_late = True
         try:
@@ -656,15 +635,13 @@ class test_Request(AppCase):
     def test_from_message_invalid_kwargs(self):
         body = dict(task=self.mytask.name, id=1, args=(), kwargs='foo')
         with self.assertRaises(InvalidTaskError):
-            TaskRequest.from_message(None, body, app=self.app)
+            Request(body, message=None, app=self.app)
 
     @patch('celery.worker.job.error')
     @patch('celery.worker.job.warn')
     def test_on_timeout(self, warn, error):
 
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.on_timeout(soft=True, timeout=1337)
         self.assertIn('Soft time limit', warn.call_args[0][0])
         job.on_timeout(soft=False, timeout=1337)
@@ -674,9 +651,7 @@ class test_Request(AppCase):
         )
 
         self.mytask.ignore_result = True
-        job = TaskRequest(
-            self.mytask.name, uuid(), [1], {'f': 'x'}, app=self.app,
-        )
+        job = self.xRequest()
         job.on_timeout(soft=True, timeout=1336)
         self.assertEqual(
             self.mytask.backend.get_status(job.id), states.PENDING,
@@ -774,7 +749,7 @@ class test_Request(AppCase):
             self.mytask.pop_request()
 
     def test_task_wrapper_mail_attrs(self):
-        job = TaskRequest(self.mytask.name, uuid(), [], {}, app=self.app)
+        job = self.xRequest({'args': [], 'kwargs': {}})
         x = job.success_msg % {
             'name': job.name,
             'id': job.id,
@@ -797,7 +772,7 @@ class test_Request(AppCase):
         m = Message(None, body=anyjson.dumps(body), backend='foo',
                     content_type='application/json',
                     content_encoding='utf-8')
-        job = TaskRequest.from_message(m, m.decode(), app=self.app)
+        job = Request(m.decode(), message=m, app=self.app)
         self.assertIsInstance(job, Request)
         self.assertEqual(job.name, body['task'])
         self.assertEqual(job.id, body['id'])
@@ -812,7 +787,7 @@ class test_Request(AppCase):
         m = Message(None, body=anyjson.dumps(body), backend='foo',
                     content_type='application/json',
                     content_encoding='utf-8')
-        job = TaskRequest.from_message(m, m.decode(), app=self.app)
+        job = Request(m.decode(), message=m, app=self.app)
         self.assertIsInstance(job, Request)
         self.assertEqual(job.args, [])
         self.assertEqual(job.kwargs, {})
@@ -823,7 +798,7 @@ class test_Request(AppCase):
                     content_type='application/json',
                     content_encoding='utf-8')
         with self.assertRaises(KeyError):
-            TaskRequest.from_message(m, m.decode(), app=self.app)
+            Request(m.decode(), message=m, app=self.app)
 
     def test_from_message_nonexistant_task(self):
         body = {'task': 'cu.mytask.doesnotexist', 'id': uuid(),
@@ -832,15 +807,15 @@ class test_Request(AppCase):
                     content_type='application/json',
                     content_encoding='utf-8')
         with self.assertRaises(KeyError):
-            TaskRequest.from_message(m, m.decode(), app=self.app)
+            Request(m.decode(), message=m, app=self.app)
 
     def test_execute(self):
         tid = uuid()
-        job = TaskRequest(self.mytask.name, tid, [4], {'f': 'x'}, app=self.app)
+        job = self.xRequest({'id': tid, 'args': [4], 'kwargs': {}})
         self.assertEqual(job.execute(), 256)
         meta = self.mytask.backend.get_task_meta(tid)
-        self.assertEqual(meta['result'], 256)
         self.assertEqual(meta['status'], states.SUCCESS)
+        self.assertEqual(meta['result'], 256)
 
     def test_execute_success_no_kwargs(self):
 
@@ -849,7 +824,12 @@ class test_Request(AppCase):
             return i ** i
 
         tid = uuid()
-        job = TaskRequest(mytask_no_kwargs.name, tid, [4], {}, app=self.app)
+        job = self.xRequest({
+            'task': mytask_no_kwargs.name,
+            'id': tid,
+            'args': [4],
+            'kwargs': {},
+        })
         self.assertEqual(job.execute(), 256)
         meta = mytask_no_kwargs.backend.get_task_meta(tid)
         self.assertEqual(meta['result'], 256)
@@ -864,7 +844,12 @@ class test_Request(AppCase):
             return i ** i
 
         tid = uuid()
-        job = TaskRequest(mytask_some_kwargs.name, tid, [4], {}, app=self.app)
+        job = self.xRequest({
+            'task': mytask_some_kwargs.name,
+            'id': tid,
+            'args': [4],
+            'kwargs': {},
+        })
         self.assertEqual(job.execute(), 256)
         meta = mytask_some_kwargs.backend.get_task_meta(tid)
         self.assertEqual(scratch.get('task_id'), tid)
@@ -878,10 +863,7 @@ class test_Request(AppCase):
             scratch['ACK'] = True
 
         tid = uuid()
-        job = TaskRequest(
-            self.mytask.name, tid, [4], {'f': 'x'},
-            on_ack=on_ack, app=self.app,
-        )
+        job = self.xRequest({'id': tid, 'args': [4]}, on_ack=on_ack)
         self.assertEqual(job.execute(), 256)
         meta = self.mytask.backend.get_task_meta(tid)
         self.assertTrue(scratch['ACK'])
@@ -890,7 +872,12 @@ class test_Request(AppCase):
 
     def test_execute_fail(self):
         tid = uuid()
-        job = TaskRequest(self.mytask_raising.name, tid, [4], app=self.app)
+        job = self.xRequest({
+            'task': self.mytask_raising.name,
+            'id': tid,
+            'args': [4],
+            'kwargs': {},
+        })
         self.assertIsInstance(job.execute(), ExceptionInfo)
         meta = self.mytask_raising.backend.get_task_meta(tid)
         self.assertEqual(meta['status'], states.FAILURE)
@@ -898,7 +885,7 @@ class test_Request(AppCase):
 
     def test_execute_using_pool(self):
         tid = uuid()
-        job = TaskRequest(self.mytask.name, tid, [4], {'f': 'x'}, app=self.app)
+        job = self.xRequest({'id': tid, 'args': [4]})
 
         class MockPool(BasePool):
             target = None
@@ -927,8 +914,9 @@ class test_Request(AppCase):
         job.execute_using_pool(p)
 
     def test_default_kwargs(self):
+        self.maxDiff = 3000
         tid = uuid()
-        job = TaskRequest(self.mytask.name, tid, [4], {'f': 'x'}, app=self.app)
+        job = self.xRequest({'id': tid, 'args': [4]})
         self.assertDictEqual(
             job.extend_with_default_kwargs(), {
                 'f': 'x',
@@ -940,8 +928,8 @@ class test_Request(AppCase):
                 'delivery_info': {
                     'exchange': None,
                     'routing_key': None,
-                    'priority': None,
-                    'redelivered': None,
+                    'priority': 0,
+                    'redelivered': False,
                 },
                 'task_name': job.name})
 
@@ -949,7 +937,7 @@ class test_Request(AppCase):
     def _test_on_failure(self, exception, logger):
         app = self.app
         tid = uuid()
-        job = TaskRequest(self.mytask.name, tid, [4], {'f': 'x'}, app=self.app)
+        job = self.xRequest({'id': tid, 'args': [4]})
         try:
             raise exception
         except Exception:

+ 4 - 4
celery/tests/worker/test_worker.py

@@ -878,7 +878,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode(), app=self.app)
+        task = Request(m.decode(), message=m, app=self.app)
         worker._process_task(task)
         self.assertEqual(worker.pool.apply_async.call_count, 1)
         worker.pool.stop()
@@ -890,7 +890,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode(), app=self.app)
+        task = Request(m.decode(), message=m, app=self.app)
         worker.steps = []
         worker.blueprint.state = RUN
         with self.assertRaises(KeyboardInterrupt):
@@ -903,7 +903,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode(), app=self.app)
+        task = Request(m.decode(), message=m, app=self.app)
         worker.steps = []
         worker.blueprint.state = RUN
         with self.assertRaises(SystemExit):
@@ -916,7 +916,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = Request.from_message(m, m.decode(), app=self.app)
+        task = Request(m.decode(), message=m, app=self.app)
         worker._process_task(task)
         worker.pool.stop()
 

+ 43 - 33
celery/worker/job.py

@@ -7,7 +7,7 @@
     which specifies how tasks are executed.
 
 """
-from __future__ import absolute_import
+from __future__ import absolute_import, unicode_literals
 
 import logging
 import socket
@@ -67,17 +67,30 @@ revoked_tasks = state.revoked
 
 NEEDS_KWDICT = sys.version_info <= (2, 6)
 
+#: Use when no message object passed to :class:`Request`.
+DEFAULT_FIELDS = {
+    'headers': None,
+    'reply_to': None,
+    'correlation_id': None,
+    'delivery_info': {
+        'exchange': None,
+        'routing_key': None,
+        'priority': 0,
+        'redelivered': False,
+    },
+}
+
 
 class Request(object):
     """A request for task execution."""
     if not IS_PYPY:  # pragma: no cover
         __slots__ = (
-            'app', 'name', 'id', 'args', 'kwargs', 'on_ack', 'delivery_info',
+            'app', 'name', 'id', 'args', 'kwargs', 'on_ack',
             'hostname', 'eventer', 'connection_errors', 'task', 'eta',
             'expires', 'request_dict', 'acknowledged', 'on_reject',
             'utc', 'time_start', 'worker_pid', '_already_revoked',
-            '_terminate_on_ack', 'headers',
-            '_tzlocal', '__weakref__',
+            '_terminate_on_ack',
+            '_tzlocal', '__weakref__', '__dict__',
         )
 
     #: Format string used to log task success.
@@ -109,8 +122,7 @@ class Request(object):
     def __init__(self, body, on_ack=noop,
                  hostname=None, eventer=None, app=None,
                  connection_errors=None, request_dict=None,
-                 delivery_info=None, headers=None, task=None,
-                 on_reject=noop, **opts):
+                 message=None, task=None, on_reject=noop, **opts):
         self.app = app
         name = self.name = body['task']
         self.id = body['id']
@@ -159,23 +171,28 @@ class Request(object):
         else:
             self.expires = None
 
-        delivery_info = {} if delivery_info is None else delivery_info
-        self.delivery_info = {
-            'exchange': delivery_info.get('exchange'),
-            'routing_key': delivery_info.get('routing_key'),
-            'priority': delivery_info.get('priority'),
-            'redelivered': delivery_info.get('redelivered'),
-        }
-        body['headers'] = headers  # pass application/headers
+        if message:
+            delivery_info = message.delivery_info or {}
+            properties = message.properties or {}
+            body.update({
+                'headers': message.headers,
+                'reply_to': properties.get('reply_to'),
+                'correlation_id': properties.get('correlation_id'),
+                'delivery_info': {
+                    'exchange': delivery_info.get('exchange'),
+                    'routing_key': delivery_info.get('routing_key'),
+                    'priority': delivery_info.get('priority'),
+                    'redelivered': delivery_info.get('redelivered'),
+                }
+
+            })
+        else:
+            body.update(DEFAULT_FIELDS)
         self.request_dict = body
 
-    @classmethod
-    def from_message(cls, message, body, **kwargs):
-        # should be deprecated
-        return Request(
-            body,
-            delivery_info=getattr(message, 'delivery_info', None), **kwargs
-        )
+    @property
+    def delivery_info(self):
+        return self.request_dict['delivery_info']
 
     def extend_with_default_kwargs(self):
         """Extend the tasks keyword arguments with standard task arguments.
@@ -205,7 +222,7 @@ class Request(object):
         return kwargs
 
     def execute_using_pool(self, pool, **kwargs):
-        """Like :meth:`execute`, but using a worker pool.
+        """Used by the worker to send this task to the pool.
 
         :param pool: A :class:`celery.concurrency.base.TaskPool` instance.
 
@@ -538,14 +555,7 @@ class Request(object):
         # used by rpc backend when failures reported by parent process
         return self.request_dict['reply_to']
 
-
-class TaskRequest(Request):
-
-    def __init__(self, name, id, args=(), kwargs={},
-                 eta=None, expires=None, **options):
-        """Compatibility class."""
-
-        super(TaskRequest, self).__init__({
-            'task': name, 'id': id, 'args': args,
-            'kwargs': kwargs, 'eta': eta,
-            'expires': expires}, **options)
+    @property
+    def correlation_id(self):
+        # used similarly to reply_to
+        return self.request_dict['correlation_id']

+ 1 - 2
celery/worker/strategy.py

@@ -47,8 +47,7 @@ def default(task, app, consumer,
                   app=app, hostname=hostname,
                   eventer=eventer, task=task,
                   connection_errors=connection_errors,
-                  delivery_info=message.delivery_info,
-                  headers=message.headers)
+                  message=message)
         if req.revoked():
             return
 

+ 10 - 0
docs/userguide/tasks.rst

@@ -241,6 +241,16 @@ The request defines the following attributes:
 :utc: Set to true the caller has utc enabled (:setting:`CELERY_ENABLE_UTC`).
 
 
+.. versionadded:: 3.1
+
+:headers:  Mapping of message headers (may be :const:`None`).
+
+:reply_to:  Where to send reply to (queue name).
+
+:correlation_id: Usually the same as the task id, often used in amqp
+                 to keep track of what a reply is for.
+
+
 An example task accessing information in the context is:
 
 .. code-block:: python

+ 1 - 0
funtests/stress/stress/templates.py

@@ -49,6 +49,7 @@ class default(object):
     CELERY_DEFAULT_QUEUE = CSTRESS_QUEUE
     CELERY_TASK_SERIALIZER = 'json'
     CELERY_RESULT_SERIALIZER = 'json'
+    CELERY_RESULT_PERSISTENT = False
     CELERY_QUEUES = [
         Queue(CSTRESS_QUEUE,
               exchange=Exchange(CSTRESS_QUEUE),