|  | @@ -8,6 +8,7 @@ from celery.result import EagerResult
 | 
	
		
			
				|  |  |  from celery.backends import default_backend
 | 
	
		
			
				|  |  |  from celery.decorators import task as task_dec
 | 
	
		
			
				|  |  |  from celery.worker.listener import parse_iso8601
 | 
	
		
			
				|  |  | +from celery.exceptions import RetryTaskError
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  def return_True(*args, **kwargs):
 | 
	
		
			
				|  |  |      # Task run functions can't be closures/lambdas, as they're pickled.
 | 
	
	
		
			
				|  | @@ -21,6 +22,16 @@ def raise_exception(self, **kwargs):
 | 
	
		
			
				|  |  |      raise Exception("%s error" % self.__class__)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class MockApplyTask(task.Task):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def run(self, x, y):
 | 
	
		
			
				|  |  | +        return x * y
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def apply_async(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +        pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class IncrementCounterTask(task.Task):
 | 
	
		
			
				|  |  |      name = "c.unittest.increment_counter_task"
 | 
	
		
			
				|  |  |      count = 0
 | 
	
	
		
			
				|  | @@ -53,6 +64,27 @@ class RetryTask(task.Task):
 | 
	
		
			
				|  |  |              return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class RetryTaskMockApply(task.Task):
 | 
	
		
			
				|  |  | +    max_retries = 3
 | 
	
		
			
				|  |  | +    iterations = 0
 | 
	
		
			
				|  |  | +    applied = 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def run(self, arg1, arg2, kwarg=1, **kwargs):
 | 
	
		
			
				|  |  | +        self.__class__.iterations += 1
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        retries = kwargs["task_retries"]
 | 
	
		
			
				|  |  | +        if retries >= 3:
 | 
	
		
			
				|  |  | +            return arg1
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            kwargs.update({"kwarg": kwarg})
 | 
	
		
			
				|  |  | +            return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def apply_async(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +        self.applied = 1
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class MyCustomException(Exception):
 | 
	
		
			
				|  |  |      """Random custom exception."""
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -85,6 +117,23 @@ class TestTaskRetries(unittest.TestCase):
 | 
	
		
			
				|  |  |          self.assertEquals(result.get(), 0xFF)
 | 
	
		
			
				|  |  |          self.assertEquals(RetryTask.iterations, 4)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_retry_not_eager(self):
 | 
	
		
			
				|  |  | +        exc = Exception("baz")
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            RetryTaskMockApply.retry(args=[4, 4], kwargs={},
 | 
	
		
			
				|  |  | +                                     exc=exc, throw=False)
 | 
	
		
			
				|  |  | +            self.assertTrue(RetryTaskMockApply.applied)
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            RetryTaskMockApply.applied = 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
 | 
	
		
			
				|  |  | +                    args=[4, 4], kwargs={}, exc=exc, throw=True)
 | 
	
		
			
				|  |  | +            self.assertTrue(RetryTaskMockApply.applied)
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            RetryTaskMockApply.applied = 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_retry_with_kwargs(self):
 | 
	
		
			
				|  |  |          RetryTaskCustomExc.max_retries = 3
 | 
	
		
			
				|  |  |          RetryTaskCustomExc.iterations = 0
 | 
	
	
		
			
				|  | @@ -116,6 +165,12 @@ class TestTaskRetries(unittest.TestCase):
 | 
	
		
			
				|  |  |          self.assertEquals(RetryTask.iterations, 2)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class MockPublisher(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +        self.kwargs = kwargs
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class TestCeleryTasks(unittest.TestCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def createTaskCls(self, cls_name, task_name=None):
 | 
	
	
		
			
				|  | @@ -232,6 +287,18 @@ class TestCeleryTasks(unittest.TestCase):
 | 
	
		
			
				|  |  |          publisher = t1.get_publisher()
 | 
	
		
			
				|  |  |          self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_get_publisher(self):
 | 
	
		
			
				|  |  | +        from celery.task import base
 | 
	
		
			
				|  |  | +        old_pub = base.TaskPublisher
 | 
	
		
			
				|  |  | +        base.TaskPublisher = MockPublisher
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            p = IncrementCounterTask.get_publisher(exchange="foo",
 | 
	
		
			
				|  |  | +                                                   connection="bar")
 | 
	
		
			
				|  |  | +            self.assertEquals(p.kwargs["exchange"], "foo")
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            base.TaskPublisher = old_pub
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_get_logger(self):
 | 
	
		
			
				|  |  |          T1 = self.createTaskCls("T1", "c.unittest.t.t1")
 | 
	
		
			
				|  |  |          t1 = T1()
 | 
	
	
		
			
				|  | @@ -308,3 +375,45 @@ class TestTaskApply(unittest.TestCase):
 | 
	
		
			
				|  |  |          self.assertFalse(f.successful())
 | 
	
		
			
				|  |  |          self.assertTrue(f.traceback)
 | 
	
		
			
				|  |  |          self.assertRaises(KeyError, f.get)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class MyPeriodic(task.PeriodicTask):
 | 
	
		
			
				|  |  | +    run_every = timedelta(hours=1)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class TestPeriodicTask(unittest.TestCase):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_must_have_run_every(self):
 | 
	
		
			
				|  |  | +        self.assertRaises(NotImplementedError, type, "Foo",
 | 
	
		
			
				|  |  | +            (task.PeriodicTask, ), {"__module__": __name__})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_remaining_estimate(self):
 | 
	
		
			
				|  |  | +        self.assertTrue(isinstance(
 | 
	
		
			
				|  |  | +            MyPeriodic().remaining_estimate(datetime.now()),
 | 
	
		
			
				|  |  | +            timedelta))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_timedelta_seconds_returns_0_on_negative_time(self):
 | 
	
		
			
				|  |  | +        delta = timedelta(days=-2)
 | 
	
		
			
				|  |  | +        self.assertEquals(MyPeriodic().timedelta_seconds(delta), 0)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_timedelta_seconds(self):
 | 
	
		
			
				|  |  | +        deltamap = ((timedelta(seconds=1), 1),
 | 
	
		
			
				|  |  | +                    (timedelta(seconds=27), 27),
 | 
	
		
			
				|  |  | +                    (timedelta(minutes=3), 3 * 60),
 | 
	
		
			
				|  |  | +                    (timedelta(hours=4), 4 * 60 * 60),
 | 
	
		
			
				|  |  | +                    (timedelta(days=3), 3 * 86400))
 | 
	
		
			
				|  |  | +        for delta, seconds in deltamap:
 | 
	
		
			
				|  |  | +            self.assertEquals(MyPeriodic().timedelta_seconds(delta), seconds)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_is_due_not_due(self):
 | 
	
		
			
				|  |  | +        due, remaining = MyPeriodic().is_due(datetime.now())
 | 
	
		
			
				|  |  | +        self.assertFalse(due)
 | 
	
		
			
				|  |  | +        self.assertTrue(remaining > 60)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_is_due(self):
 | 
	
		
			
				|  |  | +        p = MyPeriodic()
 | 
	
		
			
				|  |  | +        due, remaining = p.is_due(datetime.now() - p.run_every)
 | 
	
		
			
				|  |  | +        self.assertTrue(due)
 | 
	
		
			
				|  |  | +        self.assertEquals(remaining, p.timedelta_seconds(p.run_every))
 | 
	
		
			
				|  |  | +
 |