فهرست منبع

Task exceptions now has only one extra frame in the stacktrace (before there were dozens)

Ask Solem 13 سال پیش
والد
کامیت
ba9d6c839d
4فایلهای تغییر یافته به همراه49 افزوده شده و 37 حذف شده
  1. 6 15
      celery/app/__init__.py
  2. 27 9
      celery/app/task/__init__.py
  3. 1 1
      celery/execute/trace.py
  4. 15 12
      celery/tests/test_worker/test_worker_job.py

+ 6 - 15
celery/app/__init__.py

@@ -170,23 +170,14 @@ class App(base.BaseApp):
         def inner_create_task_cls(**options):
 
             def _create_task_cls(fun):
-                options["app"] = self
-                options.setdefault("accept_magic_kwargs", False)
                 base = options.pop("base", None) or self.Task
 
-                @wraps(fun, assigned=("__module__", "__name__"))
-                def run(self, *args, **kwargs):
-                    return fun(*args, **kwargs)
-
-                # Save the argspec for this task so we can recognize
-                # which default task kwargs we're going to pass to it later.
-                # (this happens in celery.utils.fun_takes_kwargs)
-                run.argspec = getargspec(fun)
-
-                cls_dict = dict(options, run=run,
-                                __module__=fun.__module__,
-                                __doc__=fun.__doc__)
-                T = type(fun.__name__, (base, ), cls_dict)()
+                T = type(fun.__name__, (base, ), dict({
+                        "app": self,
+                        "accept_magic_kwargs": False,
+                        "run": staticmethod(fun),
+                        "__doc__": fun.__doc__,
+                        "__module__": fun.__module__}, **options))()
                 return registry.tasks[T.name]             # global instance.
 
             return _create_task_cls

+ 27 - 9
celery/app/task/__init__.py

@@ -73,11 +73,15 @@ class TaskType(type):
         new = super(TaskType, cls).__new__
         task_module = attrs.get("__module__") or "__main__"
 
-        # Abstract class: abstract attribute should not be inherited.
+        if "__call__" in attrs:
+            # see note about __call__ below.
+            attrs["__defines_call__"] = True
+
+        # - Abstract class: abstract attribute should not be inherited.
         if attrs.pop("abstract", None) or not attrs.get("autoregister", True):
             return new(cls, name, bases, attrs)
 
-        # Automatically generate missing/empty name.
+        # - Automatically generate missing/empty name.
         autoname = False
         if not attrs.get("name"):
             try:
@@ -88,6 +92,22 @@ class TaskType(type):
             attrs["name"] = '.'.join([module_name, name])
             autoname = True
 
+        # - Automatically generate __call__.
+        # If this or none of its bases define __call__, we simply
+        # alias it to the ``run`` method, as
+        # this means we can skip a stacktrace frame :)
+        if not (attrs.get("__call__")
+                or any(getattr(b, "__defines_call__", False) for b in bases)):
+            try:
+                attrs["__call__"] = attrs["run"]
+            except KeyError:
+                # the class does not yet define run,
+                # so we can't optimize this case.
+                def __call__(self, *args, **kwargs):
+                    return self.run(*args, **kwargs)
+                attrs["__call__"] = __call__
+
+        # - Create and register class.
         # Because of the way import happens (recursively)
         # we may or may not be the first time the task tries to register
         # with the framework.  There should only be one class for each task
@@ -249,9 +269,6 @@ class BaseTask(object):
     #: Execution strategy used, or the qualified name of one.
     Strategy = "celery.worker.strategy:default"
 
-    def __call__(self, *args, **kwargs):
-        return self.run(*args, **kwargs)
-
     def __reduce__(self):
         return (_unpickle_task, (self.name, ), None)
 
@@ -597,12 +614,13 @@ class BaseTask(object):
                                         if key in supported_keys)
             kwargs.update(extend_with)
 
-        trace = TaskTrace(task.name, task_id, args, kwargs,
-                          task=task, request=request, propagate=throw)
-        retval = trace.execute()
+        trace = TaskTrace(task.name, task_id, args, kwargs, eager=True,
+                          task=task, request=request, propagate=throw,
+                          propagate_internal=True)
+        retval = trace()
         if isinstance(retval, ExceptionInfo):
             retval = retval.exception
-        return EagerResult(task_id, retval, trace.status,
+        return EagerResult(task_id, retval, trace.state,
                            traceback=trace.strtb)
 
     @classmethod

+ 1 - 1
celery/execute/trace.py

@@ -150,7 +150,7 @@ class TaskTrace(object):
                             args=args, kwargs=kwargs)
                 loader.on_task_init(uuid, task)
                 if not eager and (task.track_started and not ignore_result):
-                    backend.mark_as_started(id, pid=os.getpid(),
+                    backend.mark_as_started(uuid, pid=os.getpid(),
                                             hostname=self.hostname)
 
                 # -*- TRACE -*-

+ 15 - 12
celery/tests/test_worker/test_worker_job.py

@@ -7,6 +7,7 @@ import logging
 import os
 import sys
 import time
+import warnings
 
 from datetime import datetime, timedelta
 
@@ -459,22 +460,24 @@ class test_TaskRequest(unittest.TestCase):
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):
-        old_exec = TaskTrace.__call__
+        old_exec = mytask.__call__
+        warnings.resetwarnings()
 
         def _error_exec(self, *args, **kwargs):
             raise KeyError("baz")
 
-        TaskTrace.__call__ = _error_exec
-        try:
-            with catch_warnings(record=True) as log:
-                res = execute_and_trace(mytask.name, uuid(),
-                                        [4], {})
-                self.assertIsInstance(res, ExceptionInfo)
-                self.assertTrue(log)
-                self.assertIn("Exception outside", log[0].message.args[0])
-                self.assertIn("KeyError", log[0].message.args[0])
-        finally:
-            TaskTrace.__call__ = old_exec
+        @task_dec
+        def raising():
+            raise KeyError("baz")
+        raising.request = None
+
+        with catch_warnings(record=True) as log:
+            res = execute_and_trace(raising.name, uuid(),
+                                    [], {})
+            self.assertIsInstance(res, ExceptionInfo)
+            self.assertTrue(log)
+            self.assertIn("Exception outside", log[0].message.args[0])
+            self.assertIn("AttributeError", log[0].message.args[0])
 
     def create_exception(self, exc):
         try: