Browse Source

Refactor jail() into celery.execute.ExecuteWrapper

Ask Solem 15 years ago
parent
commit
0b86c43bb5
4 changed files with 200 additions and 119 deletions
  1. 13 0
      celery/exceptions.py
  2. 151 1
      celery/execute.py
  3. 5 1
      celery/tests/test_worker_job.py
  4. 31 117
      celery/worker/job.py

+ 13 - 0
celery/exceptions.py

@@ -0,0 +1,13 @@
+"""celery.exceptions"""
+
+
+class MaxRetriesExceededError(Exception):
+    """The tasks max restart limit has been exceeded."""
+
+
+class RetryTaskError(Exception):
+    """The task is to be retried later."""
+
+    def __init__(self, message, exc, *args, **kwargs):
+        self.exc = exc
+        super(RetryTaskError, self).__init__(message, exc, *args, **kwargs)

+ 151 - 1
celery/execute.py

@@ -3,10 +3,16 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.result import AsyncResult, EagerResult
 from celery.messaging import TaskPublisher
 from celery.registry import tasks
-from celery.utils import gen_unique_id
+from celery.utils import gen_unique_id, noop
 from functools import partial as curry
 from datetime import datetime, timedelta
 from multiprocessing import get_logger
+from celery.exceptions import RetryTaskError
+from celery.datastructures import ExceptionInfo
+from celery.backends import default_backend
+from celery.loaders import current_loader
+from celery.monitoring import TaskTimerStats
+from celery import signals
 import sys
 import traceback
 import inspect
@@ -164,3 +170,147 @@ def apply(task, args, kwargs, **options):
         status = "FAILURE"
 
     return EagerResult(task_id, ret_value, status, traceback=strtb)
+
+
+class ExecuteWrapper(object):
+    """Wraps the task in a jail, which catches all exceptions, and
+    saves the status and result of the task execution to the task
+    meta backend.
+    
+    If the call was successful, it saves the result to the task result
+    backend, and sets the task status to ``"DONE"``.
+
+    If the call raises :exc:`celery.task.base.RetryTaskError`, it extracts
+    the original exception, uses that as the result and sets the task status
+    to ``"RETRY"``.
+
+    If the call results in an exception, it saves the exception as the task
+    result, and sets the task status to ``"FAILURE"``.
+
+   
+    :param fun: Callable object to execute.
+    :param task_id: The unique id of the task.
+    :param task_name: Name of the task.
+    :param args: List of positional args to pass on to the function.
+    :param kwargs: Keyword arguments mapping to pass on to the function.
+
+    :returns: the function return value on success, or
+        the exception instance on failure.
+
+    
+    """
+
+    def __init__(self, fun, task_id, task_name, args=None, kwargs=None):
+        self.fun = fun
+        self.task_id = task_id
+        self.task_name = task_name
+        self.args = args or []
+        self.kwargs = kwargs or {}
+
+    def __call__(self, *args, **kwargs):
+        return self.execute()
+
+    def execute(self):
+        # Convenience variables
+        fun = self.fun
+        task_id = self.task_id
+        task_name = self.task_name
+        args = self.args
+        kwargs = self.kwargs
+
+        # Run task loader init handler.
+        current_loader.on_task_init(task_id, fun)
+
+        # Backend process cleanup
+        default_backend.process_cleanup()
+      
+        # Send pre-run signal.
+        signals.task_prerun.send(sender=fun, task_id=task_id, task=fun,
+                                 args=args, kwargs=kwargs)
+
+        retval = None
+        timer_stat = TaskTimerStats.start(task_id, task_name, args, kwargs)
+        try:
+            result = fun(*args, **kwargs)
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except RetryTaskError, exc:
+            retval = self.handle_retry(exc, sys.exc_info())
+        except Exception, exc:
+            retval = self.handle_failure(exc, sys.exc_info())
+        else:
+            retval = self.handle_success(result)
+        finally:
+            timer_stat.stop()
+
+        # Send post-run signal.
+        signals.task_postrun.send(sender=fun, task_id=task_id, task=fun,
+                                  args=args, kwargs=kwargs, retval=retval)
+
+        return retval
+
+    def handle_success(self, retval):
+        """Handle successful execution.
+
+        Saves the result to the current result store (skipped if the callable
+            has a ``ignore_result`` attribute set to ``True``).
+
+        If the callable has a ``on_success`` function, it as called with
+        ``retval`` as argument.
+
+        :param retval: The return value.
+
+        """
+        if not getattr(self.fun, "ignore_result", False):
+            default_backend.mark_as_done(self.task_id, retval)
+
+        # Run success handler last to be sure the status is saved.
+        success_handler = getattr(self.fun, "on_success", noop)
+        success_handler(retval)
+
+        return retval
+
+    def handle_retry(self, exc, exc_info):
+        """Handle retry exception."""
+        ### Task is to be retried.
+        type_, value_, tb = exc_info
+        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
+
+        # RetryTaskError stores both a small message describing the retry
+        # and the original exception.
+        message, orig_exc = exc.args
+        default_backend.mark_as_retry(self.task_id, orig_exc, strtb)
+
+        # Create a simpler version of the RetryTaskError that stringifies
+        # the original exception instead of including the exception instance.
+        # This is for reporting the retry in logs, e-mail etc, while
+        # guaranteeing pickleability.
+        expanded_msg = "%s: %s" % (message, str(orig_exc))
+        retval = ExceptionInfo((type_,
+                                type_(expanded_msg, None),
+                                tb))
+
+        # Run retry handler last to be sure the status is saved.
+        retry_handler = getattr(self.fun, "on_retry", noop)
+        retry_handler(exc)
+
+        return retval
+
+    def handle_failure(self, exc, exc_info):
+        """Handle exception."""
+        ### Task ended in failure.
+        type_, value_, tb = exc_info
+        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
+
+        # mark_as_failure returns an exception that is guaranteed to
+        # be pickleable.
+        stored_exc = default_backend.mark_as_failure(self.task_id, exc, strtb)
+
+        # wrap exception info + traceback and return it to caller.
+        retval = ExceptionInfo((type_, stored_exc, tb))
+
+        # Run error handler last to be sure the status is stored.
+        error_handler = getattr(self.fun, "on_failure", noop)
+        error_handler(stored_exc)
+
+        return retval

