Browse Source

Tests now passing with the new execution trace

Ask Solem 15 years ago
parent
commit
10e6675b29
3 changed files with 66 additions and 64 deletions
  1. 52 47
      celery/execute.py
  2. 1 1
      celery/result.py
  3. 13 16
      celery/worker/job.py

+ 52 - 47
celery/execute.py

@@ -169,77 +169,82 @@ def apply(task, args, kwargs, **options):
                             if key in supported_keys)
     kwargs.update(extend_with)
 
-    trace = TaskTrace(task.name, task_id, args, kwargs)
+    trace = TaskTrace(task.name, task_id, args, kwargs, task=task)
     retval = trace.execute()
 
     return EagerResult(task_id, retval, trace.status,
                        traceback=trace.strtb)
 
 
-def trace_execution(fun, args, kwargs, on_retry=noop,
-        on_failure=noop, on_success=noop):
-    """Trace the execution of a function, calling the appropiate callback
-    if the function raises retry, an failure or returned successfully."""
-    try:
-        result = fun(*args, **kwargs)
-    except (SystemExit, KeyboardInterrupt):
-        raise
-    except RetryTaskError, exc:
-        type_, value_, tb = sys.exc_info()
-        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
-        return on_retry(exc, type_, tb, strtb)
-    except Exception, exc:
-        type_, value_, tb = sys.exc_info()
-        strtb = "\n".join(traceback.format_exception(type_, value_, tb))
-        return on_failure(exc, type_, tb, strtb)
-    else:
-        return on_success(result)
+class TraceInfo(object):
+    def __init__(self, status="PENDING", retval=None, exc_info=None):
+        self.status = status
+        self.retval = retval
+        self.exc_info = exc_info
+        self.exc_type = None
+        self.exc_value = None
+        self.tb = None
+        self.strtb = None
+        if self.exc_info:
+            self.exc_type, self.exc_value, self.tb = exc_info
+            self.strtb = "\n".join(traceback.format_exception(*exc_info))
+
+    @classmethod
+    def trace(cls, fun, args, kwargs):
+        """Trace the execution of a function, calling the appropiate callback
+        if the function raises retry, an failure or returned successfully."""
+        try:
+            return cls("DONE", retval=fun(*args, **kwargs))
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except RetryTaskError, exc:
+            return cls("RETRY", retval=exc, exc_info=sys.exc_info())
+        except Exception, exc:
+            return cls("FAILURE", retval=exc, exc_info=sys.exc_info())
 
 
 class TaskTrace(object):
 
-    def __init__(self, task_name, task_id, args=None, kwargs=None):
+    def __init__(self, task_name, task_id, args, kwargs, task=None):
         self.task_id = task_id
         self.task_name = task_name
-        self.args = args or []
-        self.kwargs = kwargs or {}
+        self.args = args
+        self.kwargs = kwargs
+        self.task = task or tasks[self.task_name]
         self.status = "PENDING"
         self.strtb = None
+        self._trace_handlers = {"FAILURE": self.handle_failure,
+                                "RETRY": self.handle_retry,
+                                "DONE": self.handle_success}
+
+    def __call__(self):
+        return self.execute()
 
     def execute(self):
-        return self._trace()
+        signals.task_prerun.send(sender=self.task, task_id=self.task_id,
+                                 task=self.task, args=self.args,
+                                 kwargs=self.kwargs)
+        retval = self._trace()
 
