|  | @@ -1,6 +1,7 @@
 | 
	
		
			
				|  |  |  from datetime import datetime, timedelta
 | 
	
		
			
				|  |  |  from functools import wraps
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +from mock import Mock
 | 
	
		
			
				|  |  |  from pyparsing import ParseException
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from celery import task
 | 
	
	
		
			
				|  | @@ -60,15 +61,15 @@ class RetryTask(task.Task):
 | 
	
		
			
				|  |  |      max_retries = 3
 | 
	
		
			
				|  |  |      iterations = 0
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def run(self, arg1, arg2, kwarg=1, **kwargs):
 | 
	
		
			
				|  |  | +    def run(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
 | 
	
		
			
				|  |  |          self.__class__.iterations += 1
 | 
	
		
			
				|  |  | +        rmax = self.max_retries if max_retries is None else max_retries
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        retries = kwargs["task_retries"]
 | 
	
		
			
				|  |  | -        if retries >= 3:
 | 
	
		
			
				|  |  | +        retries = self.request.retries
 | 
	
		
			
				|  |  | +        if care and retries >= rmax:
 | 
	
		
			
				|  |  |              return arg1
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  | -            kwargs.update({"kwarg": kwarg})
 | 
	
		
			
				|  |  | -            return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
 | 
	
		
			
				|  |  | +            return self.retry(countdown=0, max_retries=max_retries)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class RetryTaskNoArgs(task.Task):
 | 
	
	
		
			
				|  | @@ -137,6 +138,12 @@ class TestTaskRetries(unittest.TestCase):
 | 
	
		
			
				|  |  |          self.assertEqual(result.get(), 0xFF)
 | 
	
		
			
				|  |  |          self.assertEqual(RetryTask.iterations, 4)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        RetryTask.max_retries = 3
 | 
	
		
			
				|  |  | +        RetryTask.iterations = 0
 | 
	
		
			
				|  |  | +        result = RetryTask.apply([0xFF, 0xFFFF], {"max_retries": 10})
 | 
	
		
			
				|  |  | +        self.assertEqual(result.get(), 0xFF)
 | 
	
		
			
				|  |  | +        self.assertEqual(RetryTask.iterations, 11)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_retry_no_args(self):
 | 
	
		
			
				|  |  |          RetryTaskNoArgs.max_retries = 3
 | 
	
		
			
				|  |  |          RetryTaskNoArgs.iterations = 0
 | 
	
	
		
			
				|  | @@ -183,14 +190,14 @@ class TestTaskRetries(unittest.TestCase):
 | 
	
		
			
				|  |  |      def test_max_retries_exceeded(self):
 | 
	
		
			
				|  |  |          RetryTask.max_retries = 2
 | 
	
		
			
				|  |  |          RetryTask.iterations = 0
 | 
	
		
			
				|  |  | -        result = RetryTask.apply([0xFF, 0xFFFF])
 | 
	
		
			
				|  |  | +        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
 | 
	
		
			
				|  |  |          self.assertRaises(RetryTask.MaxRetriesExceededError,
 | 
	
		
			
				|  |  |                            result.get)
 | 
	
		
			
				|  |  |          self.assertEqual(RetryTask.iterations, 3)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          RetryTask.max_retries = 1
 | 
	
		
			
				|  |  |          RetryTask.iterations = 0
 | 
	
		
			
				|  |  | -        result = RetryTask.apply([0xFF, 0xFFFF])
 | 
	
		
			
				|  |  | +        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
 | 
	
		
			
				|  |  |          self.assertRaises(RetryTask.MaxRetriesExceededError,
 | 
	
		
			
				|  |  |                            result.get)
 | 
	
		
			
				|  |  |          self.assertEqual(RetryTask.iterations, 2)
 | 
	
	
		
			
				|  | @@ -318,6 +325,23 @@ class TestCeleryTasks(unittest.TestCase):
 | 
	
		
			
				|  |  |          publisher = t1.get_publisher()
 | 
	
		
			
				|  |  |          self.assertTrue(publisher.exchange)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_context_get(self):
 | 
	
		
			
				|  |  | +        request = self.createTaskCls("T1", "c.unittest.t.c.g").request
 | 
	
		
			
				|  |  | +        request.foo = 32
 | 
	
		
			
				|  |  | +        self.assertEqual(request.get("foo"), 32)
 | 
	
		
			
				|  |  | +        self.assertEqual(request.get("bar", 36), 36)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_repr(self):
 | 
	
		
			
				|  |  | +        task = self.createTaskCls("T1", "c.unittest.t.repr")
 | 
	
		
			
				|  |  | +        self.assertIn("class Task of", repr(task.app.Task))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_after_return(self):
 | 
	
		
			
				|  |  | +        task = self.createTaskCls("T1", "c.unittest.t.after_return")()
 | 
	
		
			
				|  |  | +        task.backend = Mock()
 | 
	
		
			
				|  |  | +        task.request.chord = 123
 | 
	
		
			
				|  |  | +        task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
 | 
	
		
			
				|  |  | +        task.backend.on_chord_part_return.assert_called_with(task)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_send_task_sent_event(self):
 | 
	
		
			
				|  |  |          T1 = self.createTaskCls("T1", "c.unittest.t.t1")
 | 
	
		
			
				|  |  |          conn = T1.app.broker_connection()
 |