Browse Source

Painting myself into a corner here, but I know the way out, just have to finish it later.

Ask Solem 15 years ago
parent
commit
8744a0f2df
3 changed files with 160 additions and 148 deletions
  1. 56 137
      celery/execute.py
  2. 2 3
      celery/tests/test_worker_job.py
  3. 102 8
      celery/worker/job.py

+ 56 - 137
celery/execute.py

@@ -8,13 +8,9 @@ from celery.utils.functional import curry
 from datetime import datetime, timedelta
 from celery.exceptions import RetryTaskError
 from celery.datastructures import ExceptionInfo
-from celery.backends import default_backend
-from celery.loaders import current_loader
-from celery.monitoring import TaskTimerStats
 from celery import signals
 import sys
 import inspect
-import warnings
 import traceback
 
 
@@ -159,8 +155,7 @@ def apply(task, args, kwargs, **options):
     task_id = gen_unique_id()
     retries = options.get("retries", 0)
 
-    # If it's a Task class we need to have to instance
-    # for it to be callable.
+    # If it's a Task class we need to instantiate it, so it's callable.
     task = inspect.isclass(task) and task() or task
 
     default_kwargs = {"task_name": task.name,
@@ -169,171 +164,95 @@ def apply(task, args, kwargs, **options):
                       "task_is_eager": True,
                       "logfile": None,
                       "loglevel": 0}
-    fun = getattr(task, "run", task)
-    supported_keys = fun_takes_kwargs(fun, default_kwargs)
+    supported_keys = fun_takes_kwargs(task.run, default_kwargs)
     extend_with = dict((key, val) for key, val in default_kwargs.items()
                             if key in supported_keys)
     kwargs.update(extend_with)
 
+    trace = TaskTrace(task.name, task_id, args, kwargs)
+    retval = trace.execute()
+
+    return EagerResult(task_id, retval, trace.status,
+                       traceback=trace.strtb)
+
+
+def trace_execution(fun, args, kwargs, on_retry=noop,
+        on_failure=noop, on_success=noop):
+    """Trace the execution of a function, calling the appropiate callback
+    if the function raises retry, an failure or returned successfully."""
     try:
-        ret_value = task(*args, **kwargs)
-        status = "DONE"
-        strtb = None
+        result = fun(*args, **kwargs)
+    except (SystemExit, KeyboardInterrupt):
+        raise
+    except RetryTaskError, exc:
+        type_, value_, tb = sys.exc_info()
+        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
+        return on_retry(exc, type_, tb, strtb)
     except Exception, exc:
         type_, value_, tb = sys.exc_info()
         strtb = "\n".join(traceback.format_exception(type_, value_, tb))
-        ret_value = exc
-        status = "FAILURE"
-
-    return EagerResult(task_id, ret_value, status, traceback=strtb)
-
-
-class ExecuteWrapper(object):
-    """Wraps the task in a jail, which catches all exceptions, and
-    saves the status and result of the task execution to the task
-    meta backend.
-
-    If the call was successful, it saves the result to the task result
-    backend, and sets the task status to ``"DONE"``.
-
-    If the call raises :exc:`celery.exceptions.RetryTaskError`, it extracts
-    the original exception, uses that as the result and sets the task status
-    to ``"RETRY"``.
+        return on_failure(exc, type_, tb, strtb)
+    else:
+        return on_success(result)
 
-    If the call results in an exception, it saves the exception as the task
-    result, and sets the task status to ``"FAILURE"``.
 
-    :param task_name: The name of the task to execute.
-    :param task_id: The unique id of the task.
-    :param args: List of positional args to pass on to the function.
-    :param kwargs: Keyword arguments mapping to pass on to the function.
-
-    :returns: the function return value on success, or
-        the exception instance on failure.
-
-    """
+class TaskTrace(object):
 
     def __init__(self, task_name, task_id, args=None, kwargs=None):
         self.task_id = task_id
         self.task_name = task_name
         self.args = args or []
         self.kwargs = kwargs or {}
