|  | @@ -0,0 +1,295 @@
 | 
	
		
			
				|  |  | +import socket
 | 
	
		
			
				|  |  | +import sys
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from datetime import timedelta
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from celery import states
 | 
	
		
			
				|  |  | +from celery.app import app_or_default
 | 
	
		
			
				|  |  | +from celery.backends.amqp import AMQPBackend
 | 
	
		
			
				|  |  | +from celery.datastructures import ExceptionInfo
 | 
	
		
			
				|  |  | +from celery.exceptions import TimeoutError
 | 
	
		
			
				|  |  | +from celery.utils import gen_unique_id
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from celery.tests.compat import catch_warnings
 | 
	
		
			
				|  |  | +from celery.tests.utils import unittest
 | 
	
		
			
				|  |  | +from celery.tests.utils import execute_context, sleepdeprived
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class SomeClass(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, data):
 | 
	
		
			
				|  |  | +        self.data = data
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class test_AMQPBackend(unittest.TestCase):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def create_backend(self, **opts):
 | 
	
		
			
				|  |  | +        opts = dict(dict(serializer="pickle", persistent=False), **opts)
 | 
	
		
			
				|  |  | +        return AMQPBackend(**opts)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_mark_as_done(self):
 | 
	
		
			
				|  |  | +        tb1 = self.create_backend()
 | 
	
		
			
				|  |  | +        tb2 = self.create_backend()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        tid = gen_unique_id()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        tb1.mark_as_done(tid, 42)
 | 
	
		
			
				|  |  | +        self.assertEqual(tb2.get_status(tid), states.SUCCESS)
 | 
	
		
			
				|  |  | +        self.assertEqual(tb2.get_result(tid), 42)
 | 
	
		
			
				|  |  | +        self.assertTrue(tb2._cache.get(tid))
 | 
	
		
			
				|  |  | +        self.assertTrue(tb2.get_result(tid), 42)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_is_pickled(self):
 | 
	
		
			
				|  |  | +        tb1 = self.create_backend()
 | 
	
		
			
				|  |  | +        tb2 = self.create_backend()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        tid2 = gen_unique_id()
 | 
	
		
			
				|  |  | +        result = {"foo": "baz", "bar": SomeClass(12345)}
 | 
	
		
			
				|  |  | +        tb1.mark_as_done(tid2, result)
 | 
	
		
			
				|  |  | +        # is serialized properly.
 | 
	
		
			
				|  |  | +        rindb = tb2.get_result(tid2)
 | 
	
		
			
				|  |  | +        self.assertEqual(rindb.get("foo"), "baz")
 | 
	
		
			
				|  |  | +        self.assertEqual(rindb.get("bar").data, 12345)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_mark_as_failure(self):
 | 
	
		
			
				|  |  | +        tb1 = self.create_backend()
 | 
	
		
			
				|  |  | +        tb2 = self.create_backend()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        tid3 = gen_unique_id()
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            raise KeyError("foo")
 | 
	
		
			
				|  |  | +        except KeyError, exception:
 | 
	
		
			
				|  |  | +            einfo = ExceptionInfo(sys.exc_info())
 | 
	
		
			
				|  |  | +        tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
 | 
	
		
			
				|  |  | +        self.assertEqual(tb2.get_status(tid3), states.FAILURE)
 | 
	
		
			
				|  |  | +        self.assertIsInstance(tb2.get_result(tid3), KeyError)
 | 
	
		
			
				|  |  | +        self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_repair_uuid(self):
 | 
	
		
			
				|  |  | +        from celery.backends.amqp import repair_uuid
 | 
	
		
			
				|  |  | +        for i in range(10):
 | 
	
		
			
				|  |  | +            uuid = gen_unique_id()
 | 
	
		
			
				|  |  | +            self.assertEqual(repair_uuid(uuid.replace("-", "")), uuid)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_expires_defaults_to_config(self):
 | 
	
		
			
				|  |  | +        app = app_or_default()
 | 
	
		
			
				|  |  | +        prev = app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES
 | 
	
		
			
				|  |  | +        app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = 10
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            b = self.create_backend(expires=None)
 | 
	
		
			
				|  |  | +            self.assertEqual(b.queue_arguments.get("x-expires"), 10 * 1000.0)
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_expires_is_int(self):
 | 
	
		
			
				|  |  | +        b = self.create_backend(expires=48)
 | 
	
		
			
				|  |  | +        self.assertEqual(b.queue_arguments.get("x-expires"), 48 * 1000.0)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_expires_is_timedelta(self):
 | 
	
		
			
				|  |  | +        b = self.create_backend(expires=timedelta(minutes=1))
 | 
	
		
			
				|  |  | +        self.assertEqual(b.queue_arguments.get("x-expires"), 60 * 1000.0)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @sleepdeprived()
 | 
	
		
			
				|  |  | +    def test_store_result_retries(self):
 | 
	
		
			
				|  |  | +        from celery.backends.amqp import AMQResultWarning
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class _Producer(object):
 | 
	
		
			
				|  |  | +            iterations = 0
 | 
	
		
			
				|  |  | +            stop_raising_at = 5
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def __init__(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +                pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def publish(self, msg, *args, **kwargs):
 | 
	
		
			
				|  |  | +                if self.iterations > self.stop_raising_at:
 | 
	
		
			
				|  |  | +                    return
 | 
	
		
			
				|  |  | +                raise KeyError("foo")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class Backend(AMQPBackend):
 | 
	
		
			
				|  |  | +            Producer = _Producer
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        backend = Backend()
 | 
	
		
			
				|  |  | +        self.assertRaises(KeyError, backend.store_result,
 | 
	
		
			
				|  |  | +                          "foo", "bar", "STARTED", max_retries=None)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def with_catch_warnings(log):
 | 
	
		
			
				|  |  | +            backend.store_result("foo", "bar", "STARTED", max_retries=10)
 | 
	
		
			
				|  |  | +            return log[0].message
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        message = execute_context(catch_warnings(record=True),
 | 
	
		
			
				|  |  | +                                  with_catch_warnings)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertIsInstance(message, AMQResultWarning)
 | 
	
		
			
				|  |  | +        self.assertIn("Error sending result", message.args[0])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def assertState(self, retval, state):
 | 
	
		
			
				|  |  | +        self.assertEqual(retval["status"], state)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_poll_no_messages(self):
 | 
	
		
			
				|  |  | +        b = self.create_backend()
 | 
	
		
			
				|  |  | +        self.assertState(b.poll(gen_unique_id()), states.PENDING)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_poll_result(self):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class MockBinding(object):
 | 
	
		
			
				|  |  | +            delete_raises = [False]
 | 
	
		
			
				|  |  | +            get_returns = [True]
 | 
	
		
			
				|  |  | +            tried_to_delete = []
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def __init__(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +                pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def __call__(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +                return self
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def delete(self, **kwargs):
 | 
	
		
			
				|  |  | +                if self.delete_raises[0]:
 | 
	
		
			
				|  |  | +                    self.tried_to_delete.append(True)
 | 
	
		
			
				|  |  | +                    raise KeyError("foo")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def declare(self):
 | 
	
		
			
				|  |  | +                pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def get(self):
 | 
	
		
			
				|  |  | +                if self.get_returns[0]:
 | 
	
		
			
				|  |  | +                    class Object(object):
 | 
	
		
			
				|  |  | +                        payload = {"status": "STARTED",
 | 
	
		
			
				|  |  | +                                "result": None}
 | 
	
		
			
				|  |  | +                    return Object()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class MockBackend(AMQPBackend):
 | 
	
		
			
				|  |  | +            Queue = MockBinding
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        backend = MockBackend()
 | 
	
		
			
				|  |  | +        conn = backend.pool.acquire(block=False)
 | 
	
		
			
				|  |  | +        channel_errors = conn.transport.__class__.channel_errors
 | 
	
		
			
				|  |  | +        conn.transport.__class__.channel_errors = (KeyError, )
 | 
	
		
			
				|  |  | +        conn.release()
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            MockBinding.delete_raises[0] = True
 | 
	
		
			
				|  |  | +            backend.poll(gen_unique_id())
 | 
	
		
			
				|  |  | +            self.assertTrue(MockBinding.tried_to_delete)
 | 
	
		
			
				|  |  | +            MockBinding.delete_raises[0] = False
 | 
	
		
			
				|  |  | +            uuid = gen_unique_id()
 | 
	
		
			
				|  |  | +            backend.poll(uuid)
 | 
	
		
			
				|  |  | +            self.assertIn(uuid, backend._cache)
 | 
	
		
			
				|  |  | +            MockBinding.get_returns[0] = False
 | 
	
		
			
				|  |  | +            backend._cache[uuid] = "hello"
 | 
	
		
			
				|  |  | +            self.assertEqual(backend.poll(uuid), "hello")
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            conn = backend.pool.acquire(block=False)
 | 
	
		
			
				|  |  | +            conn.transport.__class__.channel_errors = channel_errors
 | 
	
		
			
				|  |  | +            conn.release()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_wait_for(self):
 | 
	
		
			
				|  |  | +        b = self.create_backend()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        uuid = gen_unique_id()
 | 
	
		
			
				|  |  | +        self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
 | 
	
		
			
				|  |  | +        b.store_result(uuid, None, states.STARTED)
 | 
	
		
			
				|  |  | +        self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
 | 
	
		
			
				|  |  | +        b.store_result(uuid, None, states.RETRY)
 | 
	
		
			
				|  |  | +        self.assertRaises(TimeoutError, b.wait_for, uuid, timeout=0.1)
 | 
	
		
			
				|  |  | +        b.store_result(uuid, 42, states.SUCCESS)
 | 
	
		
			
				|  |  | +        self.assertEqual(b.wait_for(uuid, timeout=1), 42)
 | 
	
		
			
				|  |  | +        b.store_result(uuid, 56, states.SUCCESS)
 | 
	
		
			
				|  |  | +        self.assertEqual(b.wait_for(uuid, timeout=1), 42,
 | 
	
		
			
				|  |  | +                         "result is cached")
 | 
	
		
			
				|  |  | +        self.assertEqual(b.wait_for(uuid, timeout=1, cache=False), 56)
 | 
	
		
			
				|  |  | +        b.store_result(uuid, KeyError("foo"), states.FAILURE)
 | 
	
		
			
				|  |  | +        self.assertRaises(KeyError, b.wait_for, uuid, timeout=1, cache=False)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_drain_events_remaining_timeouts(self):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class Connection(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def drain_events(self, timeout=None):
 | 
	
		
			
				|  |  | +                pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        b = self.create_backend()
 | 
	
		
			
				|  |  | +        conn = b.pool.acquire(block=False)
 | 
	
		
			
				|  |  | +        channel = conn.channel()
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            binding = b._create_binding(gen_unique_id())
 | 
	
		
			
				|  |  | +            consumer = b._create_consumer(binding, channel)
 | 
	
		
			
				|  |  | +            self.assertRaises(socket.timeout, b.drain_events,
 | 
	
		
			
				|  |  | +                              Connection(), consumer, timeout=0.1)
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            channel.close()
 | 
	
		
			
				|  |  | +            conn.release()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_get_many(self):
 | 
	
		
			
				|  |  | +        b = self.create_backend()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        uuids = []
 | 
	
		
			
				|  |  | +        for i in xrange(10):
 | 
	
		
			
				|  |  | +            uuid = gen_unique_id()
 | 
	
		
			
				|  |  | +            b.store_result(uuid, i, states.SUCCESS)
 | 
	
		
			
				|  |  | +            uuids.append(uuid)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        res = list(b.get_many(uuids, timeout=1))
 | 
	
		
			
				|  |  | +        expected_results = [(uuid, {"status": states.SUCCESS,
 | 
	
		
			
				|  |  | +                                    "result": i,
 | 
	
		
			
				|  |  | +                                    "traceback": None,
 | 
	
		
			
				|  |  | +                                    "task_id": uuid})
 | 
	
		
			
				|  |  | +                                for i, uuid in enumerate(uuids)]
 | 
	
		
			
				|  |  | +        self.assertItemsEqual(res, expected_results)
 | 
	
		
			
				|  |  | +        self.assertDictEqual(b._cache[res[0][0]], res[0][1])
 | 
	
		
			
				|  |  | +        cached_res = list(b.get_many(uuids, timeout=1))
 | 
	
		
			
				|  |  | +        self.assertItemsEqual(cached_res, expected_results)
 | 
	
		
			
				|  |  | +        b._cache[res[0][0]]["status"] = states.RETRY
 | 
	
		
			
				|  |  | +        self.assertRaises(socket.timeout, list,
 | 
	
		
			
				|  |  | +                          b.get_many(uuids, timeout=0.01))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_test_get_many_raises_outer_block(self):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class Backend(AMQPBackend):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def _create_consumer(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +                raise KeyError("foo")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        b = Backend()
 | 
	
		
			
				|  |  | +        self.assertRaises(KeyError, b.get_many(["id1"]).next)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_test_get_many_raises_inner_block(self):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class Backend(AMQPBackend):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def drain_events(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +                raise KeyError("foo")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        b = Backend()
 | 
	
		
			
				|  |  | +        self.assertRaises(KeyError, b.get_many(["id1"]).next)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_no_expires(self):
 | 
	
		
			
				|  |  | +        b = self.create_backend(expires=None)
 | 
	
		
			
				|  |  | +        app = app_or_default()
 | 
	
		
			
				|  |  | +        prev = app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES
 | 
	
		
			
				|  |  | +        app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = None
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            b = self.create_backend(expires=None)
 | 
	
		
			
				|  |  | +            self.assertRaises(KeyError, b.queue_arguments.__getitem__,
 | 
	
		
			
				|  |  | +                              "x-expires")
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_process_cleanup(self):
 | 
	
		
			
				|  |  | +        self.create_backend().process_cleanup()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_reload_task_result(self):
 | 
	
		
			
				|  |  | +        self.assertRaises(NotImplementedError,
 | 
	
		
			
				|  |  | +                          self.create_backend().reload_task_result, "x")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_reload_taskset_result(self):
 | 
	
		
			
				|  |  | +        self.assertRaises(NotImplementedError,
 | 
	
		
			
				|  |  | +                          self.create_backend().reload_taskset_result, "x")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_save_taskset(self):
 | 
	
		
			
				|  |  | +        self.assertRaises(NotImplementedError,
 | 
	
		
			
				|  |  | +                          self.create_backend().save_taskset, "x", "x")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_restore_taskset(self):
 | 
	
		
			
				|  |  | +        self.assertRaises(NotImplementedError,
 | 
	
		
			
				|  |  | +                          self.create_backend().restore_taskset, "x")
 |