Browse Source

AMQP Backend: Now set as non-persistent+non-durable by default.

Set CELERY_RESULT_PERSISTENT=False to revert to the previous settings.
Ask Solem 15 years ago
parent
commit
296e8d10a5
3 changed files with 92 additions and 72 deletions
  1. 61 50
      celery/backends/amqp.py
  2. 8 1
      celery/conf.py
  3. 23 21
      celery/tests/test_backends/test_amqp.py

+ 61 - 50
celery/backends/amqp.py

@@ -10,6 +10,32 @@ from celery.backends.base import BaseDictBackend
 from celery.messaging import establish_connection
 from celery.messaging import establish_connection
 
 
 
 
+class ResultPublisher(Publisher):
+    exchange = conf.RESULT_EXCHANGE
+    exchange_type = conf.RESULT_EXCHANGE_TYPE
+    delivery_mode = conf.RESULT_PERSISTENT and 2 or 1
+    serializer = conf.RESULT_SERIALIZER
+    durable = conf.RESULT_PERSISTENT
+
+    def __init__(self, connection, task_id, **kwargs):
+        super(ResultPublisher, self).__init__(connection,
+                        routing_key=task_id.replace("-", ""),
+                        **kwargs)
+
+
+class ResultConsumer(Consumer):
+    exchange = conf.RESULT_EXCHANGE
+    exchange_type = conf.RESULT_EXCHANGE_TYPE
+    durable = conf.RESULT_PERSISTENT
+    no_ack = True
+    auto_delete = True
+
+    def __init__(self, connection, task_id, **kwargs):
+        routing_key = task_id.replace("-", "")
+        super(ResultConsumer, self).__init__(connection,
+                queue=routing_key, routing_key=routing_key, **kwargs)
+
+
 class AMQPBackend(BaseDictBackend):
 class AMQPBackend(BaseDictBackend):
     """AMQP backend. Publish results by sending messages to the broker
     """AMQP backend. Publish results by sending messages to the broker
     using the task id as routing key.
     using the task id as routing key.
@@ -21,49 +47,28 @@ class AMQPBackend(BaseDictBackend):
     """
     """
 
 
     exchange = conf.RESULT_EXCHANGE
     exchange = conf.RESULT_EXCHANGE
-    capabilities = ["ResultStore"]
+    exchange_type = conf.RESULT_EXCHANGE_TYPE
+    persistent = conf.RESULT_PERSISTENT
+    serializer = conf.RESULT_SERIALIZER
     _connection = None
     _connection = None
-    _use_debug_tracking = False
-    _seen = set()
 
 
-    def __init__(self, *args, **kwargs):
-        super(AMQPBackend, self).__init__(*args, **kwargs)
+    def _create_publisher(self, task_id, connection):
+        delivery_mode = self.persistent and 2 or 1
 
 
-    @property
-    def connection(self):
-        if not self._connection:
-            self._connection = establish_connection()
-        return self._connection
+        # Declares the queue.
+        self._create_consumer(task_id, connection).close()
 
 
-    def _declare_queue(self, task_id, connection):
-        routing_key = task_id.replace("-", "")
-        backend = connection.create_backend()
-        backend.queue_declare(queue=routing_key, durable=True,
-                                exclusive=False, auto_delete=True)
-        backend.exchange_declare(exchange=self.exchange,
-                                 type="direct",
-                                 durable=True,
-                                 auto_delete=False)
-        backend.queue_bind(queue=routing_key, exchange=self.exchange,
-                           routing_key=routing_key)
-        backend.close()
-
-    def _publisher_for_task_id(self, task_id, connection):
-        routing_key = task_id.replace("-", "")
-        self._declare_queue(task_id, connection)
-        return Publisher(connection, exchange=self.exchange,
-                      exchange_type="direct",
-                      routing_key=routing_key)
+        return ResultPublisher(connection, task_id,
+                               exchange=self.exchange,
+                               exchange_type=self.exchange_type,
+                               delivery_mode=delivery_mode,
+                               serializer=self.serializer)
 
 
-    def _consumer_for_task_id(self, task_id, connection):
-        routing_key = task_id.replace("-", "")
-        self._declare_queue(task_id, connection)
-        return Consumer(connection, queue=routing_key,
-                        exchange=self.exchange,
-                        exchange_type="direct",
-                        no_ack=False, auto_ack=False,
-                        auto_delete=True,
-                        routing_key=routing_key)
+    def _create_consumer(self, task_id, connection):
+        return ResultConsumer(connection, task_id,
+                              exchange=self.exchange,
+                              exchange_type=self.exchange_type,
+                              durable=self.persistent)
 
 
     def store_result(self, task_id, result, status, traceback=None):
     def store_result(self, task_id, result, status, traceback=None):
         """Send task return value and status."""
         """Send task return value and status."""
