Browse Source

AMQP Backend: Now set as non-persistent+non-durable by default.

Set CELERY_RESULT_PERSISTENT=False to revert to the previous settings.
Ask Solem 15 years ago
parent
commit
296e8d10a5
3 changed files with 92 additions and 72 deletions
  1. 61 50
      celery/backends/amqp.py
  2. 8 1
      celery/conf.py
  3. 23 21
      celery/tests/test_backends/test_amqp.py

+ 61 - 50
celery/backends/amqp.py

@@ -10,6 +10,32 @@ from celery.backends.base import BaseDictBackend
 from celery.messaging import establish_connection
 
 
+class ResultPublisher(Publisher):
+    exchange = conf.RESULT_EXCHANGE
+    exchange_type = conf.RESULT_EXCHANGE_TYPE
+    delivery_mode = conf.RESULT_PERSISTENT and 2 or 1
+    serializer = conf.RESULT_SERIALIZER
+    durable = conf.RESULT_PERSISTENT
+
+    def __init__(self, connection, task_id, **kwargs):
+        super(ResultPublisher, self).__init__(connection,
+                        routing_key=task_id.replace("-", ""),
+                        **kwargs)
+
+
+class ResultConsumer(Consumer):
+    exchange = conf.RESULT_EXCHANGE
+    exchange_type = conf.RESULT_EXCHANGE_TYPE
+    durable = conf.RESULT_PERSISTENT
+    no_ack = True
+    auto_delete = True
+
+    def __init__(self, connection, task_id, **kwargs):
+        routing_key = task_id.replace("-", "")
+        super(ResultConsumer, self).__init__(connection,
+                queue=routing_key, routing_key=routing_key, **kwargs)
+
+
 class AMQPBackend(BaseDictBackend):
     """AMQP backend. Publish results by sending messages to the broker
     using the task id as routing key.
@@ -21,49 +47,28 @@ class AMQPBackend(BaseDictBackend):
     """
 
     exchange = conf.RESULT_EXCHANGE
-    capabilities = ["ResultStore"]
+    exchange_type = conf.RESULT_EXCHANGE_TYPE
+    persistent = conf.RESULT_PERSISTENT
+    serializer = conf.RESULT_SERIALIZER
     _connection = None
-    _use_debug_tracking = False
-    _seen = set()
 
-    def __init__(self, *args, **kwargs):
-        super(AMQPBackend, self).__init__(*args, **kwargs)
+    def _create_publisher(self, task_id, connection):
+        delivery_mode = self.persistent and 2 or 1
 
-    @property
-    def connection(self):
-        if not self._connection:
-            self._connection = establish_connection()
-        return self._connection
+        # Declares the queue.
+        self._create_consumer(task_id, connection).close()
 
-    def _declare_queue(self, task_id, connection):
-        routing_key = task_id.replace("-", "")
-        backend = connection.create_backend()
-        backend.queue_declare(queue=routing_key, durable=True,
-                                exclusive=False, auto_delete=True)
-        backend.exchange_declare(exchange=self.exchange,
-                                 type="direct",
-                                 durable=True,
-                                 auto_delete=False)
-        backend.queue_bind(queue=routing_key, exchange=self.exchange,
-                           routing_key=routing_key)
-        backend.close()
-
-    def _publisher_for_task_id(self, task_id, connection):
-        routing_key = task_id.replace("-", "")
-        self._declare_queue(task_id, connection)
-        return Publisher(connection, exchange=self.exchange,
-                      exchange_type="direct",
-                      routing_key=routing_key)
+        return ResultPublisher(connection, task_id,
+                               exchange=self.exchange,
+                               exchange_type=self.exchange_type,
+                               delivery_mode=delivery_mode,
+                               serializer=self.serializer)
 
-    def _consumer_for_task_id(self, task_id, connection):
-        routing_key = task_id.replace("-", "")
-        self._declare_queue(task_id, connection)
-        return Consumer(connection, queue=routing_key,
-                        exchange=self.exchange,
-                        exchange_type="direct",
-                        no_ack=False, auto_ack=False,
-                        auto_delete=True,
-                        routing_key=routing_key)
+    def _create_consumer(self, task_id, connection):
+        return ResultConsumer(connection, task_id,
+                              exchange=self.exchange,
+                              exchange_type=self.exchange_type,
+                              durable=self.persistent)
 
     def store_result(self, task_id, result, status, traceback=None):
         """Send task return value and status."""
@@ -74,10 +79,11 @@ class AMQPBackend(BaseDictBackend):
                 "status": status,
                 "traceback": traceback}
 
-        connection = self.connection
-        publisher = self._publisher_for_task_id(task_id, connection)
-        publisher.send(meta, serializer="pickle")
-        publisher.close()
+        publisher = self._create_publisher(task_id, self.connection)
+        try:
+            publisher.send(meta)
+        finally:
+            publisher.close()
 
         return result
 
@@ -93,27 +99,22 @@ class AMQPBackend(BaseDictBackend):
             raise self.get_result(task_id)
 
     def _get_task_meta_for(self, task_id, timeout=None):
-        assert task_id not in self._seen
-        self._use_debug_tracking and self._seen.add(task_id)
-
         results = []
 
         def callback(message_data, message):
             results.append(message_data)
-            message.ack()
 
         routing_key = task_id.replace("-", "")
 
