Parcourir la source

100% coverage for celery.app.task

Ask Solem il y a 14 ans
Parent
commit
53b3240ed0
2 fichiers modifiés avec 35 ajouts et 10 suppressions
  1. 4 3
      celery/app/task/__init__.py
  2. 31 7
      celery/tests/test_task/test_task.py

+ 4 - 3
celery/app/task/__init__.py

@@ -36,9 +36,10 @@ class Context(threading.local):
         self.__dict__.clear()
 
     def get(self, key, default=None):
-        if not hasattr(self, key):
+        try:
+            return getattr(self, key)
+        except AttributeError:
             return default
-        return getattr(self, key)
 
 
 class TaskType(type):
@@ -65,7 +66,7 @@ class TaskType(type):
         if not attrs.get("name"):
             try:
                 module_name = sys.modules[task_module].__name__
-            except KeyError:
+            except KeyError:  # pragma: no cover
                 # Fix for manage.py shell_plus (Issue #366).
                 module_name = task_module
             attrs["name"] = '.'.join([module_name, name])

+ 31 - 7
celery/tests/test_task/test_task.py

@@ -1,6 +1,7 @@
 from datetime import datetime, timedelta
 from functools import wraps
 
+from mock import Mock
 from pyparsing import ParseException
 
 from celery import task
@@ -60,15 +61,15 @@ class RetryTask(task.Task):
     max_retries = 3
     iterations = 0
 
-    def run(self, arg1, arg2, kwarg=1, **kwargs):
+    def run(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
         self.__class__.iterations += 1
+        rmax = self.max_retries if max_retries is None else max_retries
 
-        retries = kwargs["task_retries"]
-        if retries >= 3:
+        retries = self.request.retries
+        if care and retries >= rmax:
             return arg1
         else:
-            kwargs.update({"kwarg": kwarg})
-            return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
+            return self.retry(countdown=0, max_retries=max_retries)
 
 
 class RetryTaskNoArgs(task.Task):
@@ -137,6 +138,12 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEqual(result.get(), 0xFF)
         self.assertEqual(RetryTask.iterations, 4)
 
+        RetryTask.max_retries = 3
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF], {"max_retries": 10})
+        self.assertEqual(result.get(), 0xFF)
+        self.assertEqual(RetryTask.iterations, 11)
+
     def test_retry_no_args(self):
         RetryTaskNoArgs.max_retries = 3
         RetryTaskNoArgs.iterations = 0
@@ -183,14 +190,14 @@ class TestTaskRetries(unittest.TestCase):
     def test_max_retries_exceeded(self):
         RetryTask.max_retries = 2
         RetryTask.iterations = 0
-        result = RetryTask.apply([0xFF, 0xFFFF])
+        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
         self.assertRaises(RetryTask.MaxRetriesExceededError,
                           result.get)
         self.assertEqual(RetryTask.iterations, 3)
 
         RetryTask.max_retries = 1
         RetryTask.iterations = 0
-        result = RetryTask.apply([0xFF, 0xFFFF])
+        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
         self.assertRaises(RetryTask.MaxRetriesExceededError,
                           result.get)
         self.assertEqual(RetryTask.iterations, 2)
@@ -318,6 +325,23 @@ class TestCeleryTasks(unittest.TestCase):
         publisher = t1.get_publisher()
         self.assertTrue(publisher.exchange)
 
+    def test_context_get(self):
+        request = self.createTaskCls("T1", "c.unittest.t.c.g").request
+        request.foo = 32
+        self.assertEqual(request.get("foo"), 32)
+        self.assertEqual(request.get("bar", 36), 36)
+
+    def test_repr(self):
+        task = self.createTaskCls("T1", "c.unittest.t.repr")
+        self.assertIn("class Task of", repr(task.app.Task))
+
+    def test_after_return(self):
+        task = self.createTaskCls("T1", "c.unittest.t.after_return")()
+        task.backend = Mock()
+        task.request.chord = 123
+        task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
+        task.backend.on_chord_part_return.assert_called_with(task)
+
     def test_send_task_sent_event(self):
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
         conn = T1.app.broker_connection()