+ 5 - 1
celery/tests/test_worker_job.py

@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 import sys
 import unittest
-from celery.worker.job import jail
+from celery.execute import ExecuteWrapper
 from celery.worker.job import TaskWrapper
 from celery.datastructures import ExceptionInfo
 from celery.models import TaskMeta
@@ -18,6 +18,10 @@ import logging
 scratch = {"ACK": False}
 
 
+def jail(task_id, task_name, fun, args, kwargs):
+    return ExecuteWrapper(fun, task_id, task_name, args, kwargs)()
+
+
 def on_ack():
     scratch["ACK"] = True
 

+ 31 - 117
celery/worker/job.py

@@ -4,15 +4,10 @@ Jobs Executable by the Worker Server.
 
 """
 from celery.registry import tasks, NotRegistered
-from celery.datastructures import ExceptionInfo
-from celery.backends import default_backend
-from celery.loaders import current_loader
+from celery.execute import ExecuteWrapper
+from celery.utils import noop
 from django.core.mail import mail_admins
-from celery.monitoring import TaskTimerStats
-from celery.task.base import RetryTaskError
-from celery import signals
 import multiprocessing
-import traceback
 import socket
 import sys
 
@@ -35,89 +30,6 @@ celeryd at %%(hostname)s.
 """ % {"EMAIL_SIGNATURE_SEP": EMAIL_SIGNATURE_SEP}
 
 