-
-    def __call__(self, *args, **kwargs):
-        return self.execute_safe()
-
-    def execute_safe(self, *args, **kwargs):
-        try:
-            return self.execute(*args, **kwargs)
-        except Exception, exc:
-            type_, value_, tb = sys.exc_info()
-            exc = default_backend.prepare_exception(exc)
-            warnings.warn("Exception happend outside of task body: %s: %s" % (
-                str(exc.__class__), str(exc)))
-            return ExceptionInfo((type_, exc, tb))
+        self.status = "PENDING"
+        self.strtb = None
 
     def execute(self):
-        # Convenience variables
-        task_id = self.task_id
-        task_name = self.task_name
-        args = self.args
-        kwargs = self.kwargs
-        fun = tasks[task_name]
-        self.fun = fun # Set fun for handlers.
+        return self._trace()
 
-        # Run task loader init handler.
-        current_loader.on_task_init(task_id, fun)
-
-        # Backend process cleanup
-        default_backend.process_cleanup()
+    def _trace(self):
+        # Set self.task for handlers. Can't do it in __init__, because it
+        # has to happen in the pool workers process.
+        task = self.task = tasks[self.task_name]
 
         # Send pre-run signal.
-        signals.task_prerun.send(sender=fun, task_id=task_id, task=fun,
-                                 args=args, kwargs=kwargs)
-
-        retval = None
-        timer_stat = TaskTimerStats.start(task_id, task_name, args, kwargs)
-        try:
-            result = fun(*args, **kwargs)
-        except (SystemExit, KeyboardInterrupt):
-            raise
-        except RetryTaskError, exc:
-            retval = self.handle_retry(exc, sys.exc_info())
-        except Exception, exc:
-            retval = self.handle_failure(exc, sys.exc_info())
-        else:
-            retval = self.handle_success(result)
-        finally:
-            timer_stat.stop()
+        signals.task_prerun.send(sender=task, task_id=self.task_id, task=task,
+                                 args=self.args, kwargs=self.kwargs)
 
-        # Send post-run signal.
-        signals.task_postrun.send(sender=fun, task_id=task_id, task=fun,
-                                  args=args, kwargs=kwargs, retval=retval)
+        retval = trace_execution(self.task, self.args, self.kwargs,
+                                 on_success=self.handle_success,
+                                 on_retry=self.handle_retry,
+                                 on_failure=self.handle_failure)
 
+        # Send post-run signal.
+        signals.task_postrun.send(sender=task, task_id=self.task_id,
+                                  task=task, args=self.args,
+                                  kwargs=self.kwargs, retval=retval)
         return retval
 
-    def handle_success(self, retval):
-        """Handle successful execution.
-
-        Saves the result to the current result store (skipped if the callable
-            has a ``ignore_result`` attribute set to ``True``).
-
-        If the callable has a ``on_success`` function, it as called with
-        ``retval`` as argument.
-
-        :param retval: The return value.
-
-        """
-        if not getattr(self.fun, "ignore_result", False):
-            default_backend.mark_as_done(self.task_id, retval)
-
-        # Run success handler last to be sure the status is saved.
-        success_handler = getattr(self.fun, "on_success", noop)
-        success_handler(retval, self.task_id, self.args, self.kwargs)
 
+    def handle_success(self, retval):
+        """Handle successful execution."""
+        self.status = "DONE"
+        self.task.on_success(retval, self.task_id, self.args, self.kwargs)
         return retval
 
-    def handle_retry(self, exc, exc_info):
+    def handle_retry(self, exc, type_, tb, strtb):
         """Handle retry exception."""
-        ### Task is to be retried.
-        type_, value_, tb = exc_info
-        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
-
-        # RetryTaskError stores both a small message describing the retry
-        # and the original exception.
-        message, orig_exc = exc.args
-        default_backend.mark_as_retry(self.task_id, orig_exc, strtb)
+        self.status = "RETRY"
+        self.strtb = strtb
+        self.task.on_retry(exc, self.task_id, self.args, self.kwargs)
 
         # Create a simpler version of the RetryTaskError that stringifies
         # the original exception instead of including the exception instance.
         # This is for reporting the retry in logs, e-mail etc, while
         # guaranteeing pickleability.
