Browse Source

Move task tracing to celery.execute.trace

Ask Solem 15 years ago
parent
commit
f65ffc1793
3 changed files with 114 additions and 127 deletions
  1. 6 103
      celery/execute/__init__.py
  2. 94 0
      celery/execute/trace.py
  3. 14 24
      celery/worker/job.py

+ 6 - 103
celery/execute.py → celery/execute/__init__.py

@@ -1,22 +1,15 @@
-import sys
-import inspect
-import traceback
 from datetime import datetime, timedelta
 
-from billiard.utils.functional import curry
-
 from celery import conf
-from celery import signals
-from celery.utils import gen_unique_id, noop, fun_takes_kwargs, mattrgetter
+from celery.utils import gen_unique_id, fun_takes_kwargs, mattrgetter
 from celery.result import AsyncResult, EagerResult
+from celery.execute.trace import TaskTrace
 from celery.registry import tasks
 from celery.messaging import TaskPublisher, with_connection
-from celery.exceptions import RetryTaskError
-from celery.datastructures import ExceptionInfo
 
 extract_exec_options = mattrgetter("routing_key", "exchange",
-                                   "immediate", "mandatory",
-                                   "priority", "serializer")
+                                    "immediate", "mandatory",
+                                    "priority", "serializer")
 
 
 @with_connection
@@ -130,8 +123,7 @@ def apply(task, args, kwargs, **options):
     task_id = gen_unique_id()
     retries = options.get("retries", 0)
 
