Browse Source

99% Coverage for celery.backends.amqpÃ

Ask Solem 14 years ago
parent
commit
18d2b79f76

+ 26 - 25
celery/backends/amqp.py

@@ -35,6 +35,10 @@ class AMQPBackend(BaseDictBackend):
     still cached locally by the backend instance).
 
     """
+    Exchange = Exchange
+    Queue = Queue
+    Consumer = Consumer
+    Producer = Producer
 
     _pool = None
 
@@ -51,11 +55,11 @@ class AMQPBackend(BaseDictBackend):
             persistent = conf.CELERY_RESULT_PERSISTENT
         self.persistent = persistent
         delivery_mode = persistent and "persistent" or "transient"
-        self.exchange = Exchange(name=exchange,
-                                 type=exchange_type,
-                                 delivery_mode=delivery_mode,
-                                 durable=self.persistent,
-                                 auto_delete=auto_delete)
+        self.exchange = self.Exchange(name=exchange,
+                                      type=exchange_type,
+                                      delivery_mode=delivery_mode,
+                                      durable=self.persistent,
+                                      auto_delete=auto_delete)
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.auto_delete = auto_delete
         self.expires = expires
@@ -74,22 +78,22 @@ class AMQPBackend(BaseDictBackend):
 
     def _create_binding(self, task_id):
         name = task_id.replace("-", "")
-        return Queue(name=name,
-                     exchange=self.exchange,
-                     routing_key=name,
-                     durable=self.persistent,
-                     auto_delete=self.auto_delete)
+        return self.Queue(name=name,
+                          exchange=self.exchange,
+                          routing_key=name,
+                          durable=self.persistent,
+                          auto_delete=self.auto_delete)
 
     def _create_producer(self, task_id, channel):
         binding = self._create_binding(task_id)
         binding(channel).declare()
 
-        return Producer(channel, exchange=self.exchange,
-                        routing_key=task_id.replace("-", ""),
-                        serializer=self.serializer)
+        return self.Producer(channel, exchange=self.exchange,
+                             routing_key=task_id.replace("-", ""),
+                             serializer=self.serializer)
 
     def _create_consumer(self, bindings, channel):
-        return Consumer(channel, bindings, no_ack=True)
+        return self.Consumer(channel, bindings, no_ack=True)
 
     def store_result(self, task_id, result, status, traceback=None,
             max_retries=20, retry_delay=0.2):
@@ -101,7 +105,7 @@ class AMQPBackend(BaseDictBackend):
                 "status": status,
                 "traceback": traceback}
 
-        for i in range(max_retries + 1):
+        for i in xrange((max_retries or 0) + 1):
             conn = self.pool.acquire(block=True)
             channel = conn.channel()
             try:
@@ -130,7 +134,8 @@ class AMQPBackend(BaseDictBackend):
     def wait_for(self, task_id, timeout=None, cache=True):
         cached_meta = self._cache.get(task_id)
 
-        if cached_meta and cached_meta["status"] in states.READY_STATES:
+        if cache and cached_meta and \
+                cached_meta["status"] in states.READY_STATES:
             meta = cached_meta
         else:
             try:
@@ -167,8 +172,8 @@ class AMQPBackend(BaseDictBackend):
             channel.close()
             conn.release()
 
-    def drain_events(self, consumer, timeout=None):
-        wait = consumer.channel.connection.drain_events
+    def drain_events(self, connection, consumer, timeout=None):
+        wait = connection.drain_events
         results = {}
 
         def callback(meta, message):
@@ -199,7 +204,8 @@ class AMQPBackend(BaseDictBackend):
             consumer = self._create_consumer(binding, channel)
             consumer.consume()
             try:
-                return self.drain_events(consumer, timeout=timeout).values()[0]
+                return self.drain_events(conn, consumer,
+                                         timeout=timeout).values()[0]
             finally:
                 consumer.cancel()
         finally:
@@ -228,7 +234,7 @@ class AMQPBackend(BaseDictBackend):
             consumer.consume()
             try:
                 while ids:
-                    r = self.drain_events(consumer, timeout=timeout)
+                    r = self.drain_events(conn, consumer, timeout=timeout)
                     ids ^= set(r.keys())
                     for ready_id, ready_meta in r.items():
                         yield ready_id, ready_meta
@@ -244,11 +250,6 @@ class AMQPBackend(BaseDictBackend):
         channel.close()
         conn.release()
 
-    def close(self):
-        if self._pool is not None:
-            self._pool.close()
-            self._pool = None
-
     def reload_task_result(self, task_id):
         raise NotImplementedError(
                 "reload_task_result is not supported by this backend.")

+ 0 - 60
celery/tests/test_backends/disabled_amqp.py

@@ -1,60 +0,0 @@
-import sys
-from celery.tests.utils import unittest
-
-from celery import states
-from celery.utils import gen_unique_id
-from celery.backends.amqp import AMQPBackend
-from celery.datastructures import ExceptionInfo
-
-
-class SomeClass(object):
-
-    def __init__(self, data):
-        self.data = data
-
-
-class test_AMQPBackend(unittest.TestCase):
-
-    def create_backend(self):
-        return AMQPBackend(serializer="pickle", persistent=False)
-
-    def test_mark_as_done(self):
-        tb1 = self.create_backend()
-        tb2 = self.create_backend()
-
-        tid = gen_unique_id()
-
-        tb1.mark_as_done(tid, 42)
-        self.assertEqual(tb2.get_status(tid), states.SUCCESS)
-        self.assertEqual(tb2.get_result(tid), 42)
-        self.assertTrue(tb2._cache.get(tid))
-        self.assertTrue(tb2.get_result(tid), 42)
-
-    def test_is_pickled(self):
-        tb1 = self.create_backend()
-        tb2 = self.create_backend()
-
-        tid2 = gen_unique_id()
-        result = {"foo": "baz", "bar": SomeClass(12345)}
-        tb1.mark_as_done(tid2, result)
-        # is serialized properly.
-        rindb = tb2.get_result(tid2)
-        self.assertEqual(rindb.get("foo"), "baz")
-        self.assertEqual(rindb.get("bar").data, 12345)
-
-    def test_mark_as_failure(self):
-        tb1 = self.create_backend()
-        tb2 = self.create_backend()
-
-        tid3 = gen_unique_id()
-        try:
-            raise KeyError("foo")
-        except KeyError, exception:
-            einfo = ExceptionInfo(sys.exc_info())
-        tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
-        self.assertEqual(tb2.get_status(tid3), states.FAILURE)
-        self.assertIsInstance(tb2.get_result(tid3), KeyError)
-        self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
-
-    def test_process_cleanup(self):
-        self.create_backend().process_cleanup()

+ 295 - 0
celery/tests/test_backends/test_amqp.py

@@ -0,0 +1,295 @@
+import socket
+import sys
+
+from datetime import timedelta
+
+from celery import states
+from celery.app import app_or_default
+from celery.backends.amqp import AMQPBackend
+from celery.datastructures import ExceptionInfo
+from celery.exceptions import TimeoutError
+from celery.utils import gen_unique_id
+
+from celery.tests.compat import catch_warnings
+from celery.tests.utils import unittest
+from celery.tests.utils import execute_context, sleepdeprived
+
+
+class SomeClass(object):
+
+    def __init__(self, data):
+        self.data = data
+
+
+class test_AMQPBackend(unittest.TestCase):
+
+    def create_backend(self, **opts):
+        opts = dict(dict(serializer="pickle", persistent=False), **opts)
+        return AMQPBackend(**opts)
+
+    def test_mark_as_done(self):
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
+
+        tid = gen_unique_id()
+
+        tb1.mark_as_done(tid, 42)
+        self.assertEqual(tb2.get_status(tid), states.SUCCESS)
+        self.assertEqual(tb2.get_result(tid), 42)
+        self.assertTrue(tb2._cache.get(tid))
+        self.assertTrue(tb2.get_result(tid), 42)
+
+    def test_is_pickled(self):
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
+
+        tid2 = gen_unique_id()
+        result = {"foo": "baz", "bar": SomeClass(12345)}
+        tb1.mark_as_done(tid2, result)
+        # is serialized properly.
+        rindb = tb2.get_result(tid2)
+        self.assertEqual(rindb.get("foo"), "baz")
+        self.assertEqual(rindb.get("bar").data, 12345)
+
+    def test_mark_as_failure(self):
+        tb1 = self.create_backend()
+        tb2 = self.create_backend()
+
+        tid3 = gen_unique_id()
+        try:
+            raise KeyError("foo")
+        except KeyError, exception:
+            einfo = ExceptionInfo(sys.exc_info())
+        tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
+        self.assertEqual(tb2.get_status(tid3), states.FAILURE)
+        self.assertIsInstance(tb2.get_result(tid3), KeyError)
+        self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
+
+    def test_repair_uuid(self):
+        from celery.backends.amqp import repair_uuid
+        for i in range(10):
+            uuid = gen_unique_id()
+            self.assertEqual(repair_uuid(uuid.replace("-", "")), uuid)
+
+    def test_expires_defaults_to_config(self):
+        app = app_or_default()
+        prev = app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES
+        app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = 10
+        try:
+            b = self.create_backend(expires=None)
+            self.assertEqual(b.queue_arguments.get("x-expires"), 10 * 1000.0)
+        finally:
+            app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
+
+    def test_expires_is_int(self):
+        b = self.create_backend(expires=48)
+        self.assertEqual(b.queue_arguments.get("x-expires"), 48 * 1000.0)
+
+    def test_expires_is_timedelta(self):
+        b = self.create_backend(expires=timedelta(minutes=1))
+        self.assertEqual(b.queue_arguments.get("x-expires"), 60 * 1000.0)
+
+    @sleepdeprived()
+    def test_store_result_retries(self):
+        from celery.backends.amqp import AMQResultWarning
+
+        class _Producer(object):
+            iterations = 0
+            stop_raising_at = 5
+
+            def __init__(self, *args, **kwargs):
+                pass
+
+            def publish(self, msg, *args, **kwargs):
+                if self.iterations > self.stop_raising_at:
+                    return
+                raise KeyError("foo")
+
+        class Backend(AMQPBackend):
+            Producer = _Producer
+
+        backend = Backend()
+        self.assertRaises(KeyError, backend.store_result,
+                          "foo", "bar", "STARTED", max_retries=None)
+
+        def with_catch_warnings(log):
+            backend.store_result("foo", "bar", "STARTED", max_retries=10)
+            return log[0].message
+
+        message = execute_context(catch_warnings(record=True),
+                                  with_catch_warnings)
+
+        self.assertIsInstance(message, AMQResultWarning)
+        self.assertIn("Error sending result", message.args[0])
+
+    def assertState(self, retval, state):
+        self.assertEqual(retval["status"], state)
+
+    def test_poll_no_messages(self):
+        b = self.create_backend()
+        self.assertState(b.poll(gen_unique_id()), states.PENDING)
+
+    def test_poll_result(self):
+
+        class MockBinding(object):
+            delete_raises = [False]
+            get_returns = [True]
+            tried_to_delete = []
+
+            def __init__(self, *args, **kwargs):
+                pass
+
+            def __call__(self, *args, **kwargs):
+                return self
+
+            def delete(self, **kwargs):
+                if self.delete_raises[0]:
+                    self.tried_to_delete.append(True)
+                    raise KeyError("foo")
+
+            def declare(self):
+                pass
+
+            def get(self):
+                if self.get_returns[0]:
+                    class Object(object):
+                        payload = {"status": "STARTED",
+                                "result": None}
+                    return Object()
+
+        class MockBackend(AMQPBackend):
+            Queue = MockBinding
+
+        backend = MockBackend()
+        conn = backend.pool.acquire(block=False)
+        channel_errors = conn.transport.__class__.channel_errors
+        conn.transport.__class__.channel_errors = (KeyError, )
+        conn.release()
+        try:
+            MockBinding.delete_raises[0] = True
+            backend.poll(gen_unique_id())
+            self.assertTrue(MockBinding.tried_to_delete)
+            MockBinding.delete_raises[0] = False
+            uuid = gen_unique_id()
+            backend.poll(uuid)
+            self.assertIn(uuid, backend._cache)
+            MockBinding.get_returns[0] = False
+            backend._cache[uuid] = "hello"
+            self.assertEqual(backend.poll(uuid), "hello")
+        finally:
+            conn = backend.pool.acquire(block=False)
+            conn.transport.__class__.channel_errors = channel_errors
+            conn.release()
+
+    def test_wait_for(self):
+        b = self.create_backend()
+
+        uuid = gen_unique_id()
+        self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
+        b.store_result(uuid, None, states.STARTED)
+        self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
+        b.store_result(uuid, None, states.RETRY)
+        self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
+        b.store_result(uuid, 42, states.SUCCESS)
+        self.assertEqual(b.wait_for(uuid, timeout=1), 42)
+        b.store_result(uuid, 56, states.SUCCESS)
+        self.assertEqual(b.wait_for(uuid, timeout=1), 42,
+                         "result is cached")
+        self.assertEqual(b.wait_for(uuid, timeout=1, cache=False), 56)
+        b.store_result(uuid, KeyError("foo"), states.FAILURE)
+        self.assertRaises(KeyError, b.wait_for, uuid, timeout=1, cache=False)
+
+    def test_drain_events_remaining_timeouts(self):
+
+        class Connection(object):
+
+            def drain_events(self, timeout=None):
+                pass
+
+        b = self.create_backend()
+        conn = b.pool.acquire(block=False)
+        channel = conn.channel()
+        try:
+            binding = b._create_binding(gen_unique_id())
+            consumer = b._create_consumer(binding, channel)
+            self.assertRaises(socket.timeout, b.drain_events,
+                              Connection(), consumer, timeout=0.1)
+        finally:
+            channel.close()
+            conn.release()
+
+    def test_get_many(self):
+        b = self.create_backend()
+
+        uuids = []
+        for i in xrange(10):
+            uuid = gen_unique_id()
+            b.store_result(uuid, i, states.SUCCESS)
+            uuids.append(uuid)
+
+        res = list(b.get_many(uuids, timeout=1))
+        expected_results = [(uuid, {"status": states.SUCCESS,
+                                    "result": i,
+                                    "traceback": None,
+                                    "task_id": uuid})
+                                for i, uuid in enumerate(uuids)]
+        self.assertItemsEqual(res, expected_results)
+        self.assertDictEqual(b._cache[res[0][0]], res[0][1])
+        cached_res = list(b.get_many(uuids, timeout=1))
+        self.assertItemsEqual(cached_res, expected_results)
+        b._cache[res[0][0]]["status"] = states.RETRY
+        self.assertRaises(socket.timeout, list,
+                          b.get_many(uuids, timeout=0.01))
+
+
+    def test_test_get_many_raises_outer_block(self):
+
+        class Backend(AMQPBackend):
+
+            def _create_consumer(self, *args, **kwargs):
+                raise KeyError("foo")
+
+        b = Backend()
+        self.assertRaises(KeyError, b.get_many(["id1"]).next)
+
+    def test_test_get_many_raises_inner_block(self):
+
+        class Backend(AMQPBackend):
+
+            def drain_events(self, *args, **kwargs):
+                raise KeyError("foo")
+
+        b = Backend()
+        self.assertRaises(KeyError, b.get_many(["id1"]).next)
+
+
+    def test_no_expires(self):
+        b = self.create_backend(expires=None)
+        app = app_or_default()
+        prev = app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES
+        app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = None
+        try:
+            b = self.create_backend(expires=None)
+            self.assertRaises(KeyError, b.queue_arguments.__getitem__,
+                              "x-expires")
+        finally:
+            app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
+
+    def test_process_cleanup(self):
+        self.create_backend().process_cleanup()
+
+    def test_reload_task_result(self):
+        self.assertRaises(NotImplementedError,
+                          self.create_backend().reload_task_result, "x")
+
+    def test_reload_taskset_result(self):
+        self.assertRaises(NotImplementedError,
+                          self.create_backend().reload_taskset_result, "x")
+
+    def test_save_taskset(self):
+        self.assertRaises(NotImplementedError,
+                          self.create_backend().save_taskset, "x", "x")
+
+    def test_restore_taskset(self):
+        self.assertRaises(NotImplementedError,
+                          self.create_backend().restore_taskset, "x")

+ 0 - 1
setup.cfg

@@ -30,7 +30,6 @@ cover3-exclude = celery
                  celery.backends.mongodb
                  celery.backends.tyrant
                  celery.backends.pyredis
-                 celery.backends.amqp
                  celery.backends.cassandra
                  celery.events.dumper
                  celery.events.cursesmon