Sfoglia il codice sorgente

TaskProducer replaced by create_task_message and send_task_message

Ask Solem 11 anni fa
parent
commit
8b7e3f2e9a

+ 188 - 223
celery/app/amqp.py

@@ -10,13 +10,14 @@ from __future__ import absolute_import
 
 
 import numbers
 import numbers
 
 
+from collections import Mapping, namedtuple
 from datetime import timedelta
 from datetime import timedelta
 from weakref import WeakValueDictionary
 from weakref import WeakValueDictionary
 
 
 from kombu import Connection, Consumer, Exchange, Producer, Queue
 from kombu import Connection, Consumer, Exchange, Producer, Queue
 from kombu.common import Broadcast
 from kombu.common import Broadcast
 from kombu.pools import ProducerPool
 from kombu.pools import ProducerPool
-from kombu.utils import cached_property, uuid
+from kombu.utils import cached_property
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
 from kombu.utils.functional import maybe_list
 from kombu.utils.functional import maybe_list
 
 
@@ -25,10 +26,9 @@ from celery.five import items, string_t
 from celery.utils.text import indent as textindent
 from celery.utils.text import indent as textindent
 from celery.utils.timeutils import to_utc
 from celery.utils.timeutils import to_utc
 
 
-from . import app_or_default
 from . import routes as _routes
 from . import routes as _routes
 
 
-__all__ = ['AMQP', 'Queues', 'TaskProducer', 'TaskConsumer']
+__all__ = ['AMQP', 'Queues', 'task_message']
 
 
 #: Human readable queue declaration.
 #: Human readable queue declaration.
 QUEUE_FORMAT = """
 QUEUE_FORMAT = """
@@ -36,6 +36,9 @@ QUEUE_FORMAT = """
 key={0.routing_key}
 key={0.routing_key}
 """
 """
 
 
+task_message = namedtuple('task_message',
+                          ('headers', 'properties', 'body', 'sent_event'))
+
 
 
 class Queues(dict):
 class Queues(dict):
     """Queue name⇒ declaration mapping.
     """Queue name⇒ declaration mapping.
@@ -184,204 +187,14 @@ class Queues(dict):
         return self
         return self
 
 
 
 
-class TaskProducer(Producer):
-    app = None
-    auto_declare = False
-    retry = False
-    retry_policy = None
-    utc = True
-    event_dispatcher = None
-    send_sent_event = False
-
-    def __init__(self, channel=None, exchange=None, *args, **kwargs):
-        self.retry = kwargs.pop('retry', self.retry)
-        self.retry_policy = kwargs.pop('retry_policy',
-                                       self.retry_policy or {})
-        self.send_sent_event = kwargs.pop('send_sent_event',
-                                          self.send_sent_event)
-        exchange = exchange or self.exchange
-        self.queues = self.app.amqp.queues  # shortcut
-        self.default_queue = self.app.amqp.default_queue
-        super(TaskProducer, self).__init__(channel, exchange, *args, **kwargs)
-
-    def publish_task(self, task_name, task_args=None, task_kwargs=None,
-                     countdown=None, eta=None, task_id=None, group_id=None,
-                     taskset_id=None,  # compat alias to group_id
-                     expires=None, exchange=None, exchange_type=None,
-                     event_dispatcher=None, retry=None, retry_policy=None,
-                     queue=None, now=None, retries=0, chord=None,
-                     callbacks=None, errbacks=None, routing_key=None,
-                     serializer=None, delivery_mode=None, compression=None,
-                     reply_to=None, time_limit=None, soft_time_limit=None,
-                     declare=None, headers=None,
-                     send_before_publish=signals.before_task_publish.send,
-                     before_receivers=signals.before_task_publish.receivers,
-                     send_after_publish=signals.after_task_publish.send,
-                     after_receivers=signals.after_task_publish.receivers,
-                     send_task_sent=signals.task_sent.send,  # XXX deprecated
-                     sent_receivers=signals.task_sent.receivers,
-                     **kwargs):
-        """Send task message."""
-        retry = self.retry if retry is None else retry
-        headers = {} if headers is None else headers
-
-        qname = queue
-        if queue is None and exchange is None:
-            queue = self.default_queue
-        if queue is not None:
-            if isinstance(queue, string_t):
-                qname, queue = queue, self.queues[queue]
-            else:
-                qname = queue.name
-            exchange = exchange or queue.exchange.name
-            routing_key = routing_key or queue.routing_key
-        if declare is None and queue and not isinstance(queue, Broadcast):
-            declare = [queue]
-
-        # merge default and custom policy
-        retry = self.retry if retry is None else retry
-        _rp = (dict(self.retry_policy, **retry_policy) if retry_policy
-               else self.retry_policy)
-        task_id = task_id or uuid()
-        task_args = task_args or []
-        task_kwargs = task_kwargs or {}
-        if not isinstance(task_args, (list, tuple)):
-            raise ValueError('task args must be a list or tuple')
-        if not isinstance(task_kwargs, dict):
-            raise ValueError('task kwargs must be a dictionary')
-        if countdown:  # Convert countdown to ETA.
-            now = now or self.app.now()
-            eta = now + timedelta(seconds=countdown)
-            if self.utc:
-                eta = to_utc(eta).astimezone(self.app.timezone)
-        if isinstance(expires, numbers.Real):
-            now = now or self.app.now()
-            expires = now + timedelta(seconds=expires)
-            if self.utc:
-                expires = to_utc(expires).astimezone(self.app.timezone)
-        eta = eta and eta.isoformat()
-        expires = expires and expires.isoformat()
-
-        body = {
-            'task': task_name,
-            'id': task_id,
-            'args': task_args,
-            'kwargs': task_kwargs,
-            'retries': retries or 0,
-            'eta': eta,
-            'expires': expires,
-            'utc': self.utc,
-            'callbacks': callbacks,
-            'errbacks': errbacks,
-            'timelimit': (time_limit, soft_time_limit),
-            'taskset': group_id or taskset_id,
-            'chord': chord,
-        }
-
-        if before_receivers:
-            send_before_publish(
-                sender=task_name, body=body,
-                exchange=exchange,
-                routing_key=routing_key,
-                declare=declare,
-                headers=headers,
-                properties=kwargs,
-                retry_policy=retry_policy,
-            )
-
-        self.publish(
-            body,
-            exchange=exchange, routing_key=routing_key,
-            serializer=serializer or self.serializer,
-            compression=compression or self.compression,
-            headers=headers,
-            retry=retry, retry_policy=_rp,
-            reply_to=reply_to,
-            correlation_id=task_id,
-            delivery_mode=delivery_mode, declare=declare,
-            **kwargs
-        )
-
-        if after_receivers:
-            send_after_publish(sender=task_name, body=body,
-                               exchange=exchange, routing_key=routing_key)
-
-        if sent_receivers:  # XXX deprecated
-            send_task_sent(sender=task_name, task_id=task_id,
-                           task=task_name, args=task_args,
-                           kwargs=task_kwargs, eta=eta,
-                           taskset=group_id or taskset_id)
-        if self.send_sent_event:
-            evd = event_dispatcher or self.event_dispatcher
-            exname = exchange or self.exchange
-            if isinstance(exname, Exchange):
-                exname = exname.name
-            evd.publish(
-                'task-sent',
-                {
-                    'uuid': task_id,
-                    'name': task_name,
-                    'args': safe_repr(task_args),
-                    'kwargs': safe_repr(task_kwargs),
-                    'retries': retries,
-                    'eta': eta,
-                    'expires': expires,
-                    'queue': qname,
-                    'exchange': exname,
-                    'routing_key': routing_key,
-                },
-                self, retry=retry, retry_policy=retry_policy,
-            )
-        return task_id
-    delay_task = publish_task   # XXX Compat
-
-    @cached_property
-    def event_dispatcher(self):
-        # We call Dispatcher.publish with a custom producer
-        # so don't need the dispatcher to be "enabled".
-        return self.app.events.Dispatcher(enabled=False)
-
-
-class TaskPublisher(TaskProducer):
-    """Deprecated version of :class:`TaskProducer`."""
-
-    def __init__(self, channel=None, exchange=None, *args, **kwargs):
-        self.app = app_or_default(kwargs.pop('app', self.app))
-        self.retry = kwargs.pop('retry', self.retry)
-        self.retry_policy = kwargs.pop('retry_policy',
-                                       self.retry_policy or {})
-        exchange = exchange or self.exchange
-        if not isinstance(exchange, Exchange):
-            exchange = Exchange(exchange,
-                                kwargs.pop('exchange_type', 'direct'))
-        self.queues = self.app.amqp.queues  # shortcut
-        super(TaskPublisher, self).__init__(channel, exchange, *args, **kwargs)
-
-
-class TaskConsumer(Consumer):
-    app = None
-
-    def __init__(self, channel, queues=None, app=None, accept=None, **kw):
-        self.app = app or self.app
-        if accept is None:
-            accept = self.app.conf.CELERY_ACCEPT_CONTENT
-        super(TaskConsumer, self).__init__(
-            channel,
-            queues or list(self.app.amqp.queues.consume_from.values()),
-            accept=accept,
-            **kw
-        )
-
-
 class AMQP(object):
 class AMQP(object):
     Connection = Connection
     Connection = Connection
     Consumer = Consumer
     Consumer = Consumer
+    Producer = Producer
 
 
     #: compat alias to Connection
     #: compat alias to Connection
     BrokerConnection = Connection
     BrokerConnection = Connection
 
 
-    producer_cls = TaskProducer
-    consumer_cls = TaskConsumer
     queues_cls = Queues
     queues_cls = Queues
 
 
     #: Cached and prepared routing table.
     #: Cached and prepared routing table.
@@ -400,6 +213,18 @@ class AMQP(object):
     def __init__(self, app):
     def __init__(self, app):
         self.app = app
         self.app = app
 
 
+    @cached_property
+    def _task_retry(self):
+        return self.app.conf.CELERY_TASK_PUBLISH_RETRY
+
+    @cached_property
+    def _task_retry_policy(self):
+        return self.app.conf.CELERY_TASK_PUBLISH_RETRY_POLICY
+
+    @cached_property
+    def _task_sent_event(self):
+        return self.app.conf.CELERY_SEND_TASK_SENT_EVENT
+
     def flush_routes(self):
     def flush_routes(self):
         self._rtable = _routes.prepare(self.app.conf.CELERY_ROUTES)
         self._rtable = _routes.prepare(self.app.conf.CELERY_ROUTES)
 
 
@@ -429,35 +254,14 @@ class AMQP(object):
                               self.app.either('CELERY_CREATE_MISSING_QUEUES',
                               self.app.either('CELERY_CREATE_MISSING_QUEUES',
                                               create_missing), app=self.app)
                                               create_missing), app=self.app)
 
 
-    @cached_property
-    def TaskConsumer(self):
-        """Return consumer configured to consume from the queues
-        we are configured for (``app.amqp.queues.consume_from``)."""
-        return self.app.subclass_with_self(self.consumer_cls,
-                                           reverse='amqp.TaskConsumer')
-    get_task_consumer = TaskConsumer  # XXX compat
-
-    @cached_property
-    def TaskProducer(self):
-        """Return publisher used to send tasks.
-
-        You should use `app.send_task` instead.
-
-        """
-        conf = self.app.conf
-        return self.app.subclass_with_self(
-            self.producer_cls,
-            reverse='amqp.TaskProducer',
-            exchange=self.default_exchange,
-            routing_key=conf.CELERY_DEFAULT_ROUTING_KEY,
-            serializer=conf.CELERY_TASK_SERIALIZER,
-            compression=conf.CELERY_MESSAGE_COMPRESSION,
-            retry=conf.CELERY_TASK_PUBLISH_RETRY,
-            retry_policy=conf.CELERY_TASK_PUBLISH_RETRY_POLICY,
-            send_sent_event=conf.CELERY_SEND_TASK_SENT_EVENT,
-            utc=conf.CELERY_ENABLE_UTC,
+    def TaskConsumer(self, channel, queues=None, accept=None, **kw):
+        if accept is None:
+            accept = self.app.conf.CELERY_ACCEPT_CONTENT
+        return self.Consumer(
+            channel, accept=accept,
+            queues=queues or list(self.queues.consume_from.values()),
+            **kw
         )
         )
-    TaskPublisher = TaskProducer  # compat
 
 
     @cached_property
     @cached_property
     def default_queue(self):
     def default_queue(self):
@@ -488,7 +292,7 @@ class AMQP(object):
             self._producer_pool = ProducerPool(
             self._producer_pool = ProducerPool(
                 self.app.pool,
                 self.app.pool,
                 limit=self.app.pool.limit,
                 limit=self.app.pool.limit,
-                Producer=self.TaskProducer,
+                Producer=self.Producer,
             )
             )
         return self._producer_pool
         return self._producer_pool
     publisher_pool = producer_pool  # compat alias
     publisher_pool = producer_pool  # compat alias
@@ -497,3 +301,164 @@ class AMQP(object):
     def default_exchange(self):
     def default_exchange(self):
         return Exchange(self.app.conf.CELERY_DEFAULT_EXCHANGE,
         return Exchange(self.app.conf.CELERY_DEFAULT_EXCHANGE,
                         self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE)
                         self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE)
+
+    def create_task_message(self, task_id, name, args=None, kwargs=None,
+                            countdown=None, eta=None, group_id=None,
+                            expires=None, now=None, retries=0, chord=None,
+                            callbacks=None, errbacks=None, reply_to=None,
+                            time_limit=None, soft_time_limit=None,
+                            create_sent_event=False):
+        args = args or ()
+        kwargs = kwargs or {}
+        utc = self.utc
+        if not isinstance(args, (list, tuple)):
+            raise ValueError('task args must be a list or tuple')
+        if not isinstance(kwargs, Mapping):
+            raise ValueError('task keyword arguments must be a mapping')
+        if countdown:  # convert countdown to ETA
+            now = now or self.app.now()
+            eta = now + timedelta(seconds=countdown)
+            if utc:
+                eta = to_utc(eta).astimezone(self.app.timezone)
+        if isinstance(expires, numbers.Real):
+            now = now or self.app.now()
+            expires = now + timedelta(seconds=expires)
+            if utc:
+                expires = to_utc(expires).astimezone(self.app.timezone)
+        eta = eta and eta.isoformat()
+        expires = expires and expires.isoformat()
+
+        return task_message(
+            {},
+            {
+                'correlation_id': task_id,
+                'reply_to': reply_to,
+            },
+            {
+                'task': name,
+                'id': task_id,
+                'args': args,
+                'kwargs': kwargs,
+                'retries': retries,
+                'eta': eta,
+                'expires': expires,
+                'utc': utc,
+                'callbacks': callbacks,
+                'errbacks': errbacks,
+                'timelimit': (time_limit, soft_time_limit),
+                'taskset': group_id,
+                'chord': chord,
+            },
+            {
+                'uuid': task_id,
+                'name': name,
+                'args': safe_repr(args),
+                'kwargs': safe_repr(kwargs),
+                'retries': retries,
+                'eta': eta,
+                'expires': expires,
+            } if create_sent_event else None,
+        )
+
+    def _create_task_sender(self):
+        default_retry = self.app.conf.CELERY_TASK_PUBLISH_RETRY
+        default_policy = self.app.conf.CELERY_TASK_PUBLISH_RETRY_POLICY
+        default_queue = self.default_queue
+        queues = self.queues
+        send_before_publish = signals.before_task_publish.send
+        before_receivers = signals.before_task_publish.receivers
+        send_after_publish = signals.after_task_publish.send
+        after_receivers = signals.after_task_publish.receivers
+
+        send_task_sent = signals.task_sent.send   # XXX compat
+        sent_receivers = signals.task_sent.receivers
+
+        default_evd = self._event_dispatcher
+        default_exchange = self.default_exchange
+
+        default_rkey = self.app.conf.CELERY_DEFAULT_ROUTING_KEY
+        default_serializer = self.app.conf.CELERY_TASK_SERIALIZER
+        default_compressor = self.app.conf.CELERY_MESSAGE_COMPRESSION
+
+        def publish_task(producer, name, message,
+                         exchange=None, routing_key=None, queue=None,
+                         event_dispatcher=None, retry=None, retry_policy=None,
+                         serializer=None, delivery_mode=None,
+                         compression=None, declare=None,
+                         headers=None, **kwargs):
+            retry = default_retry if retry is None else retry
+            headers, properties, body, sent_event = message
+            if kwargs:
+                properties.update(kwargs)
+
+            qname = queue
+            if queue is None and exchange is None:
+                queue = default_queue
+            if queue is not None:
+                if isinstance(queue, string_t):
+                    qname, queue = queue, queues[queue]
+                else:
+                    qname = queue.name
+            exchange = exchange or queue.exchange.name
+            routing_key = routing_key or queue.routing_key
+            if declare is None and queue and not isinstance(queue, Broadcast):
+                declare = [queue]
+
+            # merge default and custom policy
+            retry = default_retry if retry is None else retry
+            _rp = (dict(default_policy, **retry_policy) if retry_policy
+                   else default_policy)
+
+            if before_receivers:
+                send_before_publish(
+                    sender=name, body=body,
+                    exchange=exchange, routing_key=routing_key,
+                    declare=declare, headers=headers,
+                    properties=kwargs,  retry_policy=retry_policy,
+                )
+            ret = producer.publish(
+                body,
+                exchange=exchange or default_exchange,
+                routing_key=routing_key or default_rkey,
+                serializer=serializer or default_serializer,
+                compression=compression or default_compressor,
+                retry=retry, retry_policy=_rp,
+                delivery_mode=delivery_mode, declare=declare,
+                headers=headers,
+                **properties
+            )
+            if after_receivers:
+                send_after_publish(sender=name, body=body,
+                                   exchange=exchange, routing_key=routing_key)
+            if sent_receivers:  # XXX deprecated
+                send_task_sent(sender=name, task_id=body['id'], task=name,
+                               args=body['args'], kwargs=body['kwargs'],
+                               eta=body['eta'], taskset=body['taskset'])
+            if sent_event:
+                evd = event_dispatcher or default_evd
+                exname = exchange or self.exchange
+                if isinstance(name, Exchange):
+                    exname = exname.name
+                sent_event.update({
+                    'queue': qname,
+                    'exchange': exname,
+                    'routing_key': routing_key,
+                })
+                evd.publish('task-sent', sent_event,
+                            self, retry=retry, retry_policy=retry_policy)
+            return ret
+        return publish_task
+
+    @cached_property
+    def send_task_message(self):
+        return self._create_task_sender()
+
+    @cached_property
+    def utc(self):
+        return self.app.conf.CELERY_ENABLE_UTC
+
+    @cached_property
+    def _event_dispatcher(self):
+        # We call Dispatcher.publish with a custom producer
+        # so don't need the diuspatcher to be enabled.
+        return self.app.events.Dispatcher(enabled=False)

+ 16 - 9
celery/app/base.py

@@ -302,26 +302,33 @@ class Celery(object):
                   eta=None, task_id=None, producer=None, connection=None,
                   eta=None, task_id=None, producer=None, connection=None,
                   router=None, result_cls=None, expires=None,
                   router=None, result_cls=None, expires=None,
                   publisher=None, link=None, link_error=None,
                   publisher=None, link=None, link_error=None,
-                  add_to_parent=True, reply_to=None, **options):
+                  add_to_parent=True, group_id=None, retries=0, chord=None,
+                  reply_to=None, time_limit=None, soft_time_limit=None,
+                  **options):
+        amqp = self.amqp
         task_id = task_id or uuid()
         task_id = task_id or uuid()
         producer = producer or publisher  # XXX compat
         producer = producer or publisher  # XXX compat
-        router = router or self.amqp.router
+        router = router or amqp.router
         conf = self.conf
         conf = self.conf
         if conf.CELERY_ALWAYS_EAGER:  # pragma: no cover
         if conf.CELERY_ALWAYS_EAGER:  # pragma: no cover
             warnings.warn(AlwaysEagerIgnored(
             warnings.warn(AlwaysEagerIgnored(
                 'CELERY_ALWAYS_EAGER has no effect on send_task',
                 'CELERY_ALWAYS_EAGER has no effect on send_task',
             ), stacklevel=2)
             ), stacklevel=2)
         options = router.route(options, name, args, kwargs)
         options = router.route(options, name, args, kwargs)
+
+        message = amqp.create_task_message(
+            task_id, name, args, kwargs, countdown, eta, group_id,
+            expires, retries, chord,
+            maybe_list(link), maybe_list(link_error),
+            reply_to or self.oid, time_limit, soft_time_limit,
+            self.conf.CELERY_SEND_TASK_SENT_EVENT,
+        )
+
         if connection:
         if connection:
-            producer = self.amqp.TaskProducer(connection)
+            producer = amqp.Producer(connection)
         with self.producer_or_acquire(producer) as P:
         with self.producer_or_acquire(producer) as P:
             self.backend.on_task_call(P, task_id)
             self.backend.on_task_call(P, task_id)
-            task_id = P.publish_task(
-                name, args, kwargs, countdown=countdown, eta=eta,
-                task_id=task_id, expires=expires,
-                callbacks=maybe_list(link), errbacks=maybe_list(link_error),
-                reply_to=reply_to or self.oid, **options
-            )
+            amqp.send_task_message(P, name, message, **options)
         result = (result_cls or self.AsyncResult)(task_id)
         result = (result_cls or self.AsyncResult)(task_id)
         if add_to_parent:
         if add_to_parent:
             parent = get_current_worker_task()
             parent = get_current_worker_task()

+ 1 - 1
celery/app/task.py

@@ -525,7 +525,7 @@ class Task(object):
         :keyword link_error: A single, or a list of tasks to apply
         :keyword link_error: A single, or a list of tasks to apply
                       if an error occurs while executing the task.
                       if an error occurs while executing the task.
 
 
-        :keyword producer: :class:~@amqp.TaskProducer` instance to use.
+        :keyword producer: :class:~@kombu.Producer` instance to use.
         :keyword add_to_parent: If set to True (default) and the task
         :keyword add_to_parent: If set to True (default) and the task
             is applied while executing another task, then the result
             is applied while executing another task, then the result
             will be appended to the parent tasks ``request.children``
             will be appended to the parent tasks ``request.children``

+ 1 - 1
celery/beat.py

@@ -179,7 +179,7 @@ class Scheduler(object):
         self.sync_every_tasks = (
         self.sync_every_tasks = (
             app.conf.CELERYBEAT_SYNC_EVERY if sync_every_tasks is None
             app.conf.CELERYBEAT_SYNC_EVERY if sync_every_tasks is None
             else sync_every_tasks)
             else sync_every_tasks)
-        self.Publisher = Publisher or app.amqp.TaskProducer
+        self.Publisher = Publisher or app.amqp.Producer
         if not lazy:
         if not lazy:
             self.setup_schedule()
             self.setup_schedule()
 
 

+ 2 - 2
celery/contrib/migrate.py

@@ -99,7 +99,7 @@ def migrate_tasks(source, dest, migrate=migrate_task, app=None,
                   queues=None, **kwargs):
                   queues=None, **kwargs):
     app = app_or_default(app)
     app = app_or_default(app)
     queues = prepare_queues(queues)
     queues = prepare_queues(queues)
-    producer = app.amqp.TaskProducer(dest)
+    producer = app.amqp.Producer(dest)
     migrate = partial(migrate, producer, queues=queues)
     migrate = partial(migrate, producer, queues=queues)
 
 
     def on_declare_queue(queue):
     def on_declare_queue(queue):
@@ -186,7 +186,7 @@ def move(predicate, connection=None, exchange=None, routing_key=None,
     app = app_or_default(app)
     app = app_or_default(app)
     queues = [_maybe_queue(app, queue) for queue in source or []] or None
     queues = [_maybe_queue(app, queue) for queue in source or []] or None
     with app.connection_or_acquire(connection, pool=False) as conn:
     with app.connection_or_acquire(connection, pool=False) as conn:
-        producer = app.amqp.TaskProducer(conn)
+        producer = app.amqp.Producer(conn)
         state = State()
         state = State()
 
 
         def on_task(body, message):
         def on_task(body, message):

+ 0 - 1
celery/five.py

@@ -238,7 +238,6 @@ COMPAT_MODULES = {
             'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger',
             'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger',
         },
         },
         'messaging': {
         'messaging': {
-            'TaskPublisher': 'amqp.TaskPublisher',
             'TaskConsumer': 'amqp.TaskConsumer',
             'TaskConsumer': 'amqp.TaskConsumer',
             'establish_connection': 'connection',
             'establish_connection': 'connection',
             'get_consumer_set': 'amqp.TaskConsumer',
             'get_consumer_set': 'amqp.TaskConsumer',

+ 11 - 4
celery/task/base.py

@@ -106,12 +106,19 @@ class Task(BaseTask):
                       exchange_type=None, **options):
                       exchange_type=None, **options):
         """Deprecated method to get the task publisher (now called producer).
         """Deprecated method to get the task publisher (now called producer).
 
 
-        Should be replaced with :class:`@amqp.TaskProducer`:
+        Should be replaced with :class:`@kombu.Producer`:
 
 
         .. code-block:: python
         .. code-block:: python
 
 
-            with celery.connection() as conn:
-                with celery.amqp.TaskProducer(conn) as prod:
+            with app.connection() as conn:
+                with app.amqp.Producer(conn) as prod:
+                    my_task.apply_async(producer=prod)
+
+            or event better is to use the :class:`@amqp.producer_pool`:
+
+            .. code-block:: python
+
+                with app.producer_or_acquire() as prod:
                     my_task.apply_async(producer=prod)
                     my_task.apply_async(producer=prod)
 
 
         """
         """
@@ -119,7 +126,7 @@ class Task(BaseTask):
         if exchange_type is None:
         if exchange_type is None:
             exchange_type = self.exchange_type
             exchange_type = self.exchange_type
         connection = connection or self.establish_connection()
         connection = connection or self.establish_connection()
-        return self._get_app().amqp.TaskProducer(
+        return self._get_app().amqp.Producer(
             connection,
             connection,
             exchange=exchange and Exchange(exchange, exchange_type),
             exchange=exchange and Exchange(exchange, exchange_type),
             routing_key=self.routing_key, **options
             routing_key=self.routing_key, **options

+ 1 - 1
celery/task/sets.py

@@ -46,7 +46,7 @@ class TaskSet(list):
         super(TaskSet, self).__init__(
         super(TaskSet, self).__init__(
             maybe_signature(t, app=self.app) for t in tasks or []
             maybe_signature(t, app=self.app) for t in tasks or []
         )
         )
-        self.Publisher = Publisher or self.app.amqp.TaskProducer
+        self.Publisher = Publisher or self.app.amqp.Producer
         self.total = len(self)  # XXX compat
         self.total = len(self)  # XXX compat
 
 
     def apply_async(self, connection=None, publisher=None, taskset_id=None):
     def apply_async(self, connection=None, publisher=None, taskset_id=None):

+ 2 - 94
celery/tests/app/test_amqp.py

@@ -1,86 +1,10 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-import datetime
-
-import pytz
-
 from kombu import Exchange, Queue
 from kombu import Exchange, Queue
 
 
-from celery.app.amqp import Queues, TaskPublisher
+from celery.app.amqp import Queues
 from celery.five import keys
 from celery.five import keys
-from celery.tests.case import AppCase, Mock
-
-
-class test_TaskProducer(AppCase):
-
-    def test__exit__(self):
-        publisher = self.app.amqp.TaskProducer(self.app.connection())
-        publisher.release = Mock()
-        with publisher:
-            pass
-        publisher.release.assert_called_with()
-
-    def test_declare(self):
-        publisher = self.app.amqp.TaskProducer(self.app.connection())
-        publisher.exchange.name = 'foo'
-        publisher.declare()
-        publisher.exchange.name = None
-        publisher.declare()
-
-    def test_retry_policy(self):
-        prod = self.app.amqp.TaskProducer(Mock())
-        prod.channel.connection.client.declared_entities = set()
-        prod.publish_task('tasks.add', (2, 2), {},
-                          retry_policy={'frobulate': 32.4})
-
-    def test_publish_no_retry(self):
-        prod = self.app.amqp.TaskProducer(Mock())
-        prod.channel.connection.client.declared_entities = set()
-        prod.publish_task('tasks.add', (2, 2), {}, retry=False, chord=123)
-        self.assertFalse(prod.connection.ensure.call_count)
-
-    def test_publish_custom_queue(self):
-        prod = self.app.amqp.TaskProducer(Mock())
-        self.app.amqp.queues['some_queue'] = Queue(
-            'xxx', Exchange('yyy'), 'zzz',
-        )
-        prod.channel.connection.client.declared_entities = set()
-        prod.publish = Mock()
-        prod.publish_task('tasks.add', (8, 8), {}, retry=False,
-                          queue='some_queue')
-        self.assertEqual(prod.publish.call_args[1]['exchange'], 'yyy')
-        self.assertEqual(prod.publish.call_args[1]['routing_key'], 'zzz')
-
-    def test_publish_with_countdown(self):
-        prod = self.app.amqp.TaskProducer(Mock())
-        prod.channel.connection.client.declared_entities = set()
-        prod.publish = Mock()
-        now = datetime.datetime(2013, 11, 26, 16, 48, 46)
-        prod.publish_task('tasks.add', (1, 1), {}, retry=False,
-                          countdown=10, now=now)
-        self.assertEqual(
-            prod.publish.call_args[0][0]['eta'],
-            '2013-11-26T16:48:56+00:00',
-        )
-
-    def test_publish_with_countdown_and_timezone(self):
-        # use timezone with fixed offset to be sure it won't be changed
-        self.app.conf.CELERY_TIMEZONE = pytz.FixedOffset(120)
-        prod = self.app.amqp.TaskProducer(Mock())
-        prod.channel.connection.client.declared_entities = set()
-        prod.publish = Mock()
-        now = datetime.datetime(2013, 11, 26, 16, 48, 46)
-        prod.publish_task('tasks.add', (2, 2), {}, retry=False,
-                          countdown=20, now=now)
-        self.assertEqual(
-            prod.publish.call_args[0][0]['eta'],
-            '2013-11-26T18:49:06+02:00',
-        )
-
-    def test_event_dispatcher(self):
-        prod = self.app.amqp.TaskProducer(Mock())
-        self.assertTrue(prod.event_dispatcher)
-        self.assertFalse(prod.event_dispatcher.enabled)
+from celery.tests.case import AppCase
 
 
 
 
 class test_TaskConsumer(AppCase):
 class test_TaskConsumer(AppCase):
@@ -98,22 +22,6 @@ class test_TaskConsumer(AppCase):
             )
             )
 
 
 
 
-class test_compat_TaskPublisher(AppCase):
-
-    def test_compat_exchange_is_string(self):
-        producer = TaskPublisher(exchange='foo', app=self.app)
-        self.assertIsInstance(producer.exchange, Exchange)
-        self.assertEqual(producer.exchange.name, 'foo')
-        self.assertEqual(producer.exchange.type, 'direct')
-        producer = TaskPublisher(exchange='foo', exchange_type='topic',
-                                 app=self.app)
-        self.assertEqual(producer.exchange.type, 'topic')
-
-    def test_compat_exchange_is_Exchange(self):
-        producer = TaskPublisher(exchange=Exchange('foo'), app=self.app)
-        self.assertEqual(producer.exchange.name, 'foo')
-
-
 class test_PublisherPool(AppCase):
 class test_PublisherPool(AppCase):
 
 
     def test_setup_nolimit(self):
     def test_setup_nolimit(self):

+ 19 - 16
celery/tests/app/test_app.py

@@ -8,7 +8,6 @@ from copy import deepcopy
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
 from amqp import promise
 from amqp import promise
-from kombu import Exchange
 
 
 from celery import shared_task, current_app
 from celery import shared_task, current_app
 from celery import app as _app
 from celery import app as _app
@@ -336,10 +335,13 @@ class test_App(AppCase):
         def aawsX():
         def aawsX():
             pass
             pass
 
 
-        with patch('celery.app.amqp.TaskProducer.publish_task') as dt:
-            aawsX.apply_async((4, 5))
-            args = dt.call_args[0][1]
-            self.assertEqual(args, ('hello', 4, 5))
+        with patch('celery.app.amqp.AMQP.create_task_message') as create:
+            with patch('celery.app.amqp.AMQP.send_task_message') as send:
+                create.return_value = Mock(), Mock(), Mock(), Mock()
+                aawsX.apply_async((4, 5))
+                args = create.call_args[0][2]
+                self.assertEqual(args, ('hello', 4, 5))
+                self.assertTrue(send.called)
 
 
     def test_apply_async_adds_children(self):
     def test_apply_async_adds_children(self):
         from celery._state import _task_stack
         from celery._state import _task_stack
@@ -609,22 +611,23 @@ class test_App(AppCase):
             chan.close()
             chan.close()
         assert conn.transport_cls == 'memory'
         assert conn.transport_cls == 'memory'
 
 
-        prod = self.app.amqp.TaskProducer(
-            conn, exchange=Exchange('foo_exchange'),
-            send_sent_event=True,
+        message = self.app.amqp.create_task_message(
+            'id', 'footask', (), {}, create_sent_event=True,
         )
         )
 
 
+        prod = self.app.amqp.Producer(conn)
         dispatcher = Dispatcher()
         dispatcher = Dispatcher()
-        self.assertTrue(prod.publish_task('footask', (), {},
-                                          exchange='moo_exchange',
-                                          routing_key='moo_exchange',
-                                          event_dispatcher=dispatcher))
+        self.app.amqp.send_task_message(
+            prod, 'footask', message,
+            exchange='moo_exchange', routing_key='moo_exchange',
+            event_dispatcher=dispatcher,
+        )
         self.assertTrue(dispatcher.sent)
         self.assertTrue(dispatcher.sent)
         self.assertEqual(dispatcher.sent[0][0], 'task-sent')
         self.assertEqual(dispatcher.sent[0][0], 'task-sent')
-        self.assertTrue(prod.publish_task('footask', (), {},
-                                          event_dispatcher=dispatcher,
-                                          exchange='bar_exchange',
-                                          routing_key='bar_exchange'))
+        self.app.amqp.send_task_message(
+            prod, 'footask', message, event_dispatcher=dispatcher,
+            exchange='bar_exchange', routing_key='bar_exchange',
+        )
 
 
     def test_error_mail_sender(self):
     def test_error_mail_sender(self):
         x = ErrorMail.subject % {'name': 'task_name',
         x = ErrorMail.subject % {'name': 'task_name',

+ 3 - 3
celery/tests/backends/test_amqp.py

@@ -108,8 +108,8 @@ class test_AMQPBackend(AppCase):
             raise KeyError('foo')
             raise KeyError('foo')
 
 
         backend = AMQPBackend(self.app)
         backend = AMQPBackend(self.app)
-        from celery.app.amqp import TaskProducer
-        prod, TaskProducer.publish = TaskProducer.publish, publish
+        from celery.app.amqp import Producer
+        prod, Producer.publish = Producer.publish, publish
         try:
         try:
             with self.assertRaises(KeyError):
             with self.assertRaises(KeyError):
                 backend.retry_policy['max_retries'] = None
                 backend.retry_policy['max_retries'] = None
@@ -119,7 +119,7 @@ class test_AMQPBackend(AppCase):
                 backend.retry_policy['max_retries'] = 10
                 backend.retry_policy['max_retries'] = 10
                 backend.store_result('foo', 'bar', 'STARTED')
                 backend.store_result('foo', 'bar', 'STARTED')
         finally:
         finally:
-            TaskProducer.publish = prod
+            Producer.publish = prod
 
 
     def assertState(self, retval, state):
     def assertState(self, retval, state):
         self.assertEqual(retval['status'], state)
         self.assertEqual(retval['status'], state)

+ 0 - 5
celery/tests/tasks/test_tasks.py

@@ -381,11 +381,6 @@ class test_tasks(TasksCase):
         finally:
         finally:
             self.mytask.pop_request()
             self.mytask.pop_request()
 
 
-    def test_send_task_sent_event(self):
-        with self.app.connection() as conn:
-            self.app.conf.CELERY_SEND_TASK_SENT_EVENT = True
-            self.assertTrue(self.app.amqp.TaskProducer(conn).send_sent_event)
-
     def test_update_state(self):
     def test_update_state(self):
 
 
         @self.app.task(shared=False)
         @self.app.task(shared=False)

+ 8 - 11
docs/reference/celery.app.amqp.rst

@@ -17,7 +17,11 @@
 
 
         .. attribute:: Consumer
         .. attribute:: Consumer
 
 
-            Base Consumer class used.  Default is :class:`kombu.compat.Consumer`.
+            Base Consumer class used.  Default is :class:`kombu.Consumer`.
+
+        .. attribute:: Producer
+
+            Base Producer class used.  Default is :class:`kombu.Producer`.
 
 
         .. attribute:: queues
         .. attribute:: queues
 
 
@@ -25,13 +29,13 @@
 
 
         .. automethod:: Queues
         .. automethod:: Queues
         .. automethod:: Router
         .. automethod:: Router
-        .. autoattribute:: TaskConsumer
-        .. autoattribute:: TaskProducer
         .. automethod:: flush_routes
         .. automethod:: flush_routes
 
 
+        .. autoattribute:: create_task_message
+        .. autoattribute:: send_task_message
         .. autoattribute:: default_queue
         .. autoattribute:: default_queue
         .. autoattribute:: default_exchange
         .. autoattribute:: default_exchange
-        .. autoattribute:: publisher_pool
+        .. autoattribute:: producer_pool
         .. autoattribute:: router
         .. autoattribute:: router
         .. autoattribute:: routes
         .. autoattribute:: routes
 
 
@@ -41,10 +45,3 @@
     .. autoclass:: Queues
     .. autoclass:: Queues
         :members:
         :members:
         :undoc-members:
         :undoc-members:
-
-    TaskPublisher
-    -------------
-
-    .. autoclass:: TaskPublisher
-        :members:
-        :undoc-members:

+ 15 - 11
examples/eventlet/bulk_task_producer.py

@@ -3,8 +3,6 @@ from eventlet import spawn_n, monkey_patch, Timeout
 from eventlet.queue import LightQueue
 from eventlet.queue import LightQueue
 from eventlet.event import Event
 from eventlet.event import Event
 
 
-from celery import current_app
-
 monkey_patch()
 monkey_patch()
 
 
 
 
@@ -27,9 +25,16 @@ class Receipt(object):
 
 
 
 
 class ProducerPool(object):
 class ProducerPool(object):
+    """Usage::
+
+        >>> app = Celery(broker='amqp://')
+        >>> ProducerPool(app)
+
+    """
     Receipt = Receipt
     Receipt = Receipt
 
 
-    def __init__(self, size=20):
+    def __init__(self, app, size=20):
+        self.app = app
         self.size = size
         self.size = size
         self.inqueue = LightQueue()
         self.inqueue = LightQueue()
         self._running = None
         self._running = None
@@ -48,13 +53,12 @@ class ProducerPool(object):
         ]
         ]
 
 
     def _producer(self):
     def _producer(self):
-        connection = current_app.connection()
-        publisher = current_app.amqp.TaskProducer(connection)
         inqueue = self.inqueue
         inqueue = self.inqueue
 
 
-        while 1:
-            task, args, kwargs, options, receipt = inqueue.get()
-            result = task.apply_async(args, kwargs,
-                                      publisher=publisher,
-                                      **options)
-            receipt.finished(result)
+        with self.app.producer_or_acquire() as producer:
+            while 1:
+                task, args, kwargs, options, receipt = inqueue.get()
+                result = task.apply_async(args, kwargs,
+                                          producer=producer,
+                                          **options)
+                receipt.finished(result)