-    # If it's a Task class we need to instantiate it, so it's callable.
-    task = inspect.isclass(task) and task() or task
+    task = tasks[task.name] # Make sure we get the instance, not class.
 
     default_kwargs = {"task_name": task.name,
                       "task_id": task_id,
@@ -146,93 +138,4 @@ def apply(task, args, kwargs, **options):
 
     trace = TaskTrace(task.name, task_id, args, kwargs, task=task)
     retval = trace.execute()
-
-    return EagerResult(task_id, retval, trace.status,
-                       traceback=trace.strtb)
-
-
-class TraceInfo(object):
-    def __init__(self, status="PENDING", retval=None, exc_info=None):
-        self.status = status
-        self.retval = retval
-        self.exc_info = exc_info
-        self.exc_type = None
-        self.exc_value = None
-        self.tb = None
-        self.strtb = None
-        if self.exc_info:
-            self.exc_type, self.exc_value, self.tb = exc_info
-            self.strtb = "\n".join(traceback.format_exception(*exc_info))
-
-    @classmethod
-    def trace(cls, fun, args, kwargs):
-        """Trace the execution of a function, calling the appropiate callback
-        if the function raises retry, an failure or returned successfully."""
-        try:
-            return cls("SUCCESS", retval=fun(*args, **kwargs))
-        except (SystemExit, KeyboardInterrupt):
-            raise
-        except RetryTaskError, exc:
-            return cls("RETRY", retval=exc, exc_info=sys.exc_info())
-        except Exception, exc:
-            return cls("FAILURE", retval=exc, exc_info=sys.exc_info())
-
-
-class TaskTrace(object):
-
-    def __init__(self, task_name, task_id, args, kwargs, task=None):
-        self.task_id = task_id
-        self.task_name = task_name
-        self.args = args
-        self.kwargs = kwargs
-        self.task = task or tasks[self.task_name]
-        self.status = "PENDING"
-        self.strtb = None
-        self._trace_handlers = {"FAILURE": self.handle_failure,
-                                "RETRY": self.handle_retry,
-                                "SUCCESS": self.handle_success}
-
-    def __call__(self):
-        return self.execute()
-
-    def execute(self):
-        signals.task_prerun.send(sender=self.task, task_id=self.task_id,
-                                 task=self.task, args=self.args,
-                                 kwargs=self.kwargs)
-        retval = self._trace()
-
-        signals.task_postrun.send(sender=self.task, task_id=self.task_id,
-                                  task=self.task, args=self.args,
-                                  kwargs=self.kwargs, retval=retval)
-        return retval
-
-    def _trace(self):
-        trace = TraceInfo.trace(self.task, self.args, self.kwargs)
-        self.status = trace.status
-        self.strtb = trace.strtb
-        handler = self._trace_handlers[trace.status]
-        return handler(trace.retval, trace.exc_type, trace.tb, trace.strtb)
-
-    def handle_success(self, retval, *args):
-        """Handle successful execution."""
-        self.task.on_success(retval, self.task_id, self.args, self.kwargs)
-        return retval
-
-    def handle_retry(self, exc, type_, tb, strtb):
-        """Handle retry exception."""
-        self.task.on_retry(exc, self.task_id, self.args, self.kwargs)
-
-        # Create a simpler version of the RetryTaskError that stringifies
-        # the original exception instead of including the exception instance.
-        # This is for reporting the retry in logs, e-mail etc, while
-        # guaranteeing pickleability.
-        message, orig_exc = exc.args
-        expanded_msg = "%s: %s" % (message, str(orig_exc))
-        return ExceptionInfo((type_,
-                              type_(expanded_msg, None),
-                              tb))
-
-    def handle_failure(self, exc, type_, tb, strtb):
-        """Handle exception."""
-        self.task.on_failure(exc, self.task_id, self.args, self.kwargs)
-        return ExceptionInfo((type_, exc, tb))
+    return EagerResult(task_id, retval, trace.status, traceback=trace.strtb)

+ 94 - 0
celery/execute/trace.py

@@ -0,0 +1,94 @@
+import sys
+import traceback
+
+from celery import signals
+from celery.registry import tasks
+from celery.exceptions import RetryTaskError
+from celery.datastructures import ExceptionInfo
+
+
+class TraceInfo(object):
+    def __init__(self, status="PENDING", retval=None, exc_info=None):
+        self.status = status
+        self.retval = retval
+        self.exc_info = exc_info
+        self.exc_type = None
+        self.exc_value = None
+        self.tb = None
+        self.strtb = None
+        if self.exc_info:
+            self.exc_type, self.exc_value, self.tb = exc_info
+            self.strtb = "\n".join(traceback.format_exception(*exc_info))
+
+    @classmethod
+    def trace(cls, fun, args, kwargs):
+        """Trace the execution of a function, calling the appropiate callback
+        if the function raises retry, an failure or returned successfully."""
+        try:
+            return cls("SUCCESS", retval=fun(*args, **kwargs))
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except RetryTaskError, exc:
+            return cls("RETRY", retval=exc, exc_info=sys.exc_info())
+        except Exception, exc:
+            return cls("FAILURE", retval=exc, exc_info=sys.exc_info())
+
+
+class TaskTrace(object):
+
+    def __init__(self, task_name, task_id, args, kwargs, task=None):
+        self.task_id = task_id
+        self.task_name = task_name
+        self.args = args
+        self.kwargs = kwargs
+        self.task = task or tasks[self.task_name]
+        self.status = "PENDING"
+        self.strtb = None
+        self._trace_handlers = {"FAILURE": self.handle_failure,
+                                "RETRY": self.handle_retry,
+                                "SUCCESS": self.handle_success}
+
+    def __call__(self):
+        return self.execute()
+
+    def execute(self):
+        signals.task_prerun.send(sender=self.task, task_id=self.task_id,
+                                 task=self.task, args=self.args,
+                                 kwargs=self.kwargs)
+        retval = self._trace()
+
+        signals.task_postrun.send(sender=self.task, task_id=self.task_id,
+                                  task=self.task, args=self.args,
+                                  kwargs=self.kwargs, retval=retval)
+        return retval
+
+    def _trace(self):
+        trace = TraceInfo.trace(self.task, self.args, self.kwargs)
+        self.status = trace.status
+        self.strtb = trace.strtb
+        handler = self._trace_handlers[trace.status]
+        return handler(trace.retval, trace.exc_type, trace.tb, trace.strtb)
+
+    def handle_success(self, retval, *args):
+        """Handle successful execution."""
+        self.task.on_success(retval, self.task_id, self.args, self.kwargs)
+        return retval
+
+    def handle_retry(self, exc, type_, tb, strtb):
+        """Handle retry exception."""
+        self.task.on_retry(exc, self.task_id, self.args, self.kwargs)
+
+        # Create a simpler version of the RetryTaskError that stringifies
+        # the original exception instead of including the exception instance.
+        # This is for reporting the retry in logs, e-mail etc, while
+        # guaranteeing pickleability.
+        message, orig_exc = exc.args
+        expanded_msg = "%s: %s" % (message, str(orig_exc))
+        return ExceptionInfo((type_,
+                              type_(expanded_msg, None),
+                              tb))
+
+    def handle_failure(self, exc, type_, tb, strtb):
+        """Handle exception."""
+        self.task.on_failure(exc, self.task_id, self.args, self.kwargs)
+        return ExceptionInfo((type_, exc, tb))

+ 14 - 24
celery/worker/job.py

@@ -15,7 +15,7 @@ from celery import platform
 from celery.log import get_default_logger
 from celery.utils import noop, fun_takes_kwargs
 from celery.loaders import current_loader
-from celery.execute import TaskTrace
+from celery.execute.trace import TaskTrace
 from celery.registry import tasks
 from celery.datastructures import ExceptionInfo
 
@@ -75,56 +75,46 @@ class WorkerTaskTrace(TaskTrace):
         self._store_errors = True
         if self.task.ignore_result:
             self._store_errors = conf.STORE_ERRORS_EVEN_IF_IGNORED
+        self.super = super(WorkerTaskTrace, self)
 
     def execute_safe(self, *args, **kwargs):
+        """Same as :meth:`execute`, but catches errors."""
         try:
             return self.execute(*args, **kwargs)
         except Exception, exc:
-            type_, value_, tb = sys.exc_info()
-            exc = self.task.backend.prepare_exception(exc)
-            warnings.warn("Exception happend outside of task body: %s: %s" % (
-                str(exc.__class__), str(exc)))
-            return ExceptionInfo((type_, exc, tb))
+            exc_info = sys.exc_info()
+            exc_info[1] = self.task_backend.prepare_exception(exc)
+            exc_info = ExceptionInfo(exc_info)
+            warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
+                map(str, (exc.__class__, exc, exc_info.traceback))))
+            return exc_info
 
     def execute(self):
