Browse Source

AMQP result backend now using Kombu

Ask Solem 14 years ago
parent
commit
df2184a794
3 changed files with 53 additions and 64 deletions
  1. 49 60
      celery/backends/amqp.py
  2. 2 1
      celery/bin/base.py
  3. 2 3
      celery/task/base.py

+ 49 - 60
celery/backends/amqp.py

@@ -5,7 +5,8 @@ import warnings
 
 
 from datetime import timedelta
 from datetime import timedelta
 
 
-from carrot.messaging import Consumer, Publisher
+from kombu.entity import Exchange, Queue
+from kombu.messaging import Consumer, Producer
 
 
 from celery import states
 from celery import states
 from celery.backends.base import BaseDictBackend
 from celery.backends.base import BaseDictBackend
@@ -17,25 +18,6 @@ class AMQResultWarning(UserWarning):
     pass
     pass
 
 
 
 
-class ResultPublisher(Publisher):
-    auto_delete = True
-
-    def __init__(self, connection, task_id, **kwargs):
-        super(ResultPublisher, self).__init__(connection,
-                        routing_key=task_id.replace("-", ""),
-                        **kwargs)
-
-
-class ResultConsumer(Consumer):
-    no_ack = True
-    auto_delete = True
-
-    def __init__(self, connection, task_id, **kwargs):
-        routing_key = task_id.replace("-", "")
-        super(ResultConsumer, self).__init__(connection,
-                queue=routing_key, routing_key=routing_key, **kwargs)
-
-
 class AMQPBackend(BaseDictBackend):
 class AMQPBackend(BaseDictBackend):
     """AMQP backend. Publish results by sending messages to the broker
     """AMQP backend. Publish results by sending messages to the broker
     using the task id as routing key.
     using the task id as routing key.
@@ -47,6 +29,7 @@ class AMQPBackend(BaseDictBackend):
     """
     """
 
 
     _connection = None
     _connection = None