@@ -74,10 +79,11 @@ class AMQPBackend(BaseDictBackend):
                 "status": status,
                 "status": status,
                 "traceback": traceback}
                 "traceback": traceback}
 
 
-        connection = self.connection
-        publisher = self._publisher_for_task_id(task_id, connection)
-        publisher.send(meta, serializer="pickle")
-        publisher.close()
+        publisher = self._create_publisher(task_id, self.connection)
+        try:
+            publisher.send(meta)
+        finally:
+            publisher.close()
 
 
         return result
         return result
 
 
@@ -93,27 +99,22 @@ class AMQPBackend(BaseDictBackend):
             raise self.get_result(task_id)
             raise self.get_result(task_id)
 
 
     def _get_task_meta_for(self, task_id, timeout=None):
     def _get_task_meta_for(self, task_id, timeout=None):
-        assert task_id not in self._seen
-        self._use_debug_tracking and self._seen.add(task_id)
-
         results = []
         results = []
 
 
         def callback(message_data, message):
         def callback(message_data, message):
             results.append(message_data)
             results.append(message_data)
-            message.ack()
 
 
         routing_key = task_id.replace("-", "")
         routing_key = task_id.replace("-", "")
 
 
-        connection = self.connection
-        wait = connection.connection.wait_multi
-        consumer = self._consumer_for_task_id(task_id, connection)
+        wait = self.connection.connection.wait_multi
+        consumer = self._create_consumer(task_id, self.connection)
         consumer.register_callback(callback)
         consumer.register_callback(callback)
 
 
         consumer.consume()
         consumer.consume()
         try:
         try:
             wait([consumer.backend.channel], timeout=timeout)
             wait([consumer.backend.channel], timeout=timeout)
         finally:
         finally:
-            consumer.backend.channel.queue_delete(routing_key)
+            consumer.backend.queue_delete(routing_key)
             consumer.close()
             consumer.close()
 
 
         self._cache[task_id] = results[0]
         self._cache[task_id] = results[0]
@@ -137,3 +138,13 @@ class AMQPBackend(BaseDictBackend):
         """Get the result of a taskset."""
         """Get the result of a taskset."""
         raise NotImplementedError(
         raise NotImplementedError(
                 "restore_taskset is not supported by this backend.")
                 "restore_taskset is not supported by this backend.")
+
+    def close(self):
+        if self._connection is not None:
+            self._connection.close()
+
+    @property
+    def connection(self):
+        if not self._connection:
+            self._connection = establish_connection()
+        return self._connection

+ 8 - 1
celery/conf.py

@@ -63,10 +63,14 @@ _DEFAULTS = {
     "CELERY_EVENT_ROUTING_KEY": "celeryevent",
     "CELERY_EVENT_ROUTING_KEY": "celeryevent",
     "CELERY_EVENT_SERIALIZER": "json",
     "CELERY_EVENT_SERIALIZER": "json",
     "CELERY_RESULT_EXCHANGE": "celeryresults",
     "CELERY_RESULT_EXCHANGE": "celeryresults",
+    "CELERY_RESULT_EXCHANGE_TYPE": "direct",
+    "CELERY_RESULT_SERIALIZER": "pickle",
+    "CELERY_RESULT_PERSISTENT": False,
     "CELERY_MAX_CACHED_RESULTS": 5000,
     "CELERY_MAX_CACHED_RESULTS": 5000,
     "CELERY_TRACK_STARTED": False,
     "CELERY_TRACK_STARTED": False,
 }
 }
 
 
+
 _DEPRECATION_FMT = """
 _DEPRECATION_FMT = """
 %s is deprecated in favor of %s and is scheduled for removal in celery v1.2.
 %s is deprecated in favor of %s and is scheduled for removal in celery v1.2.
 """.strip()
 """.strip()
