소스 검색

Task retries seems to be working with tests passing. Closes #7

Ask Solem 15 년 전
부모
커밋
41a38bb25f
5개의 변경된 파일64개의 추가작업 그리고 5개의 파일을 삭제
  1. 10 1
      celery/execute.py
  2. 11 2
      celery/task/base.py
  3. 40 0
      celery/tests/test_task.py
  4. 1 0
      celery/tests/test_worker_job.py
  5. 2 2
      celery/worker/job.py

+ 10 - 1
celery/execute.py

@@ -6,6 +6,7 @@ from celery.registry import tasks
 from celery.utils import gen_unique_id
 from functools import partial as curry
 from datetime import datetime, timedelta
+from multiprocessing import get_logger
 import inspect
 
 
@@ -127,7 +128,7 @@ def delay_task(task_name, *args, **kwargs):
     return apply_async(task, args, kwargs)
 
 
-def apply(task, args, kwargs, **ignored):
+def apply(task, args, kwargs, **options):
     """Apply the task locally.
 
     This will block until the task completes, and returns a
@@ -137,11 +138,19 @@ def apply(task, args, kwargs, **ignored):
     args = args or []
     kwargs = kwargs or {}
     task_id = gen_unique_id()
+    retries = options.get("retries", 0)
 
     # If it's a Task class we need to have to instance
     # for it to be callable.
     task = inspect.isclass(task) and task() or task
 
+    kwargs.update({"task_name": task.name,
+                   "task_id": task_id,
+                   "task_retries": retries,
+                   "task_is_eager": True,
+                   "logfile": None,
+                   "loglevel": 0})
+
     try:
         ret_value = task(*args, **kwargs)
         status = "DONE"

+ 11 - 2
celery/task/base.py

@@ -2,7 +2,7 @@ from carrot.connection import DjangoBrokerConnection
 from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.log import setup_logger
-from celery.result import TaskSetResult
+from celery.result import TaskSetResult, EagerResult
 from celery.execute import apply_async, delay_task, apply
 from celery.utils import gen_unique_id, get_full_cls_name
 from datetime import timedelta
@@ -258,11 +258,20 @@ class Task(object):
         options["task_id"] = kwargs.pop("task_id", None)
         options["countdown"] = options.get("countdown",
                                            self.default_retry_delay)
-        exc = exc or MaxRetriesExceededError(
+        exc = exc or self.MaxRetriesExceededError(
                 "Can't retry %s[%s] args:%s kwargs:%s" % (
                     self.name, options["task_id"], args, kwargs))
         if options["retries"] > self.max_retries:
             raise exc
+
+        # If task was executed eagerly using apply(),
+        # then the retry must also be executed eagerly.
+        if kwargs.get("task_is_eager", False):
+            result = self.apply(args=args, kwargs=kwargs, **options)
+            if isinstance(result, EagerResult):
+                return result.get()
+            return result
+
         return self.apply_async(args=args, kwargs=kwargs, **options)
         
     @classmethod

+ 40 - 0
celery/tests/test_task.py

@@ -39,6 +39,46 @@ class RaisingTask(task.Task):
         raise KeyError("foo")
 
 
+class RetryTask(task.Task):
+    max_retries = 3
+    iterations = 0
+
+    def run(self, arg1, arg2, kwarg=1, **kwargs):
+        self.__class__.iterations += 1 
+
+        retries = kwargs["task_retries"]
+        if retries >= 3:
+            return arg1
+        else:
+            kwargs.update({"kwarg": kwargs})
+            return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
+
+
+class TestTaskRetries(unittest.TestCase):
+
+    def test_retry(self):
+        RetryTask.max_retries = 3
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF])
+        self.assertEquals(result.get(), 0xFF)
+        self.assertEquals
+
+    def test_max_retries_exceeded(self):
+        RetryTask.max_retries = 2
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF])
+        self.assertRaises(RetryTask.MaxRetriesExceededError, 
+                          result.get)
+        self.assertEquals(RetryTask.iterations, 3)
+
+        RetryTask.max_retries = 1
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF])
+        self.assertRaises(RetryTask.MaxRetriesExceededError, 
+                          result.get)
+        self.assertEquals(RetryTask.iterations, 2)
+        
+
 class TestCeleryTasks(unittest.TestCase):
 
     def createTaskCls(self, cls_name, task_name=None):

+ 1 - 0
celery/tests/test_worker_job.py

@@ -210,6 +210,7 @@ class TestTaskWrapper(unittest.TestCase):
             "logfile": "some_logfile",
             "loglevel": 10,
             "task_id": tw.task_id,
+            "task_retries": 0,
             "task_name": tw.task_name})
 
     def test_on_failure(self):

+ 2 - 2
celery/worker/job.py

@@ -182,7 +182,7 @@ class TaskWrapper(object):
         task_id = message_data["id"]
         args = message_data["args"]
         kwargs = message_data["kwargs"]
-        retries = message_data["retries"]
+        retries = message_data.get("retries", 0)
 
         # Convert any unicode keys in the keyword arguments to ascii.
         kwargs = dict((key.encode("utf-8"), value)
@@ -204,7 +204,7 @@ class TaskWrapper(object):
         task_func_kwargs = {"logfile": logfile,
                             "loglevel": loglevel,
                             "task_id": self.task_id,
-                            "task_name": self.task_name}
+                            "task_name": self.task_name,
                             "task_retries": self.retries}
         kwargs.update(task_func_kwargs)
         return kwargs