|
@@ -1,6 +1,7 @@
|
|
|
from datetime import datetime, timedelta
|
|
|
from functools import wraps
|
|
|
|
|
|
+from mock import Mock
|
|
|
from pyparsing import ParseException
|
|
|
|
|
|
from celery import task
|
|
@@ -60,15 +61,15 @@ class RetryTask(task.Task):
|
|
|
max_retries = 3
|
|
|
iterations = 0
|
|
|
|
|
|
- def run(self, arg1, arg2, kwarg=1, **kwargs):
|
|
|
+ def run(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
|
|
|
self.__class__.iterations += 1
|
|
|
+ rmax = self.max_retries if max_retries is None else max_retries
|
|
|
|
|
|
- retries = kwargs["task_retries"]
|
|
|
- if retries >= 3:
|
|
|
+ retries = self.request.retries
|
|
|
+ if care and retries >= rmax:
|
|
|
return arg1
|
|
|
else:
|
|
|
- kwargs.update({"kwarg": kwarg})
|
|
|
- return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
|
|
|
+ return self.retry(countdown=0, max_retries=max_retries)
|
|
|
|
|
|
|
|
|
class RetryTaskNoArgs(task.Task):
|
|
@@ -137,6 +138,12 @@ class TestTaskRetries(unittest.TestCase):
|
|
|
self.assertEqual(result.get(), 0xFF)
|
|
|
self.assertEqual(RetryTask.iterations, 4)
|
|
|
|
|
|
+ RetryTask.max_retries = 3
|
|
|
+ RetryTask.iterations = 0
|
|
|
+ result = RetryTask.apply([0xFF, 0xFFFF], {"max_retries": 10})
|
|
|
+ self.assertEqual(result.get(), 0xFF)
|
|
|
+ self.assertEqual(RetryTask.iterations, 11)
|
|
|
+
|
|
|
def test_retry_no_args(self):
|
|
|
RetryTaskNoArgs.max_retries = 3
|
|
|
RetryTaskNoArgs.iterations = 0
|
|
@@ -183,14 +190,14 @@ class TestTaskRetries(unittest.TestCase):
|
|
|
def test_max_retries_exceeded(self):
|
|
|
RetryTask.max_retries = 2
|
|
|
RetryTask.iterations = 0
|
|
|
- result = RetryTask.apply([0xFF, 0xFFFF])
|
|
|
+ result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
|
|
|
self.assertRaises(RetryTask.MaxRetriesExceededError,
|
|
|
result.get)
|
|
|
self.assertEqual(RetryTask.iterations, 3)
|
|
|
|
|
|
RetryTask.max_retries = 1
|
|
|
RetryTask.iterations = 0
|
|
|
- result = RetryTask.apply([0xFF, 0xFFFF])
|
|
|
+ result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
|
|
|
self.assertRaises(RetryTask.MaxRetriesExceededError,
|
|
|
result.get)
|
|
|
self.assertEqual(RetryTask.iterations, 2)
|
|
@@ -318,6 +325,23 @@ class TestCeleryTasks(unittest.TestCase):
|
|
|
publisher = t1.get_publisher()
|
|
|
self.assertTrue(publisher.exchange)
|
|
|
|
|
|
+ def test_context_get(self):
|
|
|
+ request = self.createTaskCls("T1", "c.unittest.t.c.g").request
|
|
|
+ request.foo = 32
|
|
|
+ self.assertEqual(request.get("foo"), 32)
|
|
|
+ self.assertEqual(request.get("bar", 36), 36)
|
|
|
+
|
|
|
+ def test_repr(self):
|
|
|
+ task = self.createTaskCls("T1", "c.unittest.t.repr")
|
|
|
+ self.assertIn("class Task of", repr(task.app.Task))
|
|
|
+
|
|
|
+ def test_after_return(self):
|
|
|
+ task = self.createTaskCls("T1", "c.unittest.t.after_return")()
|
|
|
+ task.backend = Mock()
|
|
|
+ task.request.chord = 123
|
|
|
+ task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
|
|
|
+ task.backend.on_chord_part_return.assert_called_with(task)
|
|
|
+
|
|
|
def test_send_task_sent_event(self):
|
|
|
T1 = self.createTaskCls("T1", "c.unittest.t.t1")
|
|
|
conn = T1.app.broker_connection()
|