Parcourir la source

Properly clean up after sys.exc_info

Ask Solem il y a 13 ans
Parent
commit
c57ef46373

+ 37 - 42
celery/task/trace.py

@@ -21,7 +21,6 @@ from __future__ import absolute_import
 import os
 import socket
 import sys
-import traceback
 
 from warnings import warn
 
@@ -71,17 +70,11 @@ def defines_custom_call(task):
 
 
 class TraceInfo(object):
-    __slots__ = ("state", "retval", "exc_info",
-                 "exc_type", "exc_value", "tb", "strtb")
+    __slots__ = ("state", "retval", "tb")
 
-    def __init__(self, state, retval=None, exc_info=None):
+    def __init__(self, state, retval=None):
         self.state = state
         self.retval = retval
-        self.exc_info = exc_info
-        if exc_info:
-            self.exc_type, self.exc_value, self.tb = exc_info
-        else:
-            self.exc_type = self.exc_value = self.tb = None
 
     def handle_error_state(self, task, eager=False):
         store_errors = not eager
@@ -100,35 +93,37 @@ class TraceInfo(object):
         # This is for reporting the retry in logs, email etc, while
         # guaranteeing pickleability.
         req = task.request
-        exc, type_, tb = self.retval, self.exc_type, self.tb
-        message, orig_exc = self.retval.args
-        if store_errors:
-            task.backend.mark_as_retry(req.id, orig_exc, self.strtb)
-        expanded_msg = "%s: %s" % (message, str(orig_exc))
-        einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
-        task.on_retry(exc, req.id, req.args, req.kwargs, einfo)
-        return einfo
+        type_, _, tb = sys.exc_info()
+        try:
+            exc = self.retval
+            message, orig_exc = exc.args
+            expanded_msg = "%s: %s" % (message, str(orig_exc))
+            einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
+            if store_errors:
+                task.backend.mark_as_retry(req.id, orig_exc, einfo.traceback)
+            task.on_retry(exc, req.id, req.args, req.kwargs, einfo)
+            return einfo
+        finally:
+            del(tb)
 
     def handle_failure(self, task, store_errors=True):
         """Handle exception."""
         req = task.request
-        exc, type_, tb = self.retval, self.exc_type, self.tb
-        if store_errors:
-            task.backend.mark_as_failure(req.id, exc, self.strtb)
-        exc = get_pickleable_exception(exc)
-        einfo = ExceptionInfo((type_, exc, tb))
-        task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
-        signals.task_failure.send(sender=task, task_id=req.id,
-                                  exception=exc, args=req.args,
-                                  kwargs=req.kwargs, traceback=tb,
-                                  einfo=einfo)
-        return einfo
-
-    @property
-    def strtb(self):
-        if self.exc_info:
-            return '\n'.join(traceback.format_exception(*self.exc_info))
-        return ''
+        _, type_, tb = sys.exc_info()
+        try:
+            exc = self.retval
+            einfo = ExceptionInfo((type_, get_pickleable_exception(exc), tb))
+            if store_errors:
+                task.backend.mark_as_failure(req.id, exc, einfo.traceback)
+            task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
+            signals.task_failure.send(sender=task, task_id=req.id,
+                                      exception=exc, args=req.args,
+                                      kwargs=req.kwargs,
+                                      traceback=einfo.traceback,
+                                      einfo=einfo)
+            return einfo
+        finally:
+            del(tb)
 
 
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
@@ -184,16 +179,16 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 # -*- TRACE -*-
                 try:
                     R = retval = fun(*args, **kwargs)
-                    state, einfo = SUCCESS, None
+                    state = SUCCESS
                 except RetryTaskError, exc:
-                    I = Info(RETRY, exc, sys.exc_info())
-                    state, retval, einfo = I.state, I.retval, I.exc_info
+                    I = Info(RETRY, exc)
+                    state, retval = I.state, I.retval
                     R = I.handle_error_state(task, eager=eager)
                 except Exception, exc:
                     if propagate:
                         raise
-                    I = Info(FAILURE, exc, sys.exc_info())
-                    state, retval, einfo = I.state, I.retval, I.exc_info
+                    I = Info(FAILURE, exc)
+                    state, retval = I.state, I.retval
                     R = I.handle_error_state(task, eager=eager)
                     [subtask(errback).apply_async((uuid, ))
                         for errback in task_request.errbacks or []]
@@ -204,8 +199,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     # (but deprecated)
                     if propagate:
                         raise
