Browse Source

Task exceptions now has only one extra frame in the stacktrace (before there were dozens)

Ask Solem 13 years ago
parent
commit
ba9d6c839d

+ 6 - 15
celery/app/__init__.py

@@ -170,23 +170,14 @@ class App(base.BaseApp):
         def inner_create_task_cls(**options):
 
             def _create_task_cls(fun):
-                options["app"] = self
-                options.setdefault("accept_magic_kwargs", False)
                 base = options.pop("base", None) or self.Task
 
-                @wraps(fun, assigned=("__module__", "__name__"))
-                def run(self, *args, **kwargs):
-                    return fun(*args, **kwargs)
-
-                # Save the argspec for this task so we can recognize
-                # which default task kwargs we're going to pass to it later.
-                # (this happens in celery.utils.fun_takes_kwargs)
-                run.argspec = getargspec(fun)
-
-                cls_dict = dict(options, run=run,
-                                __module__=fun.__module__,
-                                __doc__=fun.__doc__)
-                T = type(fun.__name__, (base, ), cls_dict)()
+                T = type(fun.__name__, (base, ), dict({
+                        "app": self,
+                        "accept_magic_kwargs": False,
+                        "run": staticmethod(fun),
+                        "__doc__": fun.__doc__,
+                        "__module__": fun.__module__}, **options))()
                 return registry.tasks[T.name]             # global instance.
 
             return _create_task_cls

+ 27 - 9
celery/app/task/__init__.py

@@ -73,11 +73,15 @@ class TaskType(type):
         new = super(TaskType, cls).__new__
         task_module = attrs.get("__module__") or "__main__"
 
-        # Abstract class: abstract attribute should not be inherited.
+        if "__call__" in attrs:
+            # see note about __call__ below.
+            attrs["__defines_call__"] = True
+
+        # - Abstract class: abstract attribute should not be inherited.
         if attrs.pop("abstract", None) or not attrs.get("autoregister", True):
             return new(cls, name, bases, attrs)
 
-        # Automatically generate missing/empty name.
+        # - Automatically generate missing/empty name.
         autoname = False
         if not attrs.get("name"):
             try:
@@ -88,6 +92,22 @@ class TaskType(type):
             attrs["name"] = '.'.join([module_name, name])
             autoname = True
 
+        # - Automatically generate __call__.
+        # If this or none of its bases define __call__, we simply
+        # alias it to the ``run`` method, as
+        # this means we can skip a stacktrace frame :)
+        if not (attrs.get("__call__")
+                or any(getattr(b, "__defines_call__", False) for b in bases)):
+            try:
+                attrs["__call__"] = attrs["run"]
+            except KeyError:
+                # the class does not yet define run,
+                # so we can't optimize this case.
+                def __call__(self, *args, **kwargs):
+                    return self.run(*args, **kwargs)
+                attrs["__call__"] = __call__
+
+        # - Create and register class.
         # Because of the way import happens (recursively)
         # we may or may not be the first time the task tries to register
         # with the framework.  There should only be one class for each task
@@ -249,9 +269,6 @@ class BaseTask(object):
     #: Execution strategy used, or the qualified name of one.
     Strategy = "celery.worker.strategy:default"
 
-    def __call__(self, *args, **kwargs):
-        return self.run(*args, **kwargs)
-
     def __reduce__(self):
         return (_unpickle_task, (self.name, ), None)
 
@@ -597,12 +614,13 @@ class BaseTask(object):
                                         if key in supported_keys)
             kwargs.update(extend_with)
 
-        trace = TaskTrace(task.name, task_id, args, kwargs,
-                          task=task, request=request, propagate=throw)
-        retval = trace.execute()
+        trace = TaskTrace(task.name, task_id, args, kwargs, eager=True,
+                          task=task, request=request, propagate=throw,
+                          propagate_internal=True)
+        retval = trace()
         if isinstance(retval, ExceptionInfo):
             retval = retval.exception
-        return EagerResult(task_id, retval, trace.status,
+        return EagerResult(task_id, retval, trace.state,
                            traceback=trace.strtb)
 
     @classmethod

+ 1 - 1
celery/execute/trace.py

@@ -150,7 +150,7 @@ class TaskTrace(object):
                             args=args, kwargs=kwargs)
                 loader.on_task_init(uuid, task)
                 if not eager and (task.track_started and not ignore_result):
-                    backend.mark_as_started(id, pid=os.getpid(),
+                    backend.mark_as_started(uuid, pid=os.getpid(),
                                             hostname=self.hostname)
 
                 # -*- TRACE -*-

+ 15 - 12
celery/tests/test_worker/test_worker_job.py

@@ -7,6 +7,7 @@ import logging
 import os
 import sys
 import time
+import warnings
 
 from datetime import datetime, timedelta
 
@@ -459,22 +460,24 @@ class test_TaskRequest(unittest.TestCase):
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):
-        old_exec = TaskTrace.__call__
+        old_exec = mytask.__call__
+        warnings.resetwarnings()
 
         def _error_exec(self, *args, **kwargs):
             raise KeyError("baz")
 
-        TaskTrace.__call__ = _error_exec
-        try:
-            with catch_warnings(record=True) as log:
-                res = execute_and_trace(mytask.name, uuid(),
-                                        [4], {})
-                self.assertIsInstance(res, ExceptionInfo)
-                self.assertTrue(log)
-                self.assertIn("Exception outside", log[0].message.args[0])
-                self.assertIn("KeyError", log[0].message.args[0])
-        finally:
-            TaskTrace.__call__ = old_exec
+        @task_dec
+        def raising():
+            raise KeyError("baz")
+        raising.request = None
+
+        with catch_warnings(record=True) as log:
+            res = execute_and_trace(raising.name, uuid(),
+                                    [], {})
+            self.assertIsInstance(res, ExceptionInfo)
+            self.assertTrue(log)
+            self.assertIn("Exception outside", log[0].message.args[0])
+            self.assertIn("AttributeError", log[0].message.args[0])
 
     def create_exception(self, exc):
         try: