|
@@ -8,6 +8,7 @@ from celery.result import EagerResult
|
|
|
from celery.backends import default_backend
|
|
|
from celery.decorators import task as task_dec
|
|
|
from celery.worker.listener import parse_iso8601
|
|
|
+from celery.exceptions import RetryTaskError
|
|
|
|
|
|
def return_True(*args, **kwargs):
|
|
|
# Task run functions can't be closures/lambdas, as they're pickled.
|
|
@@ -21,6 +22,16 @@ def raise_exception(self, **kwargs):
|
|
|
raise Exception("%s error" % self.__class__)
|
|
|
|
|
|
|
|
|
+class MockApplyTask(task.Task):
|
|
|
+
|
|
|
+ def run(self, x, y):
|
|
|
+ return x * y
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def apply_async(self, *args, **kwargs):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
class IncrementCounterTask(task.Task):
|
|
|
name = "c.unittest.increment_counter_task"
|
|
|
count = 0
|
|
@@ -53,6 +64,27 @@ class RetryTask(task.Task):
|
|
|
return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
|
|
|
|
|
|
|
|
|
+class RetryTaskMockApply(task.Task):
|
|
|
+ max_retries = 3
|
|
|
+ iterations = 0
|
|
|
+ applied = 0
|
|
|
+
|
|
|
+ def run(self, arg1, arg2, kwarg=1, **kwargs):
|
|
|
+ self.__class__.iterations += 1
|
|
|
+
|
|
|
+ retries = kwargs["task_retries"]
|
|
|
+ if retries >= 3:
|
|
|
+ return arg1
|
|
|
+ else:
|
|
|
+ kwargs.update({"kwarg": kwarg})
|
|
|
+ return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
|
|
|
+
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def apply_async(self, *args, **kwargs):
|
|
|
+ self.applied = 1
|
|
|
+
|
|
|
+
|
|
|
class MyCustomException(Exception):
|
|
|
"""Random custom exception."""
|
|
|
|
|
@@ -85,6 +117,23 @@ class TestTaskRetries(unittest.TestCase):
|
|
|
self.assertEquals(result.get(), 0xFF)
|
|
|
self.assertEquals(RetryTask.iterations, 4)
|
|
|
|
|
|
+ def test_retry_not_eager(self):
|
|
|
+ exc = Exception("baz")
|
|
|
+ try:
|
|
|
+ RetryTaskMockApply.retry(args=[4, 4], kwargs={},
|
|
|
+ exc=exc, throw=False)
|
|
|
+ self.assertTrue(RetryTaskMockApply.applied)
|
|
|
+ finally:
|
|
|
+ RetryTaskMockApply.applied = 0
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
|
|
|
+ args=[4, 4], kwargs={}, exc=exc, throw=True)
|
|
|
+ self.assertTrue(RetryTaskMockApply.applied)
|
|
|
+ finally:
|
|
|
+ RetryTaskMockApply.applied = 0
|
|
|
+
|
|
|
+
|
|
|
def test_retry_with_kwargs(self):
|
|
|
RetryTaskCustomExc.max_retries = 3
|
|
|
RetryTaskCustomExc.iterations = 0
|
|
@@ -116,6 +165,12 @@ class TestTaskRetries(unittest.TestCase):
|
|
|
self.assertEquals(RetryTask.iterations, 2)
|
|
|
|
|
|
|
|
|
+class MockPublisher(object):
|
|
|
+
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ self.kwargs = kwargs
|
|
|
+
|
|
|
+
|
|
|
class TestCeleryTasks(unittest.TestCase):
|
|
|
|
|
|
def createTaskCls(self, cls_name, task_name=None):
|
|
@@ -232,6 +287,18 @@ class TestCeleryTasks(unittest.TestCase):
|
|
|
publisher = t1.get_publisher()
|
|
|
self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
|
|
|
|
|
|
+ def test_get_publisher(self):
|
|
|
+ from celery.task import base
|
|
|
+ old_pub = base.TaskPublisher
|
|
|
+ base.TaskPublisher = MockPublisher
|
|
|
+ try:
|
|
|
+ p = IncrementCounterTask.get_publisher(exchange="foo",
|
|
|
+ connection="bar")
|
|
|
+ self.assertEquals(p.kwargs["exchange"], "foo")
|
|
|
+ finally:
|
|
|
+ base.TaskPublisher = old_pub
|
|
|
+
|
|
|
+
|
|
|
def test_get_logger(self):
|
|
|
T1 = self.createTaskCls("T1", "c.unittest.t.t1")
|
|
|
t1 = T1()
|
|
@@ -308,3 +375,45 @@ class TestTaskApply(unittest.TestCase):
|
|
|
self.assertFalse(f.successful())
|
|
|
self.assertTrue(f.traceback)
|
|
|
self.assertRaises(KeyError, f.get)
|
|
|
+
|
|
|
+
|
|
|
+class MyPeriodic(task.PeriodicTask):
|
|
|
+ run_every = timedelta(hours=1)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class TestPeriodicTask(unittest.TestCase):
|
|
|
+
|
|
|
+ def test_must_have_run_every(self):
|
|
|
+ self.assertRaises(NotImplementedError, type, "Foo",
|
|
|
+ (task.PeriodicTask, ), {"__module__": __name__})
|
|
|
+
|
|
|
+ def test_remaining_estimate(self):
|
|
|
+ self.assertTrue(isinstance(
|
|
|
+ MyPeriodic().remaining_estimate(datetime.now()),
|
|
|
+ timedelta))
|
|
|
+
|
|
|
+ def test_timedelta_seconds_returns_0_on_negative_time(self):
|
|
|
+ delta = timedelta(days=-2)
|
|
|
+ self.assertEquals(MyPeriodic().timedelta_seconds(delta), 0)
|
|
|
+
|
|
|
+ def test_timedelta_seconds(self):
|
|
|
+ deltamap = ((timedelta(seconds=1), 1),
|
|
|
+ (timedelta(seconds=27), 27),
|
|
|
+ (timedelta(minutes=3), 3 * 60),
|
|
|
+ (timedelta(hours=4), 4 * 60 * 60),
|
|
|
+ (timedelta(days=3), 3 * 86400))
|
|
|
+ for delta, seconds in deltamap:
|
|
|
+ self.assertEquals(MyPeriodic().timedelta_seconds(delta), seconds)
|
|
|
+
|
|
|
+ def test_is_due_not_due(self):
|
|
|
+ due, remaining = MyPeriodic().is_due(datetime.now())
|
|
|
+ self.assertFalse(due)
|
|
|
+ self.assertTrue(remaining > 60)
|
|
|
+
|
|
|
+ def test_is_due(self):
|
|
|
+ p = MyPeriodic()
|
|
|
+ due, remaining = p.is_due(datetime.now() - p.run_every)
|
|
|
+ self.assertTrue(due)
|
|
|
+ self.assertEquals(remaining, p.timedelta_seconds(p.run_every))
|
|
|
+
|