-                    I = Info(FAILURE, None, sys.exc_info())
-                    state, retval, einfo = I.state, I.retval, I.exc_info
+                    I = Info(FAILURE, None)
+                    state, retval = I.state, I.retval
                     R = I.handle_error_state(task, eager=eager)
                     [subtask(errback).apply_async((uuid, ))
                         for errback in task_request.errbacks or []]
@@ -221,7 +216,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 # -* POST *-
                 if task_request.chord:
                     on_chord_part_return(task)
-                task_after_return(state, retval, uuid, args, kwargs, einfo)
+                task_after_return(state, retval, uuid, args, kwargs, None)
                 send_postrun(sender=task, task_id=uuid, task=task,
                             args=args, kwargs=kwargs, retval=retval)
             finally:

+ 3 - 2
celery/tests/backends/test_amqp.py

@@ -66,7 +66,7 @@ class test_AMQPBackend(Case):
         try:
             raise KeyError("foo")
         except KeyError, exception:
-            einfo = ExceptionInfo(sys.exc_info())
+            einfo = ExceptionInfo()
             tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
             self.assertEqual(tb2.get_status(tid3), states.FAILURE)
             self.assertIsInstance(tb2.get_result(tid3), KeyError)
@@ -235,7 +235,8 @@ class test_AMQPBackend(Case):
         expected_results = [(tid, {"status": states.SUCCESS,
                                     "result": i,
                                     "traceback": None,
-                                    "task_id": tid})
+                                    "task_id": tid,
+                                    "children": None})
                                 for i, tid in enumerate(tids)]
         self.assertEqual(sorted(res), sorted(expected_results))
         self.assertDictEqual(b._cache[res[0][0]], res[0][1])

+ 2 - 1
celery/tests/backends/test_cassandra.py

@@ -75,7 +75,8 @@ class test_CassandraBackend(AppCase):
                                 "status": states.SUCCESS,
                                 "result": "1",
                                 "date_done": "date",
-                                "traceback": ""}
+                                "traceback": "",
+                                "children": None}
             x.decode = Mock()
             x.detailed_mode = False
             meta = x._get_task_meta_for("task_id")

+ 2 - 1
celery/tests/backends/test_mongodb.py

@@ -195,7 +195,8 @@ class test_MongoBackend(AppCase):
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
         self.assertEquals(
-            ['status', 'date_done', 'traceback', 'result', 'task_id'],
+            ['status', 'task_id', 'date_done', 'traceback', 'result',
+             'children'],
             ret_val.keys())
 
     @patch("celery.backends.mongodb.MongoBackend._get_database")

+ 1 - 1
celery/tests/concurrency/test_pool.py

@@ -22,7 +22,7 @@ def raise_something(i):
     try:
         raise KeyError("FOO EXCEPTION")
     except KeyError:
-        return ExceptionInfo(sys.exc_info())
+        return ExceptionInfo()
 
 
 class test_TaskPool(Case):

+ 0 - 7
celery/tests/concurrency/test_processes.py

@@ -46,13 +46,6 @@ class Object(object):   # for writeable attributes.
         [setattr(self, k, v) for k, v in kwargs.items()]
 
 
-def to_excinfo(exc):
-    try:
-        raise exc
-    except:
-        return ExceptionInfo(sys.exc_info())
-
-
 class MockResult(object):
 
     def __init__(self, value, pid):

+ 10 - 12
celery/tests/utilities/test_datastructures.py

@@ -96,18 +96,16 @@ class test_ExceptionInfo(Case):
 
         try:
             raise LookupError("The quick brown fox jumps...")
-        except LookupError:
-            exc_info = sys.exc_info()
-
-        einfo = ExceptionInfo(exc_info)
-        self.assertEqual(str(einfo), einfo.traceback)
-        self.assertIsInstance(einfo.exception, LookupError)
-        self.assertTupleEqual(einfo.exception.args,
-                ("The quick brown fox jumps...", ))
-        self.assertTrue(einfo.traceback)
-
-        r = repr(einfo)
-        self.assertTrue(r)
+        except Exception:
+            einfo = ExceptionInfo()
+            self.assertEqual(str(einfo), einfo.traceback)
+            self.assertIsInstance(einfo.exception, LookupError)
+            self.assertTupleEqual(einfo.exception.args,
+                    ("The quick brown fox jumps...", ))
+            self.assertTrue(einfo.traceback)
+
+            r = repr(einfo)
+            self.assertTrue(r)
 
 
 class test_LimitedSet(Case):