@@ -211,9 +215,12 @@ BROKER_CONNECTION_RETRY = _get("CELERY_BROKER_CONNECTION_RETRY",
 BROKER_CONNECTION_MAX_RETRIES = _get("CELERY_BROKER_CONNECTION_MAX_RETRIES",
 BROKER_CONNECTION_MAX_RETRIES = _get("CELERY_BROKER_CONNECTION_MAX_RETRIES",
                                 compat=["CELERY_AMQP_CONNECTION_MAX_RETRIES"])
                                 compat=["CELERY_AMQP_CONNECTION_MAX_RETRIES"])
 
 
-# :--- Backend settings                             <-   --   --- - ----- -- #
+# :--- AMQP Backend settings                        <-   --   --- - ----- -- #
 
 
 RESULT_EXCHANGE = _get("CELERY_RESULT_EXCHANGE")
 RESULT_EXCHANGE = _get("CELERY_RESULT_EXCHANGE")
+RESULT_EXCHANGE_TYPE = _get("CELERY_RESULT_EXCHANGE_TYPE")
+RESULT_SERIALIZER = _get("CELERY_RESULT_SERIALIZER")
+RESULT_PERSISTENT = _get("CELERY_RESULT_PERSISTENT")
 
 
 # :--- Celery Beat                                  <-   --   --- - ----- -- #
 # :--- Celery Beat                                  <-   --   --- - ----- -- #
 CELERYBEAT_LOG_LEVEL = _get("CELERYBEAT_LOG_LEVEL")
 CELERYBEAT_LOG_LEVEL = _get("CELERYBEAT_LOG_LEVEL")

+ 23 - 21
celery/tests/test_backends/test_amqp.py

@@ -13,48 +13,50 @@ class SomeClass(object):
         self.data = data
         self.data = data
 
 
 
 
-class TestRedisBackend(unittest.TestCase):
+class test_AMQPBackend(unittest.TestCase):
 
 
-    def setUp(self):
-        self.backend = AMQPBackend()
-        self.backend._use_debug_tracking = True
+    def create_backend(self):
+        return AMQPBackend(serializer="pickle", persistent=False)
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
-        tb = self.backend
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
 
 
         tid = gen_unique_id()
         tid = gen_unique_id()
 
 
-        tb.mark_as_done(tid, 42)
-        self.assertTrue(tb.is_successful(tid))
-        self.assertEqual(tb.get_status(tid), states.SUCCESS)
-        self.assertEqual(tb.get_result(tid), 42)
-        self.assertTrue(tb._cache.get(tid))
-        self.assertTrue(tb.get_result(tid), 42)
+        tb1.mark_as_done(tid, 42)
+        self.assertTrue(tb2.is_successful(tid))
+        self.assertEqual(tb2.get_status(tid), states.SUCCESS)
+        self.assertEqual(tb2.get_result(tid), 42)
+        self.assertTrue(tb2._cache.get(tid))
+        self.assertTrue(tb2.get_result(tid), 42)
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
-        tb = self.backend
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
 
 
         tid2 = gen_unique_id()
         tid2 = gen_unique_id()
         result = {"foo": "baz", "bar": SomeClass(12345)}
         result = {"foo": "baz", "bar": SomeClass(12345)}
-        tb.mark_as_done(tid2, result)
+        tb1.mark_as_done(tid2, result)
         # is serialized properly.
         # is serialized properly.
-        rindb = tb.get_result(tid2)
+        rindb = tb2.get_result(tid2)
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("bar").data, 12345)
         self.assertEqual(rindb.get("bar").data, 12345)
 
 
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
-        tb = self.backend
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
 
 
         tid3 = gen_unique_id()
         tid3 = gen_unique_id()
         try:
         try:
             raise KeyError("foo")
             raise KeyError("foo")
         except KeyError, exception:
         except KeyError, exception:
             einfo = ExceptionInfo(sys.exc_info())
             einfo = ExceptionInfo(sys.exc_info())
-        tb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
-        self.assertFalse(tb.is_successful(tid3))
-        self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertIsInstance(tb.get_result(tid3), KeyError)
-        self.assertEqual(tb.get_traceback(tid3), einfo.traceback)
+        tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
+        self.assertFalse(tb2.is_successful(tid3))
+        self.assertEqual(tb2.get_status(tid3), states.FAILURE)
+        self.assertIsInstance(tb2.get_result(tid3), KeyError)
+        self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
-        self.backend.process_cleanup()
+        self.create_backend().process_cleanup()