Browse Source

Task result/exception seems to be properly stored now!

Ask Solem 16 years ago
parent
commit
5271cfc530

+ 11 - 2
celery/backends/base.py

@@ -12,9 +12,18 @@ class BaseBackend(object):
     def __init__(self):
         pass
 
-    def mark_as_done(self, task_id, result):
+    def store_result(self, task_id, result, status):
         raise NotImplementedError(
-                "Backends must implement the mark_as_done method")
+                "Backends must implement the store_result method")
+
+    def mark_as_done(self, task_id, result):
+        return self.store_result(task_id, result, status="DONE")
+    
+    def mark_as_failure(self, task_id, exc):
+        return self.store_result(task_id, exc, status="FAILURE")
+
+    def mark_as_retry(self, task_id, exc):
+        return self.store_result(task_id, exc, status="RETRY")
 
     def get_status(self, task_id):
         raise NotImplementedError(

+ 3 - 3
celery/backends/cache.py

@@ -15,10 +15,10 @@ class Backend(BaseBackend):
     def _cache_key(self, task_id):
         return "celery-task-meta-%s" % task_id
 
-    def mark_as_done(self, task_id, result):
-        """Mark task as done (executed)."""
+    def store_result(self, task_id, result, status):
+        """Store task result and status."""
         result = self.prepare_result(result)
-        meta = {"status": "DONE", "result": pickle.dumps(result)}
+        meta = {"status": status, "result": pickle.dumps(result)}
         cache.set(self._cache_key(task_id), meta)
 
     def get_status(self, task_id):

+ 2 - 2
celery/backends/database.py

@@ -8,10 +8,10 @@ class Backend(BaseBackend):
         super(Backend, self).__init__(*args, **kwargs)
         self._cache = {}
    
-    def mark_as_done(self, task_id, result):
+    def store_result(self, task_id, result, status):
         """Mark task as done (executed)."""
         result = self.prepare_result(result)
-        return TaskMeta.objects.mark_as_done(task_id, result)
+        return TaskMeta.objects.store_result(task_id, result, status)
 
     def is_done(self, task_id):
         """Returns ``True`` if task with ``task_id`` has been executed."""

+ 3 - 3
celery/backends/tyrant.py

@@ -36,10 +36,10 @@ class Backend(BaseBackend):
     def _cache_key(self, task_id):
         return "celery-task-meta-%s" % task_id
 
-    def mark_as_done(self, task_id, result):
-        """Mark task as done (executed)."""
+    def store_result(self, task_id, result, status):
+        """Store task result and status."""
         result = self.prepare_result(result)
-        meta = {"status": "DONE", "result": pickle.dumps(result)}
+        meta = {"status": status, "result": pickle.dumps(result)}
         get_server()[self._cache_key(task_id)] = serialize(meta)
 
     def get_status(self, task_id):

+ 3 - 3
celery/managers.py

@@ -19,12 +19,12 @@ class TaskManager(models.Manager):
     def delete_expired(self):
         self.get_all_expired().delete()
 
-    def mark_as_done(self, task_id, result):
+    def store_result(self, task_id, result, status):
         task, created = self.get_or_create(task_id=task_id, defaults={
-                                            "status": "DONE",
+                                            "status": status,
                                             "result": result})
         if not created:
-            task.status = "DONE"
+            task.status = status
             task.result = result
             task.save()
 

+ 0 - 2
celery/process.py

@@ -1,5 +1,4 @@
 from UserList import UserList
-from celery.task import mark_as_done
 
 
 class ProcessQueue(UserList):
@@ -23,5 +22,4 @@ class ProcessQueue(UserList):
                         "name": task_name,
                         "id": task_id,
                         "return_value": ret_value})
-                    mark_as_done(task_id, ret_value)
             self.data = []

+ 26 - 15
celery/task.py

@@ -8,6 +8,7 @@ from django.core.cache import cache
 from datetime import timedelta
 from celery.backends import default_backend
 import uuid
+import pickle
 import traceback
 
 
@@ -85,6 +86,11 @@ def mark_as_done(task_id, result):
     return default_backend.mark_as_done(task_id, result)
 
 
+def mark_as_failure(task_id, exc):
+    """Mark task as done (executed)."""
+    return default_backend.mark_as_failure(task_id, exc)
+
+
 def is_done(task_id):
     """Returns ``True`` if task with ``task_id`` has been executed."""
     return default_backend.is_done(task_id)
@@ -138,18 +144,7 @@ class Task(object):
     def __call__(self, *args, **kwargs):
         """The ``__call__`` is called when you do ``Task().run()`` and calls
         the ``run`` method. It also catches any exceptions and logs them."""
