Browse Source

AMQP result backend now using Kombu

Ask Solem 14 years ago
parent
commit
df2184a794
3 changed files with 53 additions and 64 deletions
  1. 49 60
      celery/backends/amqp.py
  2. 2 1
      celery/bin/base.py
  3. 2 3
      celery/task/base.py

+ 49 - 60
celery/backends/amqp.py

@@ -5,7 +5,8 @@ import warnings
 
 from datetime import timedelta
 
-from carrot.messaging import Consumer, Publisher
+from kombu.entity import Exchange, Queue
+from kombu.messaging import Consumer, Producer
 
 from celery import states
 from celery.backends.base import BaseDictBackend
@@ -17,25 +18,6 @@ class AMQResultWarning(UserWarning):
     pass
 
 
-class ResultPublisher(Publisher):
-    auto_delete = True
-
-    def __init__(self, connection, task_id, **kwargs):
-        super(ResultPublisher, self).__init__(connection,
-                        routing_key=task_id.replace("-", ""),
-                        **kwargs)
-
-
-class ResultConsumer(Consumer):
-    no_ack = True
-    auto_delete = True
-
-    def __init__(self, connection, task_id, **kwargs):
-        routing_key = task_id.replace("-", "")
-        super(ResultConsumer, self).__init__(connection,
-                queue=routing_key, routing_key=routing_key, **kwargs)
-
-
 class AMQPBackend(BaseDictBackend):
     """AMQP backend. Publish results by sending messages to the broker
     using the task id as routing key.
@@ -47,6 +29,7 @@ class AMQPBackend(BaseDictBackend):
     """
 
     _connection = None
+    _channel = None
 
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
@@ -55,11 +38,17 @@ class AMQPBackend(BaseDictBackend):
         conf = self.app.conf
         self._connection = connection
         self.queue_arguments = {}
-        self.exchange = exchange or conf.CELERY_RESULT_EXCHANGE
-        self.exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
+        exchange = exchange or conf.CELERY_RESULT_EXCHANGE
+        exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
         if persistent is None:
             persistent = conf.CELERY_RESULT_PERSISTENT
         self.persistent = persistent
+        delivery_mode = persistent and "persistent" or "transient"
+        self.exchange = Exchange(name=exchange,
+                                 type=exchange_type,
+                                 delivery_mode=delivery_mode,
+                                 durable=self.persistent,
+                                 auto_delete=auto_delete)
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.auto_delete = auto_delete
         self.expires = expires
@@ -74,27 +63,25 @@ class AMQPBackend(BaseDictBackend):
             # the expiry time in milliseconds.
             self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
 
-    def _create_publisher(self, task_id, connection):
-        delivery_mode = self.persistent and 2 or 1
+    def _create_binding(self, task_id):
+        name = task_id.replace("-", "")
+        return Queue(name=name,
+                     exchange=self.exchange,
+                     routing_key=name,
+                     durable=self.persistent,
+                     auto_delete=self.auto_delete)
 
-        # Declares the queue.
-        self._create_consumer(task_id, connection).close()
+    def _create_producer(self, task_id):
+        binding = self._create_binding(task_id)
+        binding(self.channel).declare()
 
-        return ResultPublisher(connection, task_id,
-                               exchange=self.exchange,
-                               exchange_type=self.exchange_type,
-                               delivery_mode=delivery_mode,
-                               durable=self.persistent,
-                               serializer=self.serializer,
-                               auto_delete=self.auto_delete)
+        return Producer(self.channel, exchange=self.exchange,
+                        routing_key=task_id.replace("-", ""),
+                        serializer=self.serializer)
 
-    def _create_consumer(self, task_id, connection):
-        return ResultConsumer(connection, task_id,
-                              exchange=self.exchange,
-                              exchange_type=self.exchange_type,
-                              durable=self.persistent,
-                              auto_delete=self.auto_delete,
-                              queue_arguments=self.queue_arguments)
+    def _create_consumer(self, task_id):
+        binding = self._create_binding(task_id)
+        return Consumer(self.channel, [binding], no_ack=True)
 
     def store_result(self, task_id, result, status, traceback=None,
             max_retries=20, retry_delay=0.2):
@@ -108,12 +95,11 @@ class AMQPBackend(BaseDictBackend):
 
         for i in range(max_retries + 1):
             try:
-                publisher = self._create_publisher(task_id, self.connection)
-                publisher.send(meta)
-                publisher.close()
+                self._create_producer(task_id).publish(meta)
             except Exception, exc:
                 if not max_retries:
                     raise
+                self._channel = None
                 self._connection = None
                 warnings.warn(AMQResultWarning(
                     "Error sending result %s: %r" % (task_id, exc)))
@@ -144,21 +130,14 @@ class AMQPBackend(BaseDictBackend):
             return self.wait_for(task_id, timeout, cache)
 
     def poll(self, task_id):
-        consumer = self._create_consumer(task_id, self.connection)
-        result = consumer.fetch()
-        try:
-            if result:
-                payload = self._cache[task_id] = result.payload
-                return payload
-            else:
-
-                # Use previously received status if any.
-                if task_id in self._cache:
-                    return self._cache[task_id]
-
-                return {"status": states.PENDING, "result": None}
-        finally:
-            consumer.close()
+        binding = self._create_binding(task_id)(self.channel)
+        result = binding.get()
+        if result:
+            payload = self._cache[task_id] = result.payload
+            return payload
+        elif task_id in self._cache:
+            return self._cache[task_id]     # use previously received state.
+        return {"status": states.PENDING, "result": None}
 
     def consume(self, task_id, timeout=None):
         results = []
@@ -168,7 +147,7 @@ class AMQPBackend(BaseDictBackend):
                 results.append(meta)
 
         wait = self.connection.drain_events
-        consumer = self._create_consumer(task_id, self.connection)
+        consumer = self._create_consumer(task_id)
         consumer.register_callback(callback)
 
         consumer.consume()
@@ -183,14 +162,18 @@ class AMQPBackend(BaseDictBackend):
                     # Got event on the wanted channel.
                     break
         finally:
-            consumer.close()
+            consumer.cancel()
 
         self._cache[task_id] = results[0]
         return results[0]
 
     def close(self):
+        if self._channel is not None:
+            self._channel.close()
+            self._channel = None
         if self._connection is not None:
             self._connection.close()
+            self._connection = None
 
     @property
     def connection(self):
@@ -198,6 +181,12 @@ class AMQPBackend(BaseDictBackend):
             self._connection = self.app.broker_connection()
         return self._connection
 
+    @property
+    def channel(self):
+        if not self._channel:
+            self._channel = self.connection.channel()
+        return self._channel
+
     def reload_task_result(self, task_id):
         raise NotImplementedError(
                 "reload_task_result is not supported by this backend.")

+ 2 - 1
celery/bin/base.py

@@ -31,7 +31,8 @@ class Command(object):
     Parser = OptionParser
 
     def __init__(self, app=None, get_app=None):
-        self.app = app
+        from celery.app import app_or_default
+        self.app = app_or_default(app)
         self.get_app = get_app or self._get_default_app
 
     def usage(self):

+ 2 - 3
celery/task/base.py

@@ -486,9 +486,8 @@ class BaseTask(object):
                                          eta=eta, expires=expires,
                                          **options)
         finally:
-            publisher or publish.close()
-            if not connection:
-                # close automatically created connection
+            if not publisher:
+                publish.close()
                 publish.connection.close()
 
         return self.AsyncResult(task_id)