-    def _trace(self):
-        # Set self.task for handlers. Can't do it in __init__, because it
-        # has to happen in the pool workers process.
-        task = self.task = tasks[self.task_name]
-
-        # Send pre-run signal.
-        signals.task_prerun.send(sender=task, task_id=self.task_id, task=task,
-                                 args=self.args, kwargs=self.kwargs)
-
-        retval = trace_execution(self.task, self.args, self.kwargs,
-                                 on_success=self.handle_success,
-                                 on_retry=self.handle_retry,
-                                 on_failure=self.handle_failure)
-
-        # Send post-run signal.
-        signals.task_postrun.send(sender=task, task_id=self.task_id,
-                                  task=task, args=self.args,
+        signals.task_postrun.send(sender=self.task, task_id=self.task_id,
+                                  task=self.task, args=self.args,
                                   kwargs=self.kwargs, retval=retval)
         return retval
 
+    def _trace(self):
+        trace = TraceInfo.trace(self.task, self.args, self.kwargs)
+        self.status = trace.status
+        self.strtb = trace.strtb
+        handler = self._trace_handlers[trace.status]
+        return handler(trace.retval, trace.exc_type, trace.tb, trace.strtb)
 
-    def handle_success(self, retval):
+    def handle_success(self, retval, *args):
         """Handle successful execution."""
-        self.status = "DONE"
         self.task.on_success(retval, self.task_id, self.args, self.kwargs)
         return retval
 
     def handle_retry(self, exc, type_, tb, strtb):
         """Handle retry exception."""
-        self.status = "RETRY"
-        self.strtb = strtb
         self.task.on_retry(exc, self.task_id, self.args, self.kwargs)
 
         # Create a simpler version of the RetryTaskError that stringifies
@@ -248,11 +253,11 @@ class TaskTrace(object):
         # guaranteeing pickleability.
         message, orig_exc = exc.args
         expanded_msg = "%s: %s" % (message, str(orig_exc))
-        return ExceptionInfo((type_, type_(expanded_msg, None), tb))
+        return ExceptionInfo((type_,
+                              type_(expanded_msg, None),
+                              tb))
 
     def handle_failure(self, exc, type_, tb, strtb):
         """Handle exception."""
-        self.status = "FAILURE"
-        self.strtb = strtb
         self.task.on_failure(exc, self.task_id, self.args, self.kwargs)
         return ExceptionInfo((type_, exc, tb))

+ 1 - 1
celery/result.py

@@ -322,7 +322,7 @@ class EagerResult(BaseAsyncResult):
         if self.status == "DONE":
             return self.result
         elif self.status == "FAILURE":
-            raise self.result
+            raise self.result.exception
 
     @property
     def result(self):

+ 13 - 16
celery/worker/job.py

@@ -71,9 +71,6 @@ class WorkerTaskTrace(TaskTrace):
         self.loader = kwargs.pop("loader", current_loader)
         super(WorkerTaskTrace, self).__init__(*args, **kwargs)
 
-    def __call__(self, *args, **kwargs):
-        return self.execute_safe()
-
     def execute_safe(self, *args, **kwargs):
         try:
             return self.execute(*args, **kwargs)
@@ -85,12 +82,8 @@ class WorkerTaskTrace(TaskTrace):
             return ExceptionInfo((type_, exc, tb))
 
     def execute(self):
-        # Set self.task for handlers. Can't do it in __init__, because it
-        # has to happen in the pool workers process.
-        task = self.task = tasks[self.task_name]
-
         # Run task loader init handler.
-        self.loader.on_task_init(self.task_id, task)
+        self.loader.on_task_init(self.task_id, self.task)
 
         # Backend process cleanup
         self.backend.process_cleanup()
@@ -102,7 +95,7 @@ class WorkerTaskTrace(TaskTrace):
         finally:
             timer_stat.stop()
 
-    def handle_success(self, retval):
+    def handle_success(self, retval, *args):
         """Handle successful execution.
 
         Saves the result to the current result store (skipped if the task's
@@ -111,7 +104,7 @@ class WorkerTaskTrace(TaskTrace):
         """
         if not self.task.ignore_result:
             self.backend.mark_as_done(self.task_id, retval)
-        return super(WorkerTaskTrace, self).handle_success(retval)
+        return super(WorkerTaskTrace, self).handle_success(retval, *args)
 
     def handle_retry(self, exc, type_, tb, strtb):
         """Handle retry exception."""
@@ -129,6 +122,10 @@ class WorkerTaskTrace(TaskTrace):
                 stored_exc, type_, tb, strtb)
 
 
+def execute_and_trace(*args, **kwargs):
+    return WorkerTaskTrace(*args, **kwargs).execute_safe()
+
+
 class TaskWrapper(object):
     """Class wrapping a task to be passed around and finally
     executed inside of the worker.
@@ -247,11 +244,10 @@ class TaskWrapper(object):
         kwargs.update(extend_with)
         return kwargs
 
-    def _tracer(self, loglevel=None, logfile=None):
+    def _get_tracer_args(self, loglevel=None, logfile=None):
         """Get the :class:`WorkerTaskTrace` tracer for this task."""
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        return WorkerTaskTrace(self.task_name, self.task_id,
-                               self.args, task_func_kwargs)
+        return self.task_name, self.task_id, self.args, task_func_kwargs
 
     def _set_executed_bit(self):
         """Set task as executed to make sure it's not executed again."""
@@ -275,7 +271,8 @@ class TaskWrapper(object):
         # acknowledge task as being processed.
         self.on_ack()
 
-        return self._tracer(loglevel, logfile).execute()
+        tracer = WorkerTaskTrace(*self._get_tracer_args(loglevel, logfile))
+        return tracer.execute()
 
     def execute_using_pool(self, pool, loglevel=None, logfile=None):
         """Like :meth:`execute`, but using the :mod:`multiprocessing` pool.
@@ -292,8 +289,8 @@ class TaskWrapper(object):
         # Make sure task has not already been executed.
         self._set_executed_bit()
 
-        wrapper = self._tracer(loglevel, logfile)
-        return pool.apply_async(wrapper,
+        args = self._get_tracer_args(loglevel, logfile)
+        return pool.apply_async(execute_and_trace, args=args,
                 callbacks=[self.on_success], errbacks=[self.on_failure],
                 on_ack=self.on_ack)