-        try:
-            retval = self.run(*args, **kwargs)
-        except Exception, e:
-            logger = self.get_logger(**kwargs)
-            logger.critical("Task got exception %s: %s\n%s" % (
-                                e.__class__, e, traceback.format_exc()))
-            self.handle_exception(e, args, kwargs)
-            if self.auto_retry:
-                self.retry(kwargs["task_id"], args, kwargs)
-            return
-        else:
-            return retval
+        return self.run(*args, **kwargs)
 
     def run(self, *args, **kwargs):
         """The actual task. All subclasses of :class:`Task` must define
@@ -175,9 +170,6 @@ class Task(object):
     def retry(self, task_id, args, kwargs):
         retry_queue.put(self.name, task_id, args, kwargs)
 
-    def handle_exception(self, exception, retry_args, retry_kwargs):
-        pass
-
     @classmethod
     def delay(cls, *args, **kwargs):
         """Delay this task for execution by the ``celery`` daemon(s)."""
@@ -299,3 +291,22 @@ class DeleteExpiredTaskMetaTask(PeriodicTask):
         logger.info("Deleting expired task meta objects...")
         default_backend.cleanup()
 tasks.register(DeleteExpiredTaskMetaTask)
+
+class ExecuteRemoteTask(Task):
+    name = "celery.execute_remote"
+
+    def run(self, ser_callable, fargs, fkwargs, **kwargs):
+        callable_ = pickle.loads(ser_callable)
+        return callable_(*fargs, **fkwargs)
+tasks.register(ExecuteRemoteTask)
+
+
+def execute_remote(func, *args, **kwargs):
+    return ExecuteRemoteTask.delay(pickle.dumps(func), args, kwargs)
+
+class SumTask(Task):
+    name = "celery.sum_task"
+
+    def run(self, *numbers, **kwargs):
+        return sum(numbers)
+tasks.register(SumTask)

+ 2 - 2
celery/tests/test_models.py

@@ -33,8 +33,8 @@ class TestModels(unittest.TestCase):
         self.assertEquals(TaskMeta.objects.get_task(m1.task_id).task_id,
                 m1.task_id)
         self.assertFalse(TaskMeta.objects.is_done(m1.task_id))
-        TaskMeta.objects.mark_as_done(m1.task_id, True)
-        TaskMeta.objects.mark_as_done(m2.task_id, True)
+        TaskMeta.objects.store_result(m1.task_id, True, status="DONE")
+        TaskMeta.objects.store_result(m2.task_id, True, status="DONE")
         self.assertTrue(TaskMeta.objects.is_done(m1.task_id))
         self.assertTrue(TaskMeta.objects.is_done(m2.task_id))
 

+ 0 - 9
celery/tests/test_task.py

@@ -47,15 +47,6 @@ class TestCeleryTasks(unittest.TestCase):
         for arg_name, arg_value in kwargs.items():
             self.assertEquals(task_kwargs.get(arg_name), arg_value)
 
-    def test_raising_task(self):
-        rtask = self.createTaskCls("RaisingTask", "c.unittest.t.rtask")
-        rtask.run = raise_exception
-        sio = StringIO()
-
-        taskinstance = rtask()
-        taskinstance(loglevel=logging.INFO, logfile=sio)
-        self.assertTrue(sio.getvalue().find("Task got exception") != -1)
-       
     def test_incomplete_task_cls(self):
         class IncompleteTask(task.Task):
             name = "c.unittest.t.itask"

+ 17 - 2
celery/worker.py

@@ -6,6 +6,7 @@ from celery.log import setup_logger
 from celery.registry import tasks
 from celery.process import ProcessQueue
 from celery.models import PeriodicTaskMeta
+from celery.task import mark_as_done, mark_as_failure
 import multiprocessing
 import simplejson
 import traceback
@@ -22,6 +23,18 @@ class UnknownTask(Exception):
     ignored."""
 
 
+def jail(task_id, callable_, args, kwargs):
+    try:
+        result = callable_(*args, **kwargs)
+        mark_as_done(task_id, result)
+        print("SUCCESS: %s" % result)
+        return result
+    except Exception, exc:
+        mark_as_failure(task_id, exc)
+        print("FAILURE: %s\n%s" % (exc, traceback.format_exc()))
+        return exc
+
+
 class TaskWrapper(object):
     def __init__(self, task_name, task_id, task_func, args, kwargs):
         self.task_name = task_name
@@ -53,11 +66,13 @@ class TaskWrapper(object):
 
     def execute(self, loglevel, logfile):
         task_func_kwargs = self.extend_kwargs_with_logging(loglevel, logfile)
-        return self.task_func(*self.args, **task_func_kwargs)
+        return jail(self.task_id, [
+                        self.task_func, self.args, task_func_kwargs])
 
     def execute_using_pool(self, pool, loglevel, logfile):
         task_func_kwargs = self.extend_kwargs_with_logging(loglevel, logfile)
-        return pool.apply_async(self.task_func, self.args, task_func_kwargs)
+        return pool.apply_async(jail, [self.task_id, self.task_func,
+                                       self.args, task_func_kwargs])
 
 
 class EventTimer(object):