+ 27 - 27
celery/tests/worker/test_request.py

@@ -268,7 +268,7 @@ class test_TaskRequest(Case):
         try:
             raise RetryTaskError("foo", KeyError("moofoobar"))
         except:
-            einfo = ExceptionInfo(sys.exc_info())
+            einfo = ExceptionInfo()
             tw.on_failure(einfo)
             self.assertIn("task-retried", tw.eventer.sent)
             tw._does_info = False
@@ -344,7 +344,7 @@ class test_TaskRequest(Case):
             try:
                 raise KeyError("moofoobar")
             except:
-                return ExceptionInfo(sys.exc_info())
+                return ExceptionInfo()
 
         app.mail_admins = mock_mail_admins
         mytask.send_error_emails = True
@@ -452,7 +452,7 @@ class test_TaskRequest(Case):
             try:
                 raise SystemExit()
             except SystemExit:
-                tw.on_success(ExceptionInfo(sys.exc_info()))
+                tw.on_success(ExceptionInfo())
             else:
                 assert False
 
@@ -471,7 +471,7 @@ class test_TaskRequest(Case):
         try:
             raise KeyError("foo")
         except Exception:
-            tw.on_success(ExceptionInfo(sys.exc_info()))
+            tw.on_success(ExceptionInfo())
             self.assertTrue(tw.on_failure.called)
 
     def test_on_success_acks_late(self):
@@ -490,7 +490,7 @@ class test_TaskRequest(Case):
             try:
                 raise WorkerLostError("do re mi")
             except WorkerLostError:
-                return ExceptionInfo(sys.exc_info())
+                return ExceptionInfo()
 
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         exc_info = get_ei()
@@ -516,7 +516,7 @@ class test_TaskRequest(Case):
             try:
                 raise KeyError("foo")
             except KeyError:
-                exc_info = ExceptionInfo(sys.exc_info())
+                exc_info = ExceptionInfo()
                 tw.on_failure(exc_info)
                 self.assertTrue(tw.acknowledged)
         finally:
@@ -568,38 +568,38 @@ class test_TaskRequest(Case):
                                     [], {})
             self.assertIsInstance(res, ExceptionInfo)
 
-    def create_exception(self, exc):
-        try:
-            raise exc
-        except exc.__class__:
-            return sys.exc_info()
-
     def test_worker_task_trace_handle_retry(self):
         from celery.exceptions import RetryTaskError
         tid = uuid()
         mytask.request.update({"id": tid})
+        einfo = tb = None
         try:
-            _, value_, _ = self.create_exception(ValueError("foo"))
-            einfo = self.create_exception(RetryTaskError(str(value_),
-                                          exc=value_))
-            w = TraceInfo(states.RETRY, einfo[1], einfo)
-            w.handle_retry(mytask, store_errors=False)
-            self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
-            w.handle_retry(mytask, store_errors=True)
-            self.assertEqual(mytask.backend.get_status(tid), states.RETRY)
+            raise ValueError("foo")
+        except Exception, exc:
+            try:
+                raise RetryTaskError(str(exc), exc=exc)
+            except RetryTaskError, exc:
+                w = TraceInfo(states.RETRY, exc)
+                w.handle_retry(mytask, store_errors=False)
+                self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
+                w.handle_retry(mytask, store_errors=True)
+                self.assertEqual(mytask.backend.get_status(tid), states.RETRY)
         finally:
             mytask.request.clear()
 
     def test_worker_task_trace_handle_failure(self):
         tid = uuid()
         mytask.request.update({"id": tid})
+        einfo = None
         try:
-            einfo = self.create_exception(ValueError("foo"))
-            w = TraceInfo(states.FAILURE, einfo[1], einfo)
-            w.handle_failure(mytask, store_errors=False)
-            self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
-            w.handle_failure(mytask, store_errors=True)
-            self.assertEqual(mytask.backend.get_status(tid), states.FAILURE)
+            try:
+                raise ValueError("foo")
+            except Exception, exc:
+                w = TraceInfo(states.FAILURE, exc)
+                w.handle_failure(mytask, store_errors=False)
+                self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
+                w.handle_failure(mytask, store_errors=True)
+                self.assertEqual(mytask.backend.get_status(tid), states.FAILURE)
         finally:
             mytask.request.clear()
 
@@ -755,7 +755,7 @@ class test_TaskRequest(Case):
         try:
             raise exception
         except Exception:
