Explorar el Código

amqp related improvements

Ask Solem hace 13 años
padre
commit
fa957dbc29

+ 2 - 1
celery/app/amqp.py

@@ -319,9 +319,10 @@ class AMQP(object):
         return self.Router()
         return self.Router()
 
 
     @cached_property
     @cached_property
-    def publisher_pool(self):
+    def producer_pool(self):
         return ProducerPool(self.app.pool, limit=self.app.pool.limit,
         return ProducerPool(self.app.pool, limit=self.app.pool.limit,
                             Producer=self.TaskProducer)
                             Producer=self.TaskProducer)
+    publisher_pool = producer_pool  # compat alias
 
 
     @cached_property
     @cached_property
     def default_exchange(self):
     def default_exchange(self):

+ 1 - 1
celery/app/base.py

@@ -223,7 +223,7 @@ class Celery(object):
         if producer:
         if producer:
             yield producer
             yield producer
         else:
         else:
-            with self.amqp.publisher_pool.acquire(block=True) as producer:
+            with self.amqp.producer_pool.acquire(block=True) as producer:
                 yield producer
                 yield producer
 
 
     def with_default_connection(self, fun):
     def with_default_connection(self, fun):

+ 21 - 22
celery/app/task.py

@@ -9,8 +9,8 @@
     :license: BSD, see LICENSE for more details.
     :license: BSD, see LICENSE for more details.
 
 
 """
 """
-
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 import logging
 import logging
 import sys
 import sys
@@ -460,8 +460,8 @@ class BaseTask(object):
         return self.apply_async(args, kwargs)
         return self.apply_async(args, kwargs)
 
 
     def apply_async(self, args=None, kwargs=None,
     def apply_async(self, args=None, kwargs=None,
-            task_id=None, publisher=None, connection=None,
-            router=None, link=None, link_error=None, **options):
+            task_id=None, producer=None, connection=None, router=None,
+            link=None, link_error=None, publisher=None, **options):
         """Apply tasks asynchronously by sending a message.
         """Apply tasks asynchronously by sending a message.
 
 
         :keyword args: The positional arguments to pass on to the
         :keyword args: The positional arguments to pass on to the
@@ -494,7 +494,7 @@ class BaseTask(object):
                         in the event of connection loss or failure.  Default
                         in the event of connection loss or failure.  Default
                         is taken from the :setting:`CELERY_TASK_PUBLISH_RETRY`
                         is taken from the :setting:`CELERY_TASK_PUBLISH_RETRY`
                         setting.  Note you need to handle the
                         setting.  Note you need to handle the
-                        publisher/connection manually for this to work.
+                        producer/connection manually for this to work.
 
 
         :keyword retry_policy:  Override the retry policy used.  See the
         :keyword retry_policy:  Override the retry policy used.  See the
                                 :setting:`CELERY_TASK_PUBLISH_RETRY` setting.
                                 :setting:`CELERY_TASK_PUBLISH_RETRY` setting.
@@ -543,11 +543,15 @@ class BaseTask(object):
         :keyword link_error: A single, or a list of subtasks to apply
         :keyword link_error: A single, or a list of subtasks to apply
                       if an error occurs while executing the task.
                       if an error occurs while executing the task.
 
 
+        :keyword producer: :class:~@amqp.TaskProducer` instance to use.
+        :keyword publisher: Deprecated alias to ``producer``.
+
         .. note::
         .. note::
             If the :setting:`CELERY_ALWAYS_EAGER` setting is set, it will
             If the :setting:`CELERY_ALWAYS_EAGER` setting is set, it will
             be replaced by a local :func:`apply` call instead.
             be replaced by a local :func:`apply` call instead.
 
 
         """
         """
+        producer = producer or publisher
         app = self._get_app()
         app = self._get_app()
         router = router or self.app.amqp.router
         router = router or self.app.amqp.router
         conf = app.conf
         conf = app.conf
@@ -562,24 +566,19 @@ class BaseTask(object):
         options = router.route(options, self.name, args, kwargs)
         options = router.route(options, self.name, args, kwargs)
 
 
         if connection:
         if connection:
-            publisher = app.amqp.TaskProducer(connection)
-        publish = publisher or app.amqp.publisher_pool.acquire(block=True)
-        evd = None
-        if conf.CELERY_SEND_TASK_SENT_EVENT:
-            evd = app.events.Dispatcher(channel=publish.channel,
-                                        buffer_while_offline=False)
-
-        try:
-            task_id = publish.delay_task(self.name, args, kwargs,
-                                         task_id=task_id,
-                                         event_dispatcher=evd,
-                                         callbacks=maybe_list(link),
-                                         errbacks=maybe_list(link_error),
-                                         **options)
-        finally:
-            if not publisher:
-                publish.release()
-
+            producer = app.amqp.TaskProducer(connection)
+        with app.default_producer(producer) as P:
+            evd = None
+            if conf.CELERY_SEND_TASK_SENT_EVENT:
+                evd = app.events.Dispatcher(channel=P.channel,
+                                            buffer_while_offline=False)
+
+            task_id = P.delay_task(self.name, args, kwargs,
+                                   task_id=task_id,
+                                   event_dispatcher=evd,
+                                   callbacks=maybe_list(link),
+                                   errbacks=maybe_list(link_error),
+                                   **options)
         result = self.AsyncResult(task_id)
         result = self.AsyncResult(task_id)
         parent = get_current_worker_task()
         parent = get_current_worker_task()
         if parent:
         if parent:

+ 45 - 55
celery/backends/amqp.py

@@ -6,16 +6,17 @@ import socket
 import threading
 import threading
 import time
 import time
 
 
-from itertools import count
-
 from kombu.entity import Exchange, Queue
 from kombu.entity import Exchange, Queue
 from kombu.messaging import Consumer, Producer
 from kombu.messaging import Consumer, Producer
 
 
 from celery import states
 from celery import states
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
+from celery.utils.log import get_logger
 
 
 from .base import BaseDictBackend
 from .base import BaseDictBackend
 
 
+logger = get_logger(__name__)
+
 
 
 class BacklogLimitExceeded(Exception):
 class BacklogLimitExceeded(Exception):
     """Too much state history to fast-forward."""
     """Too much state history to fast-forward."""
@@ -39,6 +40,13 @@ class AMQPBackend(BaseDictBackend):
 
 
     supports_native_join = True
     supports_native_join = True
 
 
+    retry_policy = {
+            "max_retries": 20,
+            "interval_start": 0,
+            "interval_step": 1,
+            "interval_max": 1,
+    }
+
     def __init__(self, connection=None, exchange=None, exchange_type=None,
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
             persistent=None, serializer=None, auto_delete=True,
             **kwargs):
             **kwargs):
@@ -83,19 +91,6 @@ class AMQPBackend(BaseDictBackend):
                           auto_delete=self.auto_delete,
                           auto_delete=self.auto_delete,
                           queue_arguments=self.queue_arguments)
                           queue_arguments=self.queue_arguments)
 
 
-    def _create_producer(self, task_id, connection):
-        self._create_binding(task_id)(connection.default_channel).declare()
-        return self.Producer(connection, exchange=self.exchange,
-                             routing_key=task_id.replace("-", ""),
-                             serializer=self.serializer)
-
-    def _create_consumer(self, bindings, channel):
-        return self.Consumer(channel, bindings, no_ack=True)
-
-    def _publish_result(self, connection, task_id, meta):
-        # cache single channel
-        self._create_producer(task_id, connection).publish(meta)
-
     def revive(self, channel):
     def revive(self, channel):
         pass
         pass
 
 
@@ -104,27 +99,18 @@ class AMQPBackend(BaseDictBackend):
             interval_max=1):
             interval_max=1):
         """Send task return value and status."""
         """Send task return value and status."""
         with self.mutex:
         with self.mutex:
-            with self.app.pool.acquire(block=True) as conn:
-
-                def errback(error, delay):
-                    print("Couldn't send result for %r: %r. Retry in %rs." % (
-                            task_id, error, delay))
-
-                send = conn.ensure(self, self._publish_result,
-                            max_retries=max_retries,
-                            errback=errback,
-                            interval_start=interval_start,
-                            interval_step=interval_step,
-                            interval_max=interval_max)
-                send(conn, task_id, {"task_id": task_id, "status": status,
-                                "result": self.encode_result(result, status),
-                                "traceback": traceback,
-                                "children": self.current_task_children()})
+            with self.app.amqp.producer_pool.acquire(block=True) as pub:
+                pub.publish({"task_id": task_id, "status": status,
+                             "result": self.encode_result(result, status),
+                             "traceback": traceback,
+                             "children": self.current_task_children()},
+                            exchange=self.exchange,
+                            routing_key=task_id.replace("-", ""),
+                            serializer=self.serializer,
+                            retry=True, retry_policy=self.retry_policy,
+                            declare=[self._create_binding(task_id)])
         return result
         return result
 
 
-    def get_task_meta(self, task_id, cache=True):
-        return self.poll(task_id)
-
     def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
     def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
             **kwargs):
             **kwargs):
         cached_meta = self._cache.get(task_id)
         cached_meta = self._cache.get(task_id)
@@ -147,23 +133,30 @@ class AMQPBackend(BaseDictBackend):
         else:
         else:
             return self.wait_for(task_id, timeout, cache)
             return self.wait_for(task_id, timeout, cache)
 
 
-    def poll(self, task_id, backlog_limit=100):
+    def get_task_meta(self, task_id, backlog_limit=1000):
+        # Polling and using basic_get
         with self.app.pool.acquire_channel(block=True) as (_, channel):
         with self.app.pool.acquire_channel(block=True) as (_, channel):
             binding = self._create_binding(task_id)(channel)
             binding = self._create_binding(task_id)(channel)
             binding.declare()
             binding.declare()
             latest, acc = None, None
             latest, acc = None, None
-            for i in count():  # fast-forward
+            for i in xrange(backlog_limit):
                 latest, acc = acc, binding.get(no_ack=True)
                 latest, acc = acc, binding.get(no_ack=True)
-                if not acc:
+                if not acc:  # no more messages
                     break
                     break
-                if i > backlog_limit:
-                    raise self.BacklogLimitExceeded(task_id)
+            else:
+                raise self.BacklogLimitExceeded(task_id)
+
             if latest:
             if latest:
+                # new state to report
                 payload = self._cache[task_id] = latest.payload
                 payload = self._cache[task_id] = latest.payload
                 return payload
                 return payload
-            elif task_id in self._cache:  # use previously received state.
-                return self._cache[task_id]
-            return {"status": states.PENDING, "result": None}
+            else:
+                # no new state, use previous
+                try:
+                    return self._cache[task_id]
+                except KeyError:
+                    # result probably pending.
+                    return {"status": states.PENDING, "result": None}
 
 
     def drain_events(self, connection, consumer, timeout=None, now=time.time):
     def drain_events(self, connection, consumer, timeout=None, now=time.time):
         wait = connection.drain_events
         wait = connection.drain_events
@@ -190,13 +183,12 @@ class AMQPBackend(BaseDictBackend):
     def consume(self, task_id, timeout=None):
     def consume(self, task_id, timeout=None):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             binding = self._create_binding(task_id)
             binding = self._create_binding(task_id)
-            with self._create_consumer(binding, channel) as consumer:
+            with self.Consumer(channel, binding, no_ack=True) as consumer:
                 return self.drain_events(conn, consumer, timeout).values()[0]
                 return self.drain_events(conn, consumer, timeout).values()[0]
 
 
     def get_many(self, task_ids, timeout=None, **kwargs):
     def get_many(self, task_ids, timeout=None, **kwargs):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             ids = set(task_ids)
             ids = set(task_ids)
-            cached_ids = set()
             for task_id in ids:
             for task_id in ids:
                 try:
                 try:
                     cached = self._cache[task_id]
                     cached = self._cache[task_id]
@@ -205,11 +197,10 @@ class AMQPBackend(BaseDictBackend):
                 else:
                 else:
                     if cached["status"] in states.READY_STATES:
                     if cached["status"] in states.READY_STATES:
                         yield task_id, cached
                         yield task_id, cached
-                        cached_ids.add(task_id)
+                        ids.discard(task_id)
 
 
-            ids ^= cached_ids
             bindings = [self._create_binding(task_id) for task_id in task_ids]
             bindings = [self._create_binding(task_id) for task_id in task_ids]
-            with self._create_consumer(bindings, channel) as consumer:
+            with self.Consumer(channel, bindings, no_ack=True) as consumer:
                 while ids:
                 while ids:
                     r = self.drain_events(conn, consumer, timeout)
                     r = self.drain_events(conn, consumer, timeout)
                     ids ^= set(r)
                     ids ^= set(r)
@@ -238,12 +229,11 @@ class AMQPBackend(BaseDictBackend):
                 "delete_taskset is not supported by this backend.")
                 "delete_taskset is not supported by this backend.")
 
 
     def __reduce__(self, args=(), kwargs={}):
     def __reduce__(self, args=(), kwargs={}):
-        kwargs.update(
-            dict(connection=self._connection,
-                 exchange=self.exchange.name,
-                 exchange_type=self.exchange.type,
-                 persistent=self.persistent,
-                 serializer=self.serializer,
-                 auto_delete=self.auto_delete,
-                 expires=self.expires))
+        kwargs.update(connection=self._connection,
+                      exchange=self.exchange.name,
+                      exchange_type=self.exchange.type,
+                      persistent=self.persistent,
+                      serializer=self.serializer,
+                      auto_delete=self.auto_delete,
+                      expires=self.expires)
         return super(AMQPBackend, self).__reduce__(args, kwargs)
         return super(AMQPBackend, self).__reduce__(args, kwargs)

+ 1 - 1
celery/events/__init__.py

@@ -261,7 +261,7 @@ class Events(object):
     @contextmanager
     @contextmanager
     def default_dispatcher(self, hostname=None, enabled=True,
     def default_dispatcher(self, hostname=None, enabled=True,
             buffer_while_offline=False):
             buffer_while_offline=False):
-        with self.app.amqp.publisher_pool.acquire(block=True) as pub:
+        with self.app.amqp.producer_pool.acquire(block=True) as pub:
             with self.Dispatcher(pub.connection, hostname, enabled,
             with self.Dispatcher(pub.connection, hostname, enabled,
                                  pub.channel, buffer_while_offline) as d:
                                  pub.channel, buffer_while_offline) as d:
                 yield d
                 yield d

+ 1 - 1
celery/task/trace.py

@@ -290,7 +290,7 @@ def trace_task(task, uuid, args, kwargs, request=None, **opts):
 
 
 
 
 def trace_task_ret(task, uuid, args, kwargs, request):
 def trace_task_ret(task, uuid, args, kwargs, request):
-    task.__tracer__(uuid, args, kwargs, request)
+    return task.__tracer__(uuid, args, kwargs, request)[0]
 
 
 
 
 def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
 def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):

+ 4 - 4
celery/tests/app/test_amqp.py

@@ -46,9 +46,9 @@ class test_PublisherPool(AppCase):
             delattr(self.app, "_pool")
             delattr(self.app, "_pool")
         except AttributeError:
         except AttributeError:
             pass
             pass
-        self.app.amqp.__dict__.pop("publisher_pool", None)
+        self.app.amqp.__dict__.pop("producer_pool", None)
         try:
         try:
-            pool = self.app.amqp.publisher_pool
+            pool = self.app.amqp.producer_pool
             self.assertEqual(pool.limit, self.app.pool.limit)
             self.assertEqual(pool.limit, self.app.pool.limit)
             self.assertFalse(pool._resource.queue)
             self.assertFalse(pool._resource.queue)
 
 
@@ -68,9 +68,9 @@ class test_PublisherPool(AppCase):
             delattr(self.app, "_pool")
             delattr(self.app, "_pool")
         except AttributeError:
         except AttributeError:
             pass
             pass
-        self.app.amqp.__dict__.pop("publisher_pool", None)
+        self.app.amqp.__dict__.pop("producer_pool", None)
         try:
         try:
-            pool = self.app.amqp.publisher_pool
+            pool = self.app.amqp.producer_pool
             self.assertEqual(pool.limit, self.app.pool.limit)
             self.assertEqual(pool.limit, self.app.pool.limit)
             self.assertTrue(pool._resource.queue)
             self.assertTrue(pool._resource.queue)
 
 

+ 2 - 1
celery/worker/job.py

@@ -355,7 +355,8 @@ class Request(object):
         if _does_info:
         if _does_info:
             info(self.retry_msg.strip(), {
             info(self.retry_msg.strip(), {
                 "id": self.id, "name": self.name,
                 "id": self.id, "name": self.name,
-                "exc": safe_repr(exc_info.exception.exc)}, exc_info=exc_info.exc_info)
+                "exc": safe_repr(exc_info.exception.exc)},
+                exc_info=exc_info.exc_info)
 
 
     def on_failure(self, exc_info):
     def on_failure(self, exc_info):
         """Handler called if the task raised an exception."""
         """Handler called if the task raised an exception."""