Browse Source

Task: Store state for exceptions occurring outside of task body. Closes #1582

Ask Solem 10 years ago
parent
commit
2d7904a836

+ 17 - 17
celery/app/trace.py

@@ -140,7 +140,7 @@ class TraceInfo(object):
         self.state = state
         self.retval = retval
 
-    def handle_error_state(self, task, eager=False):
+    def handle_error_state(self, task, req, eager=False):
         store_errors = not eager
         if task.ignore_result:
             store_errors = task.store_errors_even_if_ignored
@@ -148,19 +148,18 @@ class TraceInfo(object):
         return {
             RETRY: self.handle_retry,
             FAILURE: self.handle_failure,
-        }[self.state](task, store_errors=store_errors)
+        }[self.state](task, req, store_errors=store_errors)
 
-    def handle_reject(self, task, **kwargs):
-        self._log_error(task, ExceptionInfo())
+    def handle_reject(self, task, req, **kwargs):
+        self._log_error(task, req, ExceptionInfo())
 
-    def handle_ignore(self, task, **kwargs):
-        self._log_error(task, ExceptionInfo())
+    def handle_ignore(self, task, req, **kwargs):
+        self._log_error(task, req, ExceptionInfo())
 
-    def handle_retry(self, task, store_errors=True):
+    def handle_retry(self, task, req, store_errors=True):
         """Handle retry exception."""
         # the exception raised is the Retry semi-predicate,
         # and it's exc' attribute is the original exception raised (if any).
-        req = task.request
         type_, _, tb = sys.exc_info()
         try:
             reason = self.retval
@@ -180,9 +179,8 @@ class TraceInfo(object):
         finally:
             del(tb)
 
-    def handle_failure(self, task, store_errors=True):
+    def handle_failure(self, task, req, store_errors=True):
         """Handle exception."""
-        req = task.request
         type_, _, tb = sys.exc_info()
         try:
             exc = self.retval
@@ -199,13 +197,12 @@ class TraceInfo(object):
                                       kwargs=req.kwargs,
                                       traceback=tb,
                                       einfo=einfo)
-            self._log_error(task, einfo)
+            self._log_error(task, req, einfo)
             return einfo
         finally:
             del(tb)
 
-    def _log_error(self, task, einfo):
-        req = task.request
+    def _log_error(self, task, req, einfo):
         eobj = einfo.exception = get_pickled_exception(einfo.exception)
         exception, traceback, exc_info, sargs, skwargs = (
             safe_repr(eobj),
@@ -308,7 +305,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         if propagate:
             raise
         I = Info(state, exc)
-        R = I.handle_error_state(task, eager=eager)
+        R = I.handle_error_state(task, request, eager=eager)
         if call_errbacks:
             group(
                 [signature(errback, app=app)
@@ -329,6 +326,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         # we want the main variables (I, and R) to stand out visually from the
         # the rest of the variables, so breaking PEP8 is worth it ;)
         R = I = T = Rstr = retval = state = None
+        task_request = None
         time_start = monotonic()
         try:
             try:
@@ -359,11 +357,11 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 except Reject as exc:
                     I, R = Info(REJECTED, exc), ExceptionInfo(internal=True)
                     state, retval = I.state, I.retval
-                    I.handle_reject(task)
+                    I.handle_reject(task, task_request)
                 except Ignore as exc:
                     I, R = Info(IGNORED, exc), ExceptionInfo(internal=True)
                     state, retval = I.state, I.retval
-                    I.handle_ignore(task)
+                    I.handle_ignore(task, task_request)
                 except Retry as exc:
                     I, R, state, retval = on_error(
                         task_request, exc, uuid, RETRY, call_errbacks=False,
@@ -448,6 +446,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
             if eager:
                 raise
             R = report_internal_error(task, exc)
+            if task_request is not None:
+                I, _, _, _ = on_error(task_request, exc, uuid)
         return trace_ok_t(R, I, T, Rstr)
 
     return trace_task
@@ -459,7 +459,7 @@ def trace_task(task, uuid, args, kwargs, request={}, **opts):
             task.__trace__ = build_tracer(task.name, task, **opts)
         return task.__trace__(uuid, args, kwargs, request)
     except Exception as exc:
-        return report_internal_error(task, exc)
+        return trace_ok_t(report_internal_error(task, exc), None, 0.0, None)
 
 
 def _trace_task_ret(name, uuid, request, body, content_type,

+ 1 - 1
celery/tests/app/test_app.py

@@ -432,7 +432,7 @@ class test_App(AppCase):
                              {'foo': 'bar'})
 
     def test_compat_setting_CELERY_BACKEND(self):
-
+        self.app.conf.defaults[0]['CELERY_RESULT_BACKEND'] = None
         self.app.config_from_object(Object(CELERY_BACKEND='set_by_us'))
         self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, 'set_by_us')
 

+ 2 - 2
celery/tests/tasks/test_trace.py

@@ -172,9 +172,9 @@ class test_TraceInfo(TraceCase):
     def test_handle_error_state(self):
         x = self.TI(states.FAILURE)
         x.handle_failure = Mock()
-        x.handle_error_state(self.add_cast)
+        x.handle_error_state(self.add_cast, self.add_cast.request)
         x.handle_failure.assert_called_with(
-            self.add_cast,
+            self.add_cast, self.add_cast.request,
             store_errors=self.add_cast.store_errors_even_if_ignored,
         )
 

+ 12 - 4
celery/tests/worker/test_request.py

@@ -698,11 +698,15 @@ class test_Request(AppCase):
                 raise Retry(str(exc), exc=exc)
             except Retry as exc:
                 w = TraceInfo(states.RETRY, exc)
-                w.handle_retry(self.mytask, store_errors=False)
+                w.handle_retry(
+                    self.mytask, self.mytask.request, store_errors=False,
+                )
                 self.assertEqual(
                     self.mytask.backend.get_status(tid), states.PENDING,
                 )
-                w.handle_retry(self.mytask, store_errors=True)
+                w.handle_retry(
+                    self.mytask, self.mytask.request, store_errors=True,
+                )
                 self.assertEqual(
                     self.mytask.backend.get_status(tid), states.RETRY,
                 )
@@ -718,11 +722,15 @@ class test_Request(AppCase):
                 raise ValueError('foo')
             except Exception as exc:
                 w = TraceInfo(states.FAILURE, exc)
-                w.handle_failure(self.mytask, store_errors=False)
+                w.handle_failure(
+                    self.mytask, self.mytask.request, store_errors=False,
+                )
                 self.assertEqual(
                     self.mytask.backend.get_status(tid), states.PENDING,
                 )
-                w.handle_failure(self.mytask, store_errors=True)
+                w.handle_failure(
+                    self.mytask, self.mytask.request, store_errors=True,
+                )
                 self.assertEqual(
                     self.mytask.backend.get_status(tid), states.FAILURE,
                 )