-            exc_info = ExceptionInfo(sys.exc_info())
+            exc_info = ExceptionInfo()
             app.conf.CELERY_SEND_TASK_ERROR_EMAILS = True
             try:
                 tw.on_failure(exc_info)

+ 4 - 6
celery/tests/worker/test_worker.py

@@ -864,12 +864,10 @@ class test_WorkController(AppCase):
 
         try:
             raise KeyError("foo")
-        except KeyError:
-            exc_info = sys.exc_info()
-
-        Timers(worker).on_timer_error(exc_info)
-        msg, args = self.logger.error.call_args[0]
-        self.assertIn("KeyError", msg % args)
+        except KeyError, exc:
+            Timers(worker).on_timer_error(exc)
+            msg, args = self.logger.error.call_args[0]
+            self.assertIn("KeyError", msg % args)
 
     def test_on_timer_tick(self):
         worker = WorkController(concurrency=1, loglevel=10)

+ 4 - 4
celery/utils/__init__.py

@@ -157,13 +157,13 @@ def cry():  # pragma: no cover
 def maybe_reraise():
     """Reraise if an exception is currently being handled, or return
     otherwise."""
-    type_, exc, tb = sys.exc_info()
+    exc_info = sys.exc_info()
     try:
-        if tb:
-            raise type_, exc, tb
+        if exc_info[2]:
+            raise exc_info[0], exc_info[1], exc_info[2]
     finally:
         # see http://docs.python.org/library/sys.html#sys.exc_info
-        del(tb)
+        del(exc_info)
 
 
 # - XXX Compat

+ 1 - 1
celery/utils/log.py

@@ -65,7 +65,7 @@ class ColorFormatter(logging.Formatter):
             except Exception, exc:
                 record.msg = "<Unrepresentable %r: %r>" % (
                         type(record.msg), exc)
-                record.exc_info = sys.exc_info()
+                record.exc_info = True
 
         if not is_py3k and "processName" not in record.__dict__:
             # Very ugly, but have to make sure processName is supported

+ 8 - 4
celery/utils/threads.py

@@ -39,10 +39,14 @@ class bgThread(Thread):
     def body(self):
         raise NotImplementedError("subclass responsibility")
 
-    def on_crash(self, exc_info, msg, *fmt, **kwargs):
+    def on_crash(self, msg, *fmt, **kwargs):
         sys.stderr.write((msg + "\n") % fmt)
-        traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
-                                  None, sys.stderr)
+        exc_info = sys.exc_info()
+        try:
+            traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
+                                      None, sys.stderr)
+        finally:
+            del(exc_info)
 
     def run(self):
         shutdown = self._is_shutdown
@@ -50,7 +54,7 @@ class bgThread(Thread):
             try:
                 self.body()
             except Exception, exc:
-                self.on_crash(sys.exc_info(), "%r crashed: %r", self.name, exc)
+                self.on_crash("%r crashed: %r", self.name, exc)
                 # exiting by normal means does not work here, so force exit.
                 os._exit(1)
         try:

+ 4 - 8
celery/utils/timer2.py

@@ -103,8 +103,8 @@ class Schedule(object):
             eta = datetime.now()
         try:
             eta = to_timestamp(eta)
-        except OverflowError:
-            if not self.handle_error(sys.exc_info()):
+        except OverflowError, exc:
+            if not self.handle_error(exc):
                 raise
             return
         return self._enter(eta, priority, entry)
@@ -182,12 +182,8 @@ class Timer(Thread):
         try:
             entry()
         except Exception, exc:
-            exc_info = sys.exc_info()
-            try:
-                if not self.schedule.handle_error(exc_info):
-                    logger.error("Error in timer: %r\n", exc, exc_info=True)
-            finally:
-                del(exc_info)
+            if not self.schedule.handle_error(exc):
+                logger.error("Error in timer: %r", exc, exc_info=True)
 
     def _next_entry(self):
         with self.not_empty:

+ 2 - 2
celery/worker/__init__.py

@@ -153,8 +153,8 @@ class Timers(abstract.Component):
                                 on_error=self.on_timer_error,
                                 on_tick=self.on_timer_tick)
 
-    def on_timer_error(self, einfo):
-        logger.error("Timer error: %r", einfo[1], exc_info=einfo)
+    def on_timer_error(self, exc):
+        logger.error("Timer error: %r", exc, exc_info=True)
 
     def on_timer_tick(self, delay):
         logger.debug("Scheduler wake-up! Next eta %s secs.", delay)