-def jail(task_id, task_name, func, args, kwargs):
-    """Wraps the task in a jail, which catches all exceptions, and
-    saves the status and result of the task execution to the task
-    meta backend.
-
-    If the call was successful, it saves the result to the task result
-    backend, and sets the task status to ``"DONE"``.
-
-    If the call raises :exc:`celery.task.base.RetryTaskError`, it extracts
-    the original exception, uses that as the result and sets the task status
-    to ``"RETRY"``.
-
-    If the call results in an exception, it saves the exception as the task
-    result, and sets the task status to ``"FAILURE"``.
-
-    :param task_id: The id of the task.
-    :param task_name: The name of the task.
-    :param func: Callable object to execute.
-    :param args: List of positional args to pass on to the function.
-    :param kwargs: Keyword arguments mapping to pass on to the function.
-
-    :returns: the function return value on success, or
-        the exception instance on failure.
-
-    """
-    ignore_result = getattr(func, "ignore_result", False)
-    timer_stat = TaskTimerStats.start(task_id, task_name, args, kwargs)
-
-    # Run task loader init handler.
-    current_loader.on_task_init(task_id, func)
-    signals.task_prerun.send(sender=func, task_id=task_id, task=func,
-                             args=args, kwargs=kwargs)
-
-    # Backend process cleanup
-    default_backend.process_cleanup()
-
-    try:
-        result = func(*args, **kwargs)
-    except (SystemExit, KeyboardInterrupt):
-        raise
-    except RetryTaskError, exc:
-        ### Task is to be retried.
-        type_, value_, tb = sys.exc_info()
-        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
-
-        # RetryTaskError stores both a small message describing the retry
-        # and the original exception.
-        message, orig_exc = exc.args
-        default_backend.mark_as_retry(task_id, orig_exc, strtb)
-
-        # Create a simpler version of the RetryTaskError that stringifies
-        # the original exception instead of including the exception instance.
-        # This is for reporting the retry in logs, e-mail etc, while
-        # guaranteeing pickleability.
-        expanded_msg = "%s: %s" % (message, str(orig_exc))
-        retval = ExceptionInfo((type_,
-                                type_(expanded_msg, None),
-                                tb))
-    except Exception, exc:
-        ### Task ended in failure.
-        type_, value_, tb = sys.exc_info()
-        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
-
-        # mark_as_failure returns an exception that is guaranteed to
-        # be pickleable.
-        stored_exc = default_backend.mark_as_failure(task_id, exc, strtb)
-
-        # wrap exception info + traceback and return it to caller.
-        retval = ExceptionInfo((type_, stored_exc, tb))
-    else:
-        ### Task executed successfully.
-        if not ignore_result:
-            default_backend.mark_as_done(task_id, result)
-        retval = result
-    finally:
-        timer_stat.stop()
-
-    signals.task_postrun.send(sender=func, task_id=task_id, task=func,
-                              args=args, kwargs=kwargs, retval=retval)
-
-    return retval
-
-
 class TaskWrapper(object):
     """Class wrapping a task to be run.
 
@@ -166,7 +78,7 @@ class TaskWrapper(object):
     fail_email_body = TASK_FAIL_EMAIL_BODY
 
     def __init__(self, task_name, task_id, task_func, args, kwargs,
-            on_ack=None, retries=0, **opts):
+            on_ack=noop, retries=0, **opts):
         self.task_name = task_name
         self.task_id = task_id
         self.task_func = task_func
@@ -229,22 +141,43 @@ class TaskWrapper(object):
         kwargs.update(task_func_kwargs)
         return kwargs
 
+    def _executeable(self, loglevel=None, logfile=None):
+        """Get the :class:`celery.execute.ExecuteWrapper` for this task."""
+        task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
+        return ExecuteWrapper(self.task_func, self.task_id, self.task_name,
+                              self.args, task_func_kwargs)
+
     def execute(self, loglevel=None, logfile=None):
-        """Execute the task in a :func:`jail` and store return value
-        and status in the task meta backend.
+        """Execute the task in a :class:`celery.execute.ExecuteWrapper`.
 
         :keyword loglevel: The loglevel used by the task.
 
         :keyword logfile: The logfile used by the task.
 
         """
-        task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         # acknowledge task as being processed.
-        if self.on_ack:
-            self.on_ack()
-        return jail(self.task_id, self.task_name, self.task_func,
-                    self.args, task_func_kwargs)
+        self.on_ack()
+
+        return self._executeable(loglevel, logfile)()
+
+    def execute_using_pool(self, pool, loglevel=None, logfile=None):
+        """Like :meth:`execute`, but using the :mod:`multiprocessing` pool.
+
+        :param pool: A :class:`multiprocessing.Pool` instance.
+
+        :keyword loglevel: The loglevel used by the task.
+
+        :keyword logfile: The logfile used by the task.
 
+        :returns :class:`multiprocessing.AsyncResult` instance.
+
+        """
+        wrapper = self._executeable(loglevel, logfile)
+        return pool.apply_async(wrapper,
+                callbacks=[self.on_success], errbacks=[self.on_failure],
+                on_ack=self.on_ack,
+                meta={"task_id": self.task_id, "task_name": self.task_name})
+    
     def on_success(self, ret_value, meta):
         """The handler used if the task was successfully processed (
         without raising an exception)."""
@@ -281,22 +214,3 @@ class TaskWrapper(object):
             body = self.fail_email_body.strip() % context
             mail_admins(subject, body, fail_silently=True)
 
-    def execute_using_pool(self, pool, loglevel=None, logfile=None):
-        """Like :meth:`execute`, but using the :mod:`multiprocessing` pool.
-
-        :param pool: A :class:`multiprocessing.Pool` instance.
-
-        :keyword loglevel: The loglevel used by the task.
-
-        :keyword logfile: The logfile used by the task.
-
-        :returns :class:`multiprocessing.AsyncResult` instance.
-
-        """
-        task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        jail_args = [self.task_id, self.task_name, self.task_func,
-                     self.args, task_func_kwargs]
-        return pool.apply_async(jail, args=jail_args,
-                callbacks=[self.on_success], errbacks=[self.on_failure],
-                on_ack=self.on_ack,
-                meta={"task_id": self.task_id, "task_name": self.task_name})