+        message, orig_exc = exc.args
         expanded_msg = "%s: %s" % (message, str(orig_exc))
-        retval = ExceptionInfo((type_,
-                                type_(expanded_msg, None),
-                                tb))
+        return ExceptionInfo((type_, type_(expanded_msg, None), tb))
 
-        # Run retry handler last to be sure the status is saved.
-        retry_handler = getattr(self.fun, "on_retry", noop)
-        retry_handler(exc, self.task_id, self.args, self.kwargs)
-
-        return retval
-
-    def handle_failure(self, exc, exc_info):
+    def handle_failure(self, exc, type_, tb, strtb):
         """Handle exception."""
-        ### Task ended in failure.
-        type_, value_, tb = exc_info
-        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
-
-        # mark_as_failure returns an exception that is guaranteed to
-        # be pickleable.
-        stored_exc = default_backend.mark_as_failure(self.task_id, exc, strtb)
-
-        # wrap exception info + traceback and return it to caller.
-        retval = ExceptionInfo((type_, stored_exc, tb))
-
-        # Run error handler last to be sure the status is stored.
-        error_handler = getattr(self.fun, "on_failure", noop)
-        error_handler(stored_exc, self.task_id, self.args, self.kwargs)
-
-        return retval
+        self.status = "FAILURE"
+        self.strtb = strtb
+        self.task.on_failure(exc, self.task_id, self.args, self.kwargs)
+        return ExceptionInfo((type_, exc, tb))

+ 2 - 3
celery/tests/test_worker_job.py

@@ -1,8 +1,7 @@
 # -*- coding: utf-8 -*-
 import sys
 import unittest
-from celery.execute import ExecuteWrapper
-from celery.worker.job import TaskWrapper
+from celery.worker.job import WorkerTaskTrace, TaskWrapper
 from celery.datastructures import ExceptionInfo
 from celery.models import TaskMeta
 from celery.registry import tasks, NotRegistered
@@ -21,7 +20,7 @@ some_kwargs_scratchpad = {}
 
 
 def jail(task_id, task_name, args, kwargs):
-    return ExecuteWrapper(task_name, task_id, args, kwargs)()
+    return WorkerTaskTrace(task_name, task_id, args, kwargs)()
 
 
 def on_ack():

+ 102 - 8
celery/worker/job.py

@@ -5,11 +5,17 @@ Jobs Executable by the Worker Server.
 """
 from celery.registry import tasks
 from celery.exceptions import NotRegistered
-from celery.execute import ExecuteWrapper
+from celery.execute import TaskTrace
 from celery.utils import noop, fun_takes_kwargs
 from celery.log import get_default_logger
+from celery.monitoring import TaskTimerStats
 from django.core.mail import mail_admins
+from celery.loaders import current_loader
+from celery.backends import default_backend
+from celery.datastructures import ExceptionInfo
+import sys
 import socket
+import warnings
 
 # pep8.py borks on a inline signature separator and
 # says "trailing whitespace" ;)
@@ -35,6 +41,94 @@ class AlreadyExecutedError(Exception):
     world-wide state."""
 
 