-        # Run task loader init handler.
+        """Execute, trace and store the result of the task."""
         self.loader.on_task_init(self.task_id, self.task)
-
-        # Backend process cleanup
         self.task.backend.process_cleanup()
-
         return self._trace()
 
     def handle_success(self, retval, *args):
-        """Handle successful execution.
-
-        Saves the result to the current result store (skipped if the task's
-            ``ignore_result`` attribute is set to ``True``).
-
-        """
+        """Handle successful execution."""
         if not self.task.ignore_result:
             self.task.backend.mark_as_done(self.task_id, retval)
-        return super(WorkerTaskTrace, self).handle_success(retval, *args)
+        return self.super.handle_success(retval, *args)
 
     def handle_retry(self, exc, type_, tb, strtb):
         """Handle retry exception."""
         message, orig_exc = exc.args
         if self._store_errors:
             self.task.backend.mark_as_retry(self.task_id, orig_exc, strtb)
-        return super(WorkerTaskTrace, self).handle_retry(exc, type_,
-                                                         tb, strtb)
+        return self.super.handle_retry(exc, type_, tb, strtb)
 
     def handle_failure(self, exc, type_, tb, strtb):
         """Handle exception."""
         if self._store_errors:
-            # mark_as_failure returns an exception that is guaranteed to
-            # be pickleable.
             exc = self.task.backend.mark_as_failure(self.task_id, exc, strtb)
         else:
             exc = self.task.backend.prepare_exception(exc)
-
-        return super(WorkerTaskTrace, self).handle_failure(
-                exc, type_, tb, strtb)
+        return self.super.handle_failure(exc, type_, tb, strtb)
 
 
 def execute_and_trace(task_name, *args, **kwargs):