-        connection = self.connection
-        wait = connection.connection.wait_multi
-        consumer = self._consumer_for_task_id(task_id, connection)
+        wait = self.connection.connection.wait_multi
+        consumer = self._create_consumer(task_id, self.connection)
         consumer.register_callback(callback)
 
         consumer.consume()
         try:
             wait([consumer.backend.channel], timeout=timeout)
         finally:
-            consumer.backend.channel.queue_delete(routing_key)
+            consumer.backend.queue_delete(routing_key)
             consumer.close()
 
         self._cache[task_id] = results[0]
@@ -137,3 +138,13 @@ class AMQPBackend(BaseDictBackend):
         """Get the result of a taskset."""
         raise NotImplementedError(
                 "restore_taskset is not supported by this backend.")
+
+    def close(self):
+        if self._connection is not None:
+            self._connection.close()
+
+    @property
+    def connection(self):
+        if not self._connection:
+            self._connection = establish_connection()
+        return self._connection

+ 8 - 1
celery/conf.py

@@ -63,10 +63,14 @@ _DEFAULTS = {
     "CELERY_EVENT_ROUTING_KEY": "celeryevent",
     "CELERY_EVENT_SERIALIZER": "json",
     "CELERY_RESULT_EXCHANGE": "celeryresults",
+    "CELERY_RESULT_EXCHANGE_TYPE": "direct",
+    "CELERY_RESULT_SERIALIZER": "pickle",
+    "CELERY_RESULT_PERSISTENT": False,
     "CELERY_MAX_CACHED_RESULTS": 5000,
     "CELERY_TRACK_STARTED": False,
 }
 
+
 _DEPRECATION_FMT = """
 %s is deprecated in favor of %s and is scheduled for removal in celery v1.2.
 """.strip()
@@ -211,9 +215,12 @@ BROKER_CONNECTION_RETRY = _get("CELERY_BROKER_CONNECTION_RETRY",
 BROKER_CONNECTION_MAX_RETRIES = _get("CELERY_BROKER_CONNECTION_MAX_RETRIES",
                                 compat=["CELERY_AMQP_CONNECTION_MAX_RETRIES"])
 
-# :--- Backend settings                             <-   --   --- - ----- -- #
+# :--- AMQP Backend settings                        <-   --   --- - ----- -- #
 
 RESULT_EXCHANGE = _get("CELERY_RESULT_EXCHANGE")
+RESULT_EXCHANGE_TYPE = _get("CELERY_RESULT_EXCHANGE_TYPE")
+RESULT_SERIALIZER = _get("CELERY_RESULT_SERIALIZER")
+RESULT_PERSISTENT = _get("CELERY_RESULT_PERSISTENT")
 
 # :--- Celery Beat                                  <-   --   --- - ----- -- #
 CELERYBEAT_LOG_LEVEL = _get("CELERYBEAT_LOG_LEVEL")

+ 23 - 21
celery/tests/test_backends/test_amqp.py

@@ -13,48 +13,50 @@ class SomeClass(object):
         self.data = data
 
 
-class TestRedisBackend(unittest.TestCase):
+class test_AMQPBackend(unittest.TestCase):
 
-    def setUp(self):
-        self.backend = AMQPBackend()
-        self.backend._use_debug_tracking = True
+    def create_backend(self):
+        return AMQPBackend(serializer="pickle", persistent=False)
 
     def test_mark_as_done(self):
-        tb = self.backend
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
 
         tid = gen_unique_id()
 
-        tb.mark_as_done(tid, 42)
-        self.assertTrue(tb.is_successful(tid))
-        self.assertEqual(tb.get_status(tid), states.SUCCESS)
-        self.assertEqual(tb.get_result(tid), 42)
-        self.assertTrue(tb._cache.get(tid))
-        self.assertTrue(tb.get_result(tid), 42)
+        tb1.mark_as_done(tid, 42)
+        self.assertTrue(tb2.is_successful(tid))
+        self.assertEqual(tb2.get_status(tid), states.SUCCESS)
+        self.assertEqual(tb2.get_result(tid), 42)
+        self.assertTrue(tb2._cache.get(tid))
+        self.assertTrue(tb2.get_result(tid), 42)
 
     def test_is_pickled(self):
-        tb = self.backend
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
 
         tid2 = gen_unique_id()
         result = {"foo": "baz", "bar": SomeClass(12345)}
-        tb.mark_as_done(tid2, result)
+        tb1.mark_as_done(tid2, result)
         # is serialized properly.
-        rindb = tb.get_result(tid2)
+        rindb = tb2.get_result(tid2)
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("bar").data, 12345)
 
     def test_mark_as_failure(self):
-        tb = self.backend
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
 
         tid3 = gen_unique_id()
         try:
             raise KeyError("foo")
         except KeyError, exception:
             einfo = ExceptionInfo(sys.exc_info())
-        tb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
-        self.assertFalse(tb.is_successful(tid3))
-        self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertIsInstance(tb.get_result(tid3), KeyError)
-        self.assertEqual(tb.get_traceback(tid3), einfo.traceback)
+        tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
+        self.assertFalse(tb2.is_successful(tid3))
+        self.assertEqual(tb2.get_status(tid3), states.FAILURE)
+        self.assertIsInstance(tb2.get_result(tid3), KeyError)
+        self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
 
     def test_process_cleanup(self):
-        self.backend.process_cleanup()
+        self.create_backend().process_cleanup()