+class WorkerTaskTrace(TaskTrace):
+    """Wraps the task in a jail, catches all exceptions, and
+    saves the status and result of the task execution to the task
+    meta backend.
+
+    If the call was successful, it saves the result to the task result
+    backend, and sets the task status to ``"DONE"``.
+
+    If the call raises :exc:`celery.exceptions.RetryTaskError`, it extracts
+    the original exception, uses that as the result and sets the task status
+    to ``"RETRY"``.
+
+    If the call results in an exception, it saves the exception as the task
+    result, and sets the task status to ``"FAILURE"``.
+
+    :param task_name: The name of the task to execute.
+    :param task_id: The unique id of the task.
+    :param args: List of positional args to pass on to the function.
+    :param kwargs: Keyword arguments mapping to pass on to the function.
+
+    :returns: the function return value on success, or
+        the exception instance on failure.
+
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.backend = kwargs.pop("backend", default_backend)
+        self.loader = kwargs.pop("loader", current_loader)
+        super(WorkerTaskTrace, self).__init__(*args, **kwargs)
+
+    def __call__(self, *args, **kwargs):
+        return self.execute_safe()
+
+    def execute_safe(self, *args, **kwargs):
+        try:
+            return self.execute(*args, **kwargs)
+        except Exception, exc:
+            type_, value_, tb = sys.exc_info()
+            exc = self.backend.prepare_exception(exc)
+            warnings.warn("Exception happend outside of task body: %s: %s" % (
+                str(exc.__class__), str(exc)))
+            return ExceptionInfo((type_, exc, tb))
+
+    def execute(self):
+        # Set self.task for handlers. Can't do it in __init__, because it
+        # has to happen in the pool workers process.
+        task = self.task = tasks[self.task_name]
+
+        # Run task loader init handler.
+        self.loader.on_task_init(self.task_id, task)
+
+        # Backend process cleanup
+        self.backend.process_cleanup()
+
+        timer_stat = TaskTimerStats.start(self.task_id, self.task_name,
+                                          self.args, self.kwargs)
+        try:
+            return self._trace()
+        finally:
+            timer_stat.stop()
+
+    def handle_success(self, retval):
+        """Handle successful execution.
+
+        Saves the result to the current result store (skipped if the task's
+            ``ignore_result`` attribute is set to ``True``).
+
+        """
+        if not self.task.ignore_result:
+            self.backend.mark_as_done(self.task_id, retval)
+        return super(WorkerTaskTrace, self).handle_success(retval)
+
+    def handle_retry(self, exc, type_, tb, strtb):
+        """Handle retry exception."""
+        message, orig_exc = exc.args
+        self.backend.mark_as_retry(self.task_id, orig_exc, strtb)
+        return super(WorkerTaskTrace, self).handle_retry(exc, type_,
+                                                         tb, strtb)
+
+    def handle_failure(self, exc, type_, tb, strtb):
+        """Handle exception."""
+        # mark_as_failure returns an exception that is guaranteed to
+        # be pickleable.
+        stored_exc = self.backend.mark_as_failure(self.task_id, exc, strtb)
+        return super(WorkerTaskTrace, self).handle_failure(
+                stored_exc, type_, tb, strtb)
+
+
 class TaskWrapper(object):
     """Class wrapping a task to be passed around and finally
     executed inside of the worker.
@@ -153,11 +247,11 @@ class TaskWrapper(object):
         kwargs.update(extend_with)
         return kwargs
 
-    def _executeable(self, loglevel=None, logfile=None):
-        """Get the :class:`celery.execute.ExecuteWrapper` for this task."""
+    def _tracer(self, loglevel=None, logfile=None):
+        """Get the :class:`WorkerTaskTrace` tracer for this task."""
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        return ExecuteWrapper(self.task_name, self.task_id,
-                              self.args, task_func_kwargs)
+        return WorkerTaskTrace(self.task_name, self.task_id,
+                               self.args, task_func_kwargs)
 
     def _set_executed_bit(self):
         """Set task as executed to make sure it's not executed again."""
@@ -168,7 +262,7 @@ class TaskWrapper(object):
         self.executed = True
 
     def execute(self, loglevel=None, logfile=None):
-        """Execute the task in a :class:`celery.execute.ExecuteWrapper`.
+        """Execute the task in a :class:`WorkerTaskTrace`.
 
         :keyword loglevel: The loglevel used by the task.
 
@@ -181,7 +275,7 @@ class TaskWrapper(object):
         # acknowledge task as being processed.
         self.on_ack()
 
-        return self._executeable(loglevel, logfile)()
+        return self._tracer(loglevel, logfile).execute()
 
     def execute_using_pool(self, pool, loglevel=None, logfile=None):
         """Like :meth:`execute`, but using the :mod:`multiprocessing` pool.
@@ -198,7 +292,7 @@ class TaskWrapper(object):
         # Make sure task has not already been executed.
         self._set_executed_bit()
 
-        wrapper = self._executeable(loglevel, logfile)
+        wrapper = self._tracer(loglevel, logfile)
         return pool.apply_async(wrapper,
                 callbacks=[self.on_success], errbacks=[self.on_failure],
                 on_ack=self.on_ack)