Browse Source

Better exception and traceback handling.

Ask Solem 13 years ago
parent
commit
9530098d0f

+ 4 - 2
celery/app/task/__init__.py

@@ -21,7 +21,8 @@ from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...execute.trace import eager_trace_task
 from ...registry import tasks, _unpickle_task
 from ...result import EagerResult
-from ...utils import fun_takes_kwargs, instantiate, mattrgetter, uuid
+from ...utils import (fun_takes_kwargs, instantiate,
+                      mattrgetter, uuid, maybe_reraise)
 from ...utils.mail import ErrorMail
 
 extract_exec_options = mattrgetter("queue", "routing_key",
@@ -535,6 +536,7 @@ class BaseTask(object):
         # Not in worker or emulated by (apply/always_eager),
         # so just raise the original exception.
         if request.called_directly:
+            maybe_reraise()
             raise exc or RetryTaskError("Task can be retried", None)
 
         if delivery_info:
@@ -551,7 +553,7 @@ class BaseTask(object):
 
         if max_retries is not None and options["retries"] > max_retries:
             if exc:
-                raise
+                maybe_reraise()
             raise self.MaxRetriesExceededError(
                     "Can't retry %s[%s] args:%s kwargs:%s" % (
                         self.name, options["task_id"], args, kwargs))

+ 8 - 2
celery/backends/__init__.py

@@ -1,11 +1,17 @@
 # -*- coding: utf-8 -*-
 from __future__ import absolute_import
 
+import sys
+
 from .. import current_app
 from ..local import Proxy
 from ..utils import get_cls_by_name
 from ..utils.functional import memoize
 
+UNKNOWN_BACKEND = """\
+Unknown result backend: %r.  Did you spell that correctly? (%r)\
+"""
+
 BACKEND_ALIASES = {
     "amqp": "celery.backends.amqp:AMQPBackend",
     "cache": "celery.backends.cache:CacheBackend",
@@ -27,8 +33,8 @@ def get_backend_cls(backend=None, loader=None):
     try:
         return get_cls_by_name(backend, aliases)
     except ValueError, exc:
-        raise ValueError("Unknown result backend: %r.  "
-                         "Did you spell it correctly?  (%s)" % (backend, exc))
+        raise ValueError, ValueError(UNKNOWN_BACKEND % (
+                    backend, exc)), sys.exc_info()[2]
 
 
 # deprecate this.

+ 9 - 5
celery/execute/trace.py

@@ -217,8 +217,12 @@ def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
 
 def report_internal_error(task, exc):
     _type, _value, _tb = sys.exc_info()
-    _value = task.backend.prepare_exception(exc)
-    exc_info = ExceptionInfo((_type, _value, _tb))
-    warn(RuntimeWarning(
-        "Exception raised outside body: %r:\n%s" % (exc, exc_info.traceback)))
-    return exc_info
+    try:
+        _value = task.backend.prepare_exception(exc)
+        exc_info = ExceptionInfo((_type, _value, _tb))
+        warn(RuntimeWarning(
+            "Exception raised outside body: %r:\n%s" % (
+                exc, exc_info.traceback)))
+        return exc_info
+    finally:
+        del(_tb)

+ 4 - 2
celery/loaders/base.py

@@ -14,6 +14,7 @@ from __future__ import absolute_import
 import importlib
 import os
 import re
+import traceback
 import warnings
 
 from anyjson import deserialize
@@ -194,8 +195,9 @@ class BaseLoader(object):
             if not fail_silently:
                 raise
             warnings.warn(self.mail.SendmailWarning(
-                "Mail could not be sent: %r %r" % (
-                    exc, {"To": to, "Subject": subject})))
+                "Mail could not be sent: %r %r\n%r" % (
+                    exc, {"To": to, "Subject": subject},
+                    traceback.format_stack())))
 
     @property
     def conf(self):

+ 7 - 3
celery/loaders/default.py

@@ -50,9 +50,13 @@ class Loader(BaseLoader):
             self.find_module(configname)
         except NotAPackage:
             if configname.endswith('.py'):
