Browse Source

amqp related improvements

Ask Solem 13 years ago
parent
commit
fa957dbc29

+ 2 - 1
celery/app/amqp.py

@@ -319,9 +319,10 @@ class AMQP(object):
         return self.Router()
 
     @cached_property
-    def publisher_pool(self):
+    def producer_pool(self):
         return ProducerPool(self.app.pool, limit=self.app.pool.limit,
                             Producer=self.TaskProducer)
+    publisher_pool = producer_pool  # compat alias
 
     @cached_property
     def default_exchange(self):

+ 1 - 1
celery/app/base.py

@@ -223,7 +223,7 @@ class Celery(object):
         if producer:
             yield producer
         else:
-            with self.amqp.publisher_pool.acquire(block=True) as producer:
+            with self.amqp.producer_pool.acquire(block=True) as producer:
                 yield producer
 
     def with_default_connection(self, fun):

+ 21 - 22
celery/app/task.py

@@ -9,8 +9,8 @@
     :license: BSD, see LICENSE for more details.
 
 """
-
 from __future__ import absolute_import
+from __future__ import with_statement
 
 import logging
 import sys
@@ -460,8 +460,8 @@ class BaseTask(object):
         return self.apply_async(args, kwargs)
 
     def apply_async(self, args=None, kwargs=None,
-            task_id=None, publisher=None, connection=None,
-            router=None, link=None, link_error=None, **options):
+            task_id=None, producer=None, connection=None, router=None,
+            link=None, link_error=None, publisher=None, **options):
         """Apply tasks asynchronously by sending a message.
 
         :keyword args: The positional arguments to pass on to the
@@ -494,7 +494,7 @@ class BaseTask(object):
                         in the event of connection loss or failure.  Default
                         is taken from the :setting:`CELERY_TASK_PUBLISH_RETRY`
                         setting.  Note you need to handle the
-                        publisher/connection manually for this to work.
+                        producer/connection manually for this to work.
 
         :keyword retry_policy:  Override the retry policy used.  See the
                                 :setting:`CELERY_TASK_PUBLISH_RETRY` setting.
@@ -543,11 +543,15 @@ class BaseTask(object):
         :keyword link_error: A single, or a list of subtasks to apply
                       if an error occurs while executing the task.
 
+        :keyword producer: :class:~@amqp.TaskProducer` instance to use.
+        :keyword publisher: Deprecated alias to ``producer``.
+
         .. note::
             If the :setting:`CELERY_ALWAYS_EAGER` setting is set, it will
             be replaced by a local :func:`apply` call instead.
 
         """
+        producer = producer or publisher
         app = self._get_app()
         router = router or self.app.amqp.router
         conf = app.conf
@@ -562,24 +566,19 @@ class BaseTask(object):
         options = router.route(options, self.name, args, kwargs)
 
         if connection:
-            publisher = app.amqp.TaskProducer(connection)
-        publish = publisher or app.amqp.publisher_pool.acquire(block=True)
-        evd = None
-        if conf.CELERY_SEND_TASK_SENT_EVENT:
-            evd = app.events.Dispatcher(channel=publish.channel,
-                                        buffer_while_offline=False)
-
-        try:
-            task_id = publish.delay_task(self.name, args, kwargs,
-                                         task_id=task_id,
-                                         event_dispatcher=evd,
-                                         callbacks=maybe_list(link),
-                                         errbacks=maybe_list(link_error),
-                                         **options)
-        finally:
-            if not publisher:
-                publish.release()
-
+            producer = app.amqp.TaskProducer(connection)
+        with app.default_producer(producer) as P:
+            evd = None
+            if conf.CELERY_SEND_TASK_SENT_EVENT:
+                evd = app.events.Dispatcher(channel=P.channel,
+                                            buffer_while_offline=False)
+
+            task_id = P.delay_task(self.name, args, kwargs,
+                                   task_id=task_id,
+                                   event_dispatcher=evd,
+                                   callbacks=maybe_list(link),
+                                   errbacks=maybe_list(link_error),
+                                   **options)
         result = self.AsyncResult(task_id)
         parent = get_current_worker_task()
         if parent:

+ 45 - 55
celery/backends/amqp.py

@@ -6,16 +6,17 @@ import socket
 import threading
 import time
 
-from itertools import count
-
 from kombu.entity import Exchange, Queue
 from kombu.messaging import Consumer, Producer
 
 from celery import states
 from celery.exceptions import TimeoutError
+from celery.utils.log import get_logger
 
 from .base import BaseDictBackend
 
+logger = get_logger(__name__)
+
 
 class BacklogLimitExceeded(Exception):
     """Too much state history to fast-forward."""
@@ -39,6 +40,13 @@ class AMQPBackend(BaseDictBackend):
 
     supports_native_join = True
 
+    retry_policy = {
+            "max_retries": 20,
+            "interval_start": 0,
+            "interval_step": 1,
+            "interval_max": 1,
+    }
+
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
             **kwargs):
@@ -83,19 +91,6 @@ class AMQPBackend(BaseDictBackend):
                           auto_delete=self.auto_delete,
                           queue_arguments=self.queue_arguments)
 
-    def _create_producer(self, task_id, connection):
-        self._create_binding(task_id)(connection.default_channel).declare()
-        return self.Producer(connection, exchange=self.exchange,
-                             routing_key=task_id.replace("-", ""),
-                             serializer=self.serializer)
-
-    def _create_consumer(self, bindings, channel):
-        return self.Consumer(channel, bindings, no_ack=True)
-
-    def _publish_result(self, connection, task_id, meta):
-        # cache single channel
-        self._create_producer(task_id, connection).publish(meta)
-
     def revive(self, channel):
         pass
 
@@ -104,27 +99,18 @@ class AMQPBackend(BaseDictBackend):
             interval_max=1):
         """Send task return value and status."""
         with self.mutex:
-            with self.app.pool.acquire(block=True) as conn:
-
-                def errback(error, delay):
-                    print("Couldn't send result for %r: %r. Retry in %rs." % (
-                            task_id, error, delay))
-
-                send = conn.ensure(self, self._publish_result,
-                            max_retries=max_retries,
-                            errback=errback,
-                            interval_start=interval_start,
-                            interval_step=interval_step,
-                            interval_max=interval_max)
-                send(conn, task_id, {"task_id": task_id, "status": status,
-                                "result": self.encode_result(result, status),
-                                "traceback": traceback,
-                                "children": self.current_task_children()})
+            with self.app.amqp.producer_pool.acquire(block=True) as pub:
+                pub.publish({"task_id": task_id, "status": status,
+                             "result": self.encode_result(result, status),
+                             "traceback": traceback,
+                             "children": self.current_task_children()},
+                            exchange=self.exchange,
+                            routing_key=task_id.replace("-", ""),
+                            serializer=self.serializer,
+                            retry=True, retry_policy=self.retry_policy,
+                            declare=[self._create_binding(task_id)])
         return result
 
-    def get_task_meta(self, task_id, cache=True):
-        return self.poll(task_id)
-
     def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
             **kwargs):
         cached_meta = self._cache.get(task_id)
@@ -147,23 +133,30 @@ class AMQPBackend(BaseDictBackend):
         else:
             return self.wait_for(task_id, timeout, cache)
 
-    def poll(self, task_id, backlog_limit=100):
+    def get_task_meta(self, task_id, backlog_limit=1000):
+        # Polling and using basic_get
         with self.app.pool.acquire_channel(block=True) as (_, channel):
             binding = self._create_binding(task_id)(channel)
             binding.declare()
             latest, acc = None, None
-            for i in count():  # fast-forward
+            for i in xrange(backlog_limit):
                 latest, acc = acc, binding.get(no_ack=True)
-                if not acc:
+                if not acc:  # no more messages
                     break
-                if i > backlog_limit:
-                    raise self.BacklogLimitExceeded(task_id)
+            else:
+                raise self.BacklogLimitExceeded(task_id)
+
             if latest:
+                # new state to report
                 payload = self._cache[task_id] = latest.payload
                 return payload
-            elif task_id in self._cache:  # use previously received state.
-                return self._cache[task_id]
-            return {"status": states.PENDING, "result": None}
+            else:
+                # no new state, use previous
+                try:
+                    return self._cache[task_id]
+                except KeyError:
+                    # result probably pending.
+                    return {"status": states.PENDING, "result": None}
 
     def drain_events(self, connection, consumer, timeout=None, now=time.time):
         wait = connection.drain_events
@@ -190,13 +183,12 @@ class AMQPBackend(BaseDictBackend):
     def consume(self, task_id, timeout=None):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             binding = self._create_binding(task_id)
-            with self._create_consumer(binding, channel) as consumer:
+            with self.Consumer(channel, binding, no_ack=True) as consumer:
                 return self.drain_events(conn, consumer, timeout).values()[0]
 
     def get_many(self, task_ids, timeout=None, **kwargs):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             ids = set(task_ids)
-            cached_ids = set()
             for task_id in ids:
                 try:
                     cached = self._cache[task_id]
@@ -205,11 +197,10 @@ class AMQPBackend(BaseDictBackend):
                 else:
                     if cached["status"] in states.READY_STATES:
                         yield task_id, cached
-                        cached_ids.add(task_id)
+                        ids.discard(task_id)
 
-            ids ^= cached_ids
             bindings = [self._create_binding(task_id) for task_id in task_ids]
-            with self._create_consumer(bindings, channel) as consumer:
+            with self.Consumer(channel, bindings, no_ack=True) as consumer:
                 while ids:
                     r = self.drain_events(conn, consumer, timeout)
                     ids ^= set(r)
@@ -238,12 +229,11 @@ class AMQPBackend(BaseDictBackend):
                 "delete_taskset is not supported by this backend.")
 
     def __reduce__(self, args=(), kwargs={}):
-        kwargs.update(
-            dict(connection=self._connection,
-                 exchange=self.exchange.name,
-                 exchange_type=self.exchange.type,
-                 persistent=self.persistent,
-                 serializer=self.serializer,
-                 auto_delete=self.auto_delete,
-                 expires=self.expires))
+        kwargs.update(connection=self._connection,
+                      exchange=self.exchange.name,
+                      exchange_type=self.exchange.type,
+                      persistent=self.persistent,
+                      serializer=self.serializer,
+                      auto_delete=self.auto_delete,
+                      expires=self.expires)
         return super(AMQPBackend, self).__reduce__(args, kwargs)

+ 1 - 1
celery/events/__init__.py

@@ -261,7 +261,7 @@ class Events(object):
     @contextmanager
     def default_dispatcher(self, hostname=None, enabled=True,
             buffer_while_offline=False):
-        with self.app.amqp.publisher_pool.acquire(block=True) as pub:
+        with self.app.amqp.producer_pool.acquire(block=True) as pub:
             with self.Dispatcher(pub.connection, hostname, enabled,
                                  pub.channel, buffer_while_offline) as d:
                 yield d

+ 1 - 1
celery/task/trace.py

@@ -290,7 +290,7 @@ def trace_task(task, uuid, args, kwargs, request=None, **opts):
 
 
 def trace_task_ret(task, uuid, args, kwargs, request):
-    task.__tracer__(uuid, args, kwargs, request)
+    return task.__tracer__(uuid, args, kwargs, request)[0]
 
 
 def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):

+ 4 - 4
celery/tests/app/test_amqp.py

@@ -46,9 +46,9 @@ class test_PublisherPool(AppCase):
             delattr(self.app, "_pool")
         except AttributeError:
             pass
-        self.app.amqp.__dict__.pop("publisher_pool", None)
+        self.app.amqp.__dict__.pop("producer_pool", None)
         try:
-            pool = self.app.amqp.publisher_pool
+            pool = self.app.amqp.producer_pool
             self.assertEqual(pool.limit, self.app.pool.limit)
             self.assertFalse(pool._resource.queue)
 
@@ -68,9 +68,9 @@ class test_PublisherPool(AppCase):
             delattr(self.app, "_pool")
         except AttributeError:
             pass
-        self.app.amqp.__dict__.pop("publisher_pool", None)
+        self.app.amqp.__dict__.pop("producer_pool", None)
         try:
-            pool = self.app.amqp.publisher_pool
+            pool = self.app.amqp.producer_pool
             self.assertEqual(pool.limit, self.app.pool.limit)
             self.assertTrue(pool._resource.queue)
 

+ 2 - 1
celery/worker/job.py

@@ -355,7 +355,8 @@ class Request(object):
         if _does_info:
             info(self.retry_msg.strip(), {
                 "id": self.id, "name": self.name,
-                "exc": safe_repr(exc_info.exception.exc)}, exc_info=exc_info.exc_info)
+                "exc": safe_repr(exc_info.exception.exc)},
+                exc_info=exc_info.exc_info)
 
     def on_failure(self, exc_info):
         """Handler called if the task raised an exception."""