|  | @@ -14,7 +14,7 @@ from celery.datastructures import ExceptionInfo
 | 
	
		
			
				|  |  |  from celery.exceptions import TimeoutError
 | 
	
		
			
				|  |  |  from celery.utils import uuid
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from celery.tests.utils import Case, sleepdeprived
 | 
	
		
			
				|  |  | +from celery.tests.utils import AppCase, sleepdeprived
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class SomeClass(object):
 | 
	
	
		
			
				|  | @@ -23,7 +23,7 @@ class SomeClass(object):
 | 
	
		
			
				|  |  |          self.data = data
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class test_AMQPBackend(Case):
 | 
	
		
			
				|  |  | +class test_AMQPBackend(AppCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def create_backend(self, **opts):
 | 
	
		
			
				|  |  |          opts = dict(dict(serializer="pickle", persistent=False), **opts)
 | 
	
	
		
			
				|  | @@ -101,35 +101,35 @@ class test_AMQPBackend(Case):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @sleepdeprived()
 | 
	
		
			
				|  |  |      def test_store_result_retries(self):
 | 
	
		
			
				|  |  | +        iterations = [0]
 | 
	
		
			
				|  |  | +        stop_raising_at = [5]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        class _Producer(object):
 | 
	
		
			
				|  |  | -            iterations = 0
 | 
	
		
			
				|  |  | -            stop_raising_at = 5
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            def __init__(self, *args, **kwargs):
 | 
	
		
			
				|  |  | -                pass
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            def publish(self, msg, *args, **kwargs):
 | 
	
		
			
				|  |  | -                if self.iterations > self.stop_raising_at:
 | 
	
		
			
				|  |  | -                    return
 | 
	
		
			
				|  |  | -                raise KeyError("foo")
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        class Backend(AMQPBackend):
 | 
	
		
			
				|  |  | -            Producer = _Producer
 | 
	
		
			
				|  |  | +        def publish(*args, **kwargs):
 | 
	
		
			
				|  |  | +            if iterations[0] > stop_raising_at[0]:
 | 
	
		
			
				|  |  | +                return
 | 
	
		
			
				|  |  | +            iterations[0] += 1
 | 
	
		
			
				|  |  | +            raise KeyError("foo")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        backend = Backend()
 | 
	
		
			
				|  |  | -        with self.assertRaises(KeyError):
 | 
	
		
			
				|  |  | -            backend.store_result("foo", "bar", "STARTED", max_retries=None)
 | 
	
		
			
				|  |  | +        backend = AMQPBackend()
 | 
	
		
			
				|  |  | +        from celery.app.amqp import TaskProducer
 | 
	
		
			
				|  |  | +        prod, TaskProducer.publish = TaskProducer.publish, publish
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            with self.assertRaises(KeyError):
 | 
	
		
			
				|  |  | +                backend.retry_policy["max_retries"] = None
 | 
	
		
			
				|  |  | +                backend.store_result("foo", "bar", "STARTED")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        with self.assertRaises(KeyError):
 | 
	
		
			
				|  |  | -            backend.store_result("foo", "bar", "STARTED", max_retries=10)
 | 
	
		
			
				|  |  | +            with self.assertRaises(KeyError):
 | 
	
		
			
				|  |  | +                backend.retry_policy["max_retries"] = 10
 | 
	
		
			
				|  |  | +                backend.store_result("foo", "bar", "STARTED")
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            TaskProducer.publish = prod
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def assertState(self, retval, state):
 | 
	
		
			
				|  |  |          self.assertEqual(retval["status"], state)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_no_messages(self):
 | 
	
		
			
				|  |  |          b = self.create_backend()
 | 
	
		
			
				|  |  | -        self.assertState(b.poll(uuid()), states.PENDING)
 | 
	
		
			
				|  |  | +        self.assertState(b.get_task_meta(uuid()), states.PENDING)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_poll_result(self):
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -167,7 +167,7 @@ class test_AMQPBackend(Case):
 | 
	
		
			
				|  |  |          results.put(Message(status=states.RECEIVED, seq=1))
 | 
	
		
			
				|  |  |          results.put(Message(status=states.STARTED, seq=2))
 | 
	
		
			
				|  |  |          results.put(Message(status=states.FAILURE, seq=3))
 | 
	
		
			
				|  |  | -        r1 = backend.poll(uuid())
 | 
	
		
			
				|  |  | +        r1 = backend.get_task_meta(uuid())
 | 
	
		
			
				|  |  |          self.assertDictContainsSubset({"status": states.FAILURE,
 | 
	
		
			
				|  |  |                                         "seq": 3}, r1,
 | 
	
		
			
				|  |  |                                         "FFWDs to the last state")
 | 
	
	
		
			
				|  | @@ -175,14 +175,14 @@ class test_AMQPBackend(Case):
 | 
	
		
			
				|  |  |          # Caches last known state.
 | 
	
		
			
				|  |  |          results.put(Message())
 | 
	
		
			
				|  |  |          tid = uuid()
 | 
	
		
			
				|  |  | -        backend.poll(tid)
 | 
	
		
			
				|  |  | +        backend.get_task_meta(tid)
 | 
	
		
			
				|  |  |          self.assertIn(tid, backend._cache, "Caches last known state")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Returns cache if no new states.
 | 
	
		
			
				|  |  |          results.queue.clear()
 | 
	
		
			
				|  |  |          assert not results.qsize()
 | 
	
		
			
				|  |  |          backend._cache[tid] = "hello"
 | 
	
		
			
				|  |  | -        self.assertEqual(backend.poll(tid), "hello",
 | 
	
		
			
				|  |  | +        self.assertEqual(backend.get_task_meta(tid), "hello",
 | 
	
		
			
				|  |  |                           "Returns cache if no new states")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_wait_for(self):
 | 
	
	
		
			
				|  | @@ -217,7 +217,7 @@ class test_AMQPBackend(Case):
 | 
	
		
			
				|  |  |          b = self.create_backend()
 | 
	
		
			
				|  |  |          with current_app.pool.acquire_channel(block=False) as (_, channel):
 | 
	
		
			
				|  |  |              binding = b._create_binding(uuid())
 | 
	
		
			
				|  |  | -            consumer = b._create_consumer(binding, channel)
 | 
	
		
			
				|  |  | +            consumer = b.Consumer(channel, binding, no_ack=True)
 | 
	
		
			
				|  |  |              with self.assertRaises(socket.timeout):
 | 
	
		
			
				|  |  |                  b.drain_events(Connection(), consumer, timeout=0.1)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -249,7 +249,7 @@ class test_AMQPBackend(Case):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          class Backend(AMQPBackend):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            def _create_consumer(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +            def Consumer(*args, **kwargs):
 | 
	
		
			
				|  |  |                  raise KeyError("foo")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          b = Backend()
 |