-                raise NotAPackage(CONFIG_WITH_SUFFIX % {
-                    "module": configname, "suggest": configname[:-3]})
-                raise NotAPackage(CONFIG_INVALID_NAME % {"module": configname})
+                raise NotAPackage, NotAPackage(
+                        CONFIG_WITH_SUFFIX % {
+                            "module": configname,
+                            "suggest": configname[:-3]}), sys.exc_info()[2]
+            raise NotAPackage, NotAPackage(
+                    CONFIG_INVALID_NAME % {
+                        "module": configname}), sys.exc_info()[2]
         except ImportError:
             warnings.warn(NotConfigured(
                 "No %r module found! Please make sure it exists and "

+ 1 - 1
celery/platforms.py

@@ -102,7 +102,7 @@ class PIDFile(object):
         try:
             self.write_pid()
         except OSError, exc:
-            raise LockFailed(str(exc))
+            raise LockFailed, LockFailed(str(exc)), sys.exc_info()[2]
         return self
     __enter__ = acquire
 

+ 6 - 3
celery/security/certificate.py

@@ -1,8 +1,9 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-import os
 import glob
+import os
+import sys
 
 try:
     from OpenSSL import crypto
@@ -20,7 +21,8 @@ class Certificate(object):
         try:
             self._cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
         except crypto.Error, exc:
-            raise SecurityError("Invalid certificate: %r" % (exc, ))
+            raise SecurityError, SecurityError(
+                    "Invalid certificate: %r" % (exc, )), sys.exc_info()[2]
 
     def has_expired(self):
         """Check if the certificate has expired."""
@@ -44,7 +46,8 @@ class Certificate(object):
         try:
             crypto.verify(self._cert, signature, data, digest)
         except crypto.Error, exc:
-            raise SecurityError("Bad signature: %r" % (exc, ))
+            raise SecurityError, SecurityError(
+                    "Bad signature: %r" % (exc, )), sys.exc_info()[2]
 
 
 class CertStore(object):

+ 6 - 2
celery/security/key.py

@@ -1,5 +1,7 @@
 from __future__ import absolute_import
 
+import sys
+
 try:
     from OpenSSL import crypto
 except ImportError:
@@ -15,11 +17,13 @@ class PrivateKey(object):
         try:
             self._key = crypto.load_privatekey(crypto.FILETYPE_PEM, key)
         except crypto.Error, exc:
-            raise SecurityError("Invalid private key: %r" % (exc, ))
+            raise SecurityError, SecurityError(
+                    "Invalid private key: %r" % (exc, )), sys.exc_info()[2]
 
     def sign(self, data, digest):
         """sign string containing data."""
         try:
             return crypto.sign(self._key, data, digest)
         except crypto.Error, exc:
-            raise SecurityError("Unable to sign data: %r" % (exc, ))
+            raise SecurityError, SecurityError(
+                    "Unable to sign data: %r" % (exc, )), sys.exc_info()[2]

+ 5 - 2
celery/security/serialization.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 
 import base64
+import sys
 
 from kombu.serialization import registry, encode, decode
 
@@ -44,7 +45,8 @@ class SecureSerializer(object):
                               signature=self._key.sign(body, self._digest),
                               signer=self._cert.get_id())
         except Exception, exc:
-            raise SecurityError("Unable to serialize: %r" % (exc, ))
+            raise SecurityError, SecurityError(
+                    "Unable to serialize: %r" % (exc, )), sys.exc_info()[2]
 
     def deserialize(self, data):
         """deserialize data structure from string"""
@@ -57,7 +59,8 @@ class SecureSerializer(object):
             self._cert_store[signer].verify(body,
                                             signature, self._digest)
         except Exception, exc:
-            raise SecurityError("Unable to deserialize: %r" % (exc, ))
+            raise SecurityError, SecurityError(
+                    "Unable to deserialize: %r" % (exc, )), sys.exc_info()[2]
 
         return decode(body, payload["content_type"],
                             payload["content_encoding"], force=True)

+ 2 - 1
celery/task/http.py

@@ -70,7 +70,8 @@ def extract_response(raw_response):
     try:
         payload = deserialize(raw_response)
     except ValueError, exc:
-        raise InvalidResponseError(str(exc))
+        raise InvalidResponseError, InvalidResponseError(
+                str(exc)), sys.exc_info()[2]
 
     status = payload["status"]
     if status == "success":

+ 22 - 11
celery/tests/test_worker/test_worker_job.py

@@ -290,29 +290,35 @@ class test_TaskRequest(Case):
         def mock_mail_admins(*args, **kwargs):
             mail_sent[0] = True
 
+        def get_ei():
+            try:
+                raise KeyError("moofoobar")
+            except:
+                return ExceptionInfo(sys.exc_info())
+
         app.mail_admins = mock_mail_admins
         mytask.send_error_emails = True
         try:
             tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
-            try:
-                raise KeyError("moofoobar")
-            except:
-                einfo = ExceptionInfo(sys.exc_info())
 
+            einfo = get_ei()
             tw.on_failure(einfo)
             self.assertTrue(mail_sent[0])
 
+            einfo = get_ei()
             mail_sent[0] = False
             mytask.send_error_emails = False
             tw.on_failure(einfo)
             self.assertFalse(mail_sent[0])
 
+            einfo = get_ei()
             mail_sent[0] = False
             mytask.send_error_emails = True
             mytask.error_whitelist = [KeyError]
             tw.on_failure(einfo)
             self.assertTrue(mail_sent[0])
 
+            einfo = get_ei()
             mail_sent[0] = False
             mytask.send_error_emails = True
             mytask.error_whitelist = [SyntaxError]
@@ -394,17 +400,22 @@ class test_TaskRequest(Case):
             mytask.acks_late = False
 
     def test_on_failure_WorkerLostError(self):
+
+        def get_ei():
+            try:
+                raise WorkerLostError("do re mi")
+            except WorkerLostError:
+                return ExceptionInfo(sys.exc_info())
+
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
-        try:
-            raise WorkerLostError("do re mi")
-        except WorkerLostError:
-            exc_info = ExceptionInfo(sys.exc_info())
-            tw.on_failure(exc_info)
-            self.assertEqual(mytask.backend.get_status(tw.task_id),
-                             states.FAILURE)
+        exc_info = get_ei()
+        tw.on_failure(exc_info)
+        self.assertEqual(mytask.backend.get_status(tw.task_id),
+                         states.FAILURE)
 
         mytask.ignore_result = True
         try:
+            exc_info = get_ei()
             tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
             tw.on_failure(exc_info)
             self.assertEqual(mytask.backend.get_status(tw.task_id),

+ 15 - 1
celery/utils/__init__.py

@@ -288,7 +288,8 @@ def get_cls_by_name(name, aliases={}, imp=None, package=None,
     try:
         module = imp(module_name, package=package, **kwargs)
     except ValueError, exc:
-        raise ValueError("Couldn't import %r: %s" % (name, exc))
+        raise ValueError, ValueError(
+                "Couldn't import %r: %s" % (name, exc)), sys.exc_info()[2]
     return getattr(module, cls_name)
 
 get_symbol_by_name = get_cls_by_name
@@ -431,3 +432,16 @@ def uniq(it):
         if obj not in seen:
             yield obj
             seen.add(obj)
+
+
+
+def maybe_reraise():
+    """Reraise if an exception is currently being handled, or return
+    otherwise."""
+    type_, exc, tb = sys.exc_info()
+    try:
+        if tb:
+            raise type_, exc, tb
+    finally:
+        # see http://docs.python.org/library/sys.html#sys.exc_info
+        del(tb)

+ 2 - 3
celery/utils/timer2.py

@@ -189,10 +189,9 @@ class Timer(Thread):
         try:
             entry()
         except Exception, exc:
-            typ, val, tb = einfo = sys.exc_info()
-            if not self.schedule.handle_error(einfo):
+            if not self.schedule.handle_error(sys.exc_info()):
                 warnings.warn(TimedFunctionFailed(repr(exc))),
-                traceback.print_exception(typ, val, tb)
+                traceback.print_stack()
 
     def _next_entry(self):
         with self.not_empty:

+ 2 - 3
celery/worker/__init__.py

@@ -293,9 +293,8 @@ class WorkController(configurated):
         self._state = self.TERMINATE
         self._shutdown_complete.set()
 
-    def on_timer_error(self, exc_info):
-        _, exc, _ = exc_info
-        self.logger.error("Timer error: %r", exc, exc_info=exc_info)
+    def on_timer_error(self, einfo):
+        self.logger.error("Timer error: %r", einfo[1], exc_info=einfo)
 
     def on_timer_tick(self, delay):
         self.timer_debug("Scheduler wake-up! Next eta %s secs." % delay)