+    _channel = None
 
 
     def __init__(self, connection=None, exchange=None, exchange_type=None,
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
             persistent=None, serializer=None, auto_delete=True,
@@ -55,11 +38,17 @@ class AMQPBackend(BaseDictBackend):
         conf = self.app.conf
         conf = self.app.conf
         self._connection = connection
         self._connection = connection
         self.queue_arguments = {}
         self.queue_arguments = {}
-        self.exchange = exchange or conf.CELERY_RESULT_EXCHANGE
-        self.exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
+        exchange = exchange or conf.CELERY_RESULT_EXCHANGE
+        exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
         if persistent is None:
         if persistent is None:
             persistent = conf.CELERY_RESULT_PERSISTENT
             persistent = conf.CELERY_RESULT_PERSISTENT
         self.persistent = persistent
         self.persistent = persistent
+        delivery_mode = persistent and "persistent" or "transient"
+        self.exchange = Exchange(name=exchange,
+                                 type=exchange_type,
+                                 delivery_mode=delivery_mode,
+                                 durable=self.persistent,
+                                 auto_delete=auto_delete)
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.auto_delete = auto_delete
         self.auto_delete = auto_delete
         self.expires = expires
         self.expires = expires
@@ -74,27 +63,25 @@ class AMQPBackend(BaseDictBackend):
             # the expiry time in milliseconds.
             # the expiry time in milliseconds.
             self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
             self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
 
 
-    def _create_publisher(self, task_id, connection):
-        delivery_mode = self.persistent and 2 or 1
+    def _create_binding(self, task_id):
+        name = task_id.replace("-", "")
+        return Queue(name=name,
+                     exchange=self.exchange,
+                     routing_key=name,
+                     durable=self.persistent,
+                     auto_delete=self.auto_delete)
 
 
-        # Declares the queue.
-        self._create_consumer(task_id, connection).close()
+    def _create_producer(self, task_id):
+        binding = self._create_binding(task_id)
+        binding(self.channel).declare()
 
 
-        return ResultPublisher(connection, task_id,
-                               exchange=self.exchange,
-                               exchange_type=self.exchange_type,
-                               delivery_mode=delivery_mode,
-                               durable=self.persistent,
-                               serializer=self.serializer,
-                               auto_delete=self.auto_delete)
+        return Producer(self.channel, exchange=self.exchange,
+                        routing_key=task_id.replace("-", ""),
+                        serializer=self.serializer)
 
 
-    def _create_consumer(self, task_id, connection):
-        return ResultConsumer(connection, task_id,
-                              exchange=self.exchange,
-                              exchange_type=self.exchange_type,
-                              durable=self.persistent,
-                              auto_delete=self.auto_delete,
-                              queue_arguments=self.queue_arguments)
+    def _create_consumer(self, task_id):
+        binding = self._create_binding(task_id)
+        return Consumer(self.channel, [binding], no_ack=True)
 
 
     def store_result(self, task_id, result, status, traceback=None,
     def store_result(self, task_id, result, status, traceback=None,
             max_retries=20, retry_delay=0.2):
             max_retries=20, retry_delay=0.2):
@@ -108,12 +95,11 @@ class AMQPBackend(BaseDictBackend):
 
 
         for i in range(max_retries + 1):
         for i in range(max_retries + 1):
             try:
             try:
-                publisher = self._create_publisher(task_id, self.connection)
-                publisher.send(meta)
-                publisher.close()
+                self._create_producer(task_id).publish(meta)
             except Exception, exc:
             except Exception, exc:
                 if not max_retries:
                 if not max_retries:
                     raise
                     raise
+                self._channel = None
                 self._connection = None
                 self._connection = None
                 warnings.warn(AMQResultWarning(
                 warnings.warn(AMQResultWarning(
                     "Error sending result %s: %r" % (task_id, exc)))
                     "Error sending result %s: %r" % (task_id, exc)))
@@ -144,21 +130,14 @@ class AMQPBackend(BaseDictBackend):
             return self.wait_for(task_id, timeout, cache)
             return self.wait_for(task_id, timeout, cache)
 
 
     def poll(self, task_id):
     def poll(self, task_id):
-        consumer = self._create_consumer(task_id, self.connection)
-        result = consumer.fetch()
-        try:
-            if result:
-                payload = self._cache[task_id] = result.payload
-                return payload
-            else:
-
-                # Use previously received status if any.
-                if task_id in self._cache:
-                    return self._cache[task_id]
-
-                return {"status": states.PENDING, "result": None}
-        finally:
-            consumer.close()
+        binding = self._create_binding(task_id)(self.channel)
+        result = binding.get()
+        if result:
+            payload = self._cache[task_id] = result.payload
+            return payload
+        elif task_id in self._cache:
+            return self._cache[task_id]     # use previously received state.
+        return {"status": states.PENDING, "result": None}
 
 
     def consume(self, task_id, timeout=None):
     def consume(self, task_id, timeout=None):
         results = []
         results = []
@@ -168,7 +147,7 @@ class AMQPBackend(BaseDictBackend):
                 results.append(meta)
                 results.append(meta)
 
 
         wait = self.connection.drain_events
         wait = self.connection.drain_events
-        consumer = self._create_consumer(task_id, self.connection)
+        consumer = self._create_consumer(task_id)
         consumer.register_callback(callback)
         consumer.register_callback(callback)
 
 
         consumer.consume()
         consumer.consume()
@@ -183,14 +162,18 @@ class AMQPBackend(BaseDictBackend):
                     # Got event on the wanted channel.
                     # Got event on the wanted channel.
                     break
                     break
         finally:
         finally:
-            consumer.close()
+            consumer.cancel()
 
 
         self._cache[task_id] = results[0]
         self._cache[task_id] = results[0]
         return results[0]
         return results[0]
 
 
     def close(self):
     def close(self):
+        if self._channel is not None:
+            self._channel.close()
+            self._channel = None
         if self._connection is not None:
         if self._connection is not None:
             self._connection.close()
             self._connection.close()
+            self._connection = None
 
 
     @property
     @property
     def connection(self):
     def connection(self):
@@ -198,6 +181,12 @@ class AMQPBackend(BaseDictBackend):
             self._connection = self.app.broker_connection()
             self._connection = self.app.broker_connection()
         return self._connection
         return self._connection
 
 
+    @property
+    def channel(self):
+        if not self._channel:
+            self._channel = self.connection.channel()
+        return self._channel
+
     def reload_task_result(self, task_id):
     def reload_task_result(self, task_id):
         raise NotImplementedError(
         raise NotImplementedError(
                 "reload_task_result is not supported by this backend.")
                 "reload_task_result is not supported by this backend.")

+ 2 - 1
celery/bin/base.py

@@ -31,7 +31,8 @@ class Command(object):
     Parser = OptionParser
     Parser = OptionParser
 
 
     def __init__(self, app=None, get_app=None):
     def __init__(self, app=None, get_app=None):
-        self.app = app
+        from celery.app import app_or_default
+        self.app = app_or_default(app)
         self.get_app = get_app or self._get_default_app
         self.get_app = get_app or self._get_default_app
 
 
     def usage(self):
     def usage(self):

+ 2 - 3
celery/task/base.py

@@ -486,9 +486,8 @@ class BaseTask(object):
                                          eta=eta, expires=expires,
                                          eta=eta, expires=expires,
                                          **options)
                                          **options)
         finally:
         finally:
-            publisher or publish.close()
-            if not connection:
-                # close automatically created connection
+            if not publisher:
+                publish.close()
                 publish.connection.close()
                 publish.connection.close()
 
 
         return self.AsyncResult(task_id)
         return self.AsyncResult(task_id)