Selaa lähdekoodia

Optimizes TaskTrace

Ask Solem 13 vuotta sitten
vanhempi
commit
789f53eb49

+ 166 - 74
celery/execute/trace.py

@@ -12,112 +12,199 @@
 """
 from __future__ import absolute_import
 
+import os
+import socket
 import sys
 import traceback
+import warnings
 
+from .. import current_app
 from .. import states, signals
 from ..datastructures import ExceptionInfo
 from ..exceptions import RetryTaskError
 from ..registry import tasks
+from ..utils.serialization import get_pickleable_exception
+
+send_prerun = signals.task_prerun.send
+send_postrun = signals.task_postrun.send
+SUCCESS = states.SUCCESS
+RETRY = states.RETRY
+FAILURE = states.FAILURE
+EXCEPTION_STATES = states.EXCEPTION_STATES
 
 
 class TraceInfo(object):
+    __slots__ = ("state", "retval", "exc_info",
+                 "exc_type", "exc_value", "tb", "strtb")
 
-    def __init__(self, status=states.PENDING, retval=None, exc_info=None):
-        self.status = status
+    def __init__(self, state, retval=None, exc_info=None):
+        self.state = state
         self.retval = retval
         self.exc_info = exc_info
-        self.exc_type = None
-        self.exc_value = None
-        self.tb = None
-        self.strtb = None
-        if self.exc_info:
+        if exc_info:
             self.exc_type, self.exc_value, self.tb = exc_info
-            self.strtb = "\n".join(traceback.format_exception(*exc_info))
+        else:
+            self.exc_type = self.exc_value = self.tb = None
 
-    @classmethod
-    def trace(cls, fun, args, kwargs, propagate=False):
-        """Trace the execution of a function, calling the appropiate callback
-        if the function raises retry, an failure or returned successfully.
+    @property
+    def strtb(self):
+        if self.exc_info:
+            return "\n".join(traceback.format_exception(*self.exc_info))
 
-        :keyword propagate: If true, errors will propagate to the caller.
 
-        """
-        try:
-            return cls(states.SUCCESS, retval=fun(*args, **kwargs))
-        except RetryTaskError, exc:
-            return cls(states.RETRY, retval=exc, exc_info=sys.exc_info())
-        except Exception, exc:
-            if propagate:
-                raise
-            return cls(states.FAILURE, retval=exc, exc_info=sys.exc_info())
-        except BaseException, exc:
+def trace(fun, args, kwargs, propagate=False, Info=TraceInfo):
+    """Trace the execution of a function, calling the appropiate callback
+    if the function raises retry, an failure or returned successfully.
+
+    :keyword propagate: If true, errors will propagate to the caller.
+
+    """
+    try:
+        return Info(states.SUCCESS, retval=fun(*args, **kwargs))
+    except RetryTaskError, exc:
+        return Info(states.RETRY, retval=exc, exc_info=sys.exc_info())
+    except Exception, exc:
+        if propagate:
             raise
-        except:  # pragma: no cover
-            # For Python2.5 where raising strings are still allowed
-            # (but deprecated)
-            if propagate:
-                raise
-            return cls(states.FAILURE, retval=None, exc_info=sys.exc_info())
+        return Info(states.FAILURE, retval=exc, exc_info=sys.exc_info())
+    except BaseException, exc:
+        raise
+    except:  # pragma: no cover
+        # For Python2.5 where raising strings are still allowed
+        # (but deprecated)
+        if propagate:
+            raise
+        return Info(states.FAILURE, retval=None, exc_info=sys.exc_info())
 
 
 class TaskTrace(object):
+    """Wraps the task in a jail, catches all exceptions, and
+    saves the status and result of the task execution to the task
+    meta backend.
+
+    If the call was successful, it saves the result to the task result
+    backend, and sets the task status to `"SUCCESS"`.
+
+    If the call raises :exc:`~celery.exceptions.RetryTaskError`, it extracts
+    the original exception, uses that as the result and sets the task status
+    to `"RETRY"`.
+
+    If the call results in an exception, it saves the exception as the task
+    result, and sets the task status to `"FAILURE"`.
+
+    :param task_name: The name of the task to execute.
+    :param task_id: The unique id of the task.
+    :param args: List of positional args to pass on to the function.
+    :param kwargs: Keyword arguments mapping to pass on to the function.
+
+    :keyword loader: Custom loader to use, if not specified the current app
+      loader will be used.
+    :keyword hostname: Custom hostname to use, if not specified the system
+      hostname will be used.
+
+    :returns: the evaluated functions return value on success, or
+        the exception instance on failure.
+
+    """
 
     def __init__(self, task_name, task_id, args, kwargs, task=None,
-            request=None, propagate=None, **_):
+            request=None, propagate=None, propagate_internal=False,
+            eager=False, **_):
         self.task_id = task_id
         self.task_name = task_name
         self.args = args
         self.kwargs = kwargs
         self.task = task or tasks[self.task_name]
         self.request = request or {}
-        self.status = states.PENDING
-        self.strtb = None
         self.propagate = propagate
-        self._trace_handlers = {states.FAILURE: self.handle_failure,
-                                states.RETRY: self.handle_retry,
-                                states.SUCCESS: self.handle_success}
+        self.propagate_internal = propagate_internal
+        self.eager = eager
+        self._trace_handlers = {FAILURE: self.handle_failure,
+                                RETRY: self.handle_retry}
+        self.loader = kwargs.get("loader") or current_app.loader
+        self.hostname = kwargs.get("hostname") or socket.gethostname()
+        self._store_errors = True
+        if eager:
+            self._store_errors = False
+        elif self.task.ignore_result:
+            self._store_errors = self.task.store_errors_even_if_ignored
+
+        # Used by always_eager
+        self.state = None
+        self.strtb = None
 
     def __call__(self):
-        return self.execute()
-
-    def execute(self):
-        self.task.request.update(self.request, args=self.args,
-                                 called_directly=False, kwargs=self.kwargs)
-        signals.task_prerun.send(sender=self.task, task_id=self.task_id,
-                                 task=self.task, args=self.args,
-                                 kwargs=self.kwargs)
-        retval = self._trace()
-
-        signals.task_postrun.send(sender=self.task, task_id=self.task_id,
-                                  task=self.task, args=self.args,
-                                  kwargs=self.kwargs, retval=retval)
-        self.task.request.clear()
-        return retval
-
-    def _trace(self):
-        trace = TraceInfo.trace(self.task, self.args, self.kwargs,
-                                propagate=self.propagate)
-        self.status = trace.status
-        self.strtb = trace.strtb
-        handler = self._trace_handlers[trace.status]
-        r = handler(trace.retval, trace.exc_type, trace.tb, trace.strtb)
-        self.handle_after_return(trace.status, trace.retval,
-                                 trace.exc_type, trace.tb, trace.strtb,
-                                 einfo=trace.exc_info)
-        return r
-
-    def handle_after_return(self, status, retval, type_, tb, strtb,
-            einfo=None):
-        if status in states.EXCEPTION_STATES:
-            einfo = ExceptionInfo(einfo)
-        self.task.after_return(status, retval, self.task_id,
-                               self.args, self.kwargs, einfo)
-
-    def handle_success(self, retval, *args):
-        """Handle successful execution."""
-        self.task.on_success(retval, self.task_id, self.args, self.kwargs)
-        return retval
+        try:
+            task, uuid, args, kwargs, eager = (self.task, self.task_id,
+                                               self.args, self.kwargs,
+                                               self.eager)
+            backend = task.backend
+            ignore_result = task.ignore_result
+            loader = self.loader
+
+            task.request.update(self.request, args=args,
+                                called_directly=False, kwargs=kwargs)
+            try:
+                # -*- PRE -*-
+                send_prerun(sender=task, task_id=uuid, task=task,
+                            args=args, kwargs=kwargs)
+                loader.on_task_init(uuid, task)
+                if not eager and (task.track_started and not ignore_result):
+                    backend.mark_as_started(id, pid=os.getpid(),
+                                            hostname=self.hostname)
+
+                # -*- TRACE -*-
+                # self.info is used by always_eager
+                info = self.info = trace(task, args, kwargs, self.propagate)
+                state, retval, einfo = info.state, info.retval, info.exc_info
+                if eager:
+                    self.state, self.strtb = info.state, info.strtb
+
+                if state == SUCCESS:
+                    task.on_success(retval, uuid, args, kwargs)
+                    if not eager and not ignore_result:
+                        backend.mark_as_done(uuid, retval)
+                    R = retval
+                else:
+                    R = self.handle_trace(info)
+
+                # -* POST *-
+                if state in EXCEPTION_STATES:
+                    einfo = ExceptionInfo(einfo)
+                task.after_return(state, retval, self.task_id,
+                                  self.args, self.kwargs, einfo)
+                send_postrun(sender=task, task_id=uuid, task=task,
+                             args=args, kwargs=kwargs, retval=retval)
+            finally:
+                task.request.clear()
+                if not eager:
+                    try:
+                        backend.process_cleanup()
+                        loader.on_process_cleanup()
+                    except (KeyboardInterrupt, SystemExit, MemoryError):
+                        raise
+                    except Exception, exc:
+                        logger = current_app.log.get_default_logger()
+                        logger.error("Process cleanup failed: %r", exc,
+                                    exc_info=sys.exc_info())
+        except Exception, exc:
+            if self.propagate_internal:
+                raise
+            return self.report_internal_error(exc)
+        return R
+    execute = __call__
+
+    def report_internal_error(self, exc):
+        _type, _value, _tb = sys.exc_info()
+        _value = self.task.backend.prepare_exception(exc)
+        exc_info = ExceptionInfo((_type, _value, _tb))
+        warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
+            map(str, (exc.__class__, exc, exc_info.traceback))))
+        return exc_info
+
+    def handle_trace(self, trace):
+        return self._trace_handlers[trace.state](trace.retval, trace.exc_type,
+                                                 trace.tb, trace.strtb)
 
     def handle_retry(self, exc, type_, tb, strtb):
         """Handle retry exception."""
@@ -126,6 +213,8 @@ class TaskTrace(object):
         # This is for reporting the retry in logs, email etc, while
         # guaranteeing pickleability.
         message, orig_exc = exc.args
+        if self._store_errors:
+            self.task.backend.mark_as_retry(self.task_id, orig_exc, strtb)
         expanded_msg = "%s: %s" % (message, str(orig_exc))
         einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
         self.task.on_retry(exc, self.task_id, self.args, self.kwargs, einfo)
@@ -133,6 +222,9 @@ class TaskTrace(object):
 
     def handle_failure(self, exc, type_, tb, strtb):
         """Handle exception."""
+        if self._store_errors:
+            self.task.backend.mark_as_failure(self.task_id, exc, strtb)
+        exc = get_pickleable_exception(exc)
         einfo = ExceptionInfo((type_, exc, tb))
         self.task.on_failure(exc, self.task_id, self.args, self.kwargs, einfo)
         signals.task_failure.send(sender=self.task, task_id=self.task_id,

+ 5 - 8
celery/tests/test_task/test_execute_trace.py

@@ -5,21 +5,18 @@ import operator
 
 from celery import states
 from celery.exceptions import RetryTaskError
-from celery.execute.trace import TraceInfo
+from celery.execute.trace import trace
 from celery.tests.utils import unittest
 
-trace = TraceInfo.trace
-
-
 def raises(exc):
     raise exc
 
 
-class test_TraceInfo(unittest.TestCase):
+class test_trace(unittest.TestCase):
 
     def test_trace_successful(self):
         info = trace(operator.add, (2, 2), {})
-        self.assertEqual(info.status, states.SUCCESS)
+        self.assertEqual(info.state, states.SUCCESS)
         self.assertEqual(info.retval, 4)
 
     def test_trace_SystemExit(self):
@@ -29,13 +26,13 @@ class test_TraceInfo(unittest.TestCase):
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError("foo", "bar")
         info = trace(raises, (exc, ), {})
-        self.assertEqual(info.status, states.RETRY)
+        self.assertEqual(info.state, states.RETRY)
         self.assertIs(info.retval, exc)
 
     def test_trace_exception(self):
         exc = KeyError("foo")
         info = trace(raises, (exc, ), {})
-        self.assertEqual(info.status, states.FAILURE)
+        self.assertEqual(info.state, states.FAILURE)
         self.assertIs(info.retval, exc)
 
     def test_trace_exception_propagate(self):

+ 9 - 8
celery/tests/test_worker/test_worker_job.py

@@ -21,12 +21,13 @@ from celery.datastructures import ExceptionInfo
 from celery.task import task as task_dec
 from celery.exceptions import (RetryTaskError, NotRegistered,
                                WorkerLostError, InvalidTaskError)
+from celery.execute.trace import TaskTrace
 from celery.log import setup_logger
 from celery.result import AsyncResult
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils.encoding import from_utf8, default_encode
-from celery.worker.job import WorkerTaskTrace, TaskRequest, execute_and_trace
+from celery.worker.job import TaskRequest, execute_and_trace
 from celery.worker.state import revoked
 
 from celery.tests.compat import catch_warnings
@@ -39,7 +40,7 @@ some_kwargs_scratchpad = {}
 
 
 def jail(task_id, task_name, args, kwargs):
-    return WorkerTaskTrace(task_name, task_id, args, kwargs)()
+    return TaskTrace(task_name, task_id, args, kwargs)()
 
 
 def on_ack():
@@ -108,7 +109,7 @@ class test_RetryTaskError(unittest.TestCase):
             self.assertEqual(ret.exc, exc)
 
 
-class test_WorkerTaskTrace(unittest.TestCase):
+class test_TaskTrace(unittest.TestCase):
 
     def test_process_cleanup_fails(self):
         backend = mytask.backend
@@ -458,12 +459,12 @@ class test_TaskRequest(unittest.TestCase):
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):
-        old_exec = WorkerTaskTrace.execute
+        old_exec = TaskTrace.__call__
 
         def _error_exec(self, *args, **kwargs):
             raise KeyError("baz")
 
-        WorkerTaskTrace.execute = _error_exec
+        TaskTrace.__call__ = _error_exec
         try:
             with catch_warnings(record=True) as log:
                 res = execute_and_trace(mytask.name, uuid(),
@@ -473,7 +474,7 @@ class test_TaskRequest(unittest.TestCase):
                 self.assertIn("Exception outside", log[0].message.args[0])
                 self.assertIn("KeyError", log[0].message.args[0])
         finally:
-            WorkerTaskTrace.execute = old_exec
+            TaskTrace.__call__ = old_exec
 
     def create_exception(self, exc):
         try:
@@ -484,7 +485,7 @@ class test_TaskRequest(unittest.TestCase):
     def test_worker_task_trace_handle_retry(self):
         from celery.exceptions import RetryTaskError
         tid = uuid()
-        w = WorkerTaskTrace(mytask.name, tid, [4], {})
+        w = TaskTrace(mytask.name, tid, [4], {})
         type_, value_, tb_ = self.create_exception(ValueError("foo"))
         type_, value_, tb_ = self.create_exception(RetryTaskError(str(value_),
                                                                   exc=value_))
@@ -497,7 +498,7 @@ class test_TaskRequest(unittest.TestCase):
 
     def test_worker_task_trace_handle_failure(self):
         tid = uuid()
-        w = WorkerTaskTrace(mytask.name, tid, [4], {})
+        w = TaskTrace(mytask.name, tid, [4], {})
         type_, value_, tb_ = self.create_exception(ValueError("foo"))
         w._store_errors = False
         w.handle_failure(value_, type_, tb_, "")

+ 9 - 115
celery/worker/job.py

@@ -13,25 +13,19 @@
 from __future__ import absolute_import
 
 import logging
-import os
-import sys
 import time
 import socket
-import warnings
 
 from datetime import datetime
 
-from .. import current_app
 from .. import exceptions
 from .. import registry
 from ..app import app_or_default
-from ..datastructures import ExceptionInfo
 from ..execute.trace import TaskTrace
 from ..platforms import set_mp_process_title as setps
 from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
-from ..utils.encoding import safe_repr, safe_str, default_encode
+from ..utils.encoding import safe_repr, safe_str
 from ..utils.timeutils import maybe_iso8601, timezone
-from ..utils.serialization import get_pickleable_exception
 
 from . import state
 
@@ -40,118 +34,18 @@ from . import state
 WANTED_DELIVERY_INFO = ("exchange", "routing_key", "consumer_tag", )
 
 
-class WorkerTaskTrace(TaskTrace):
-    """Wraps the task in a jail, catches all exceptions, and
-    saves the status and result of the task execution to the task
-    meta backend.
-
-    If the call was successful, it saves the result to the task result
-    backend, and sets the task status to `"SUCCESS"`.
-
-    If the call raises :exc:`~celery.exceptions.RetryTaskError`, it extracts
-    the original exception, uses that as the result and sets the task status
-    to `"RETRY"`.
-
-    If the call results in an exception, it saves the exception as the task
-    result, and sets the task status to `"FAILURE"`.
-
-    :param task_name: The name of the task to execute.
-    :param task_id: The unique id of the task.
-    :param args: List of positional args to pass on to the function.
-    :param kwargs: Keyword arguments mapping to pass on to the function.
-
-    :keyword loader: Custom loader to use, if not specified the current app
-      loader will be used.
-    :keyword hostname: Custom hostname to use, if not specified the system
-      hostname will be used.
-
-    :returns: the evaluated functions return value on success, or
-        the exception instance on failure.
-
-    """
-
-    #: Current loader.
-    loader = None
-
-    #: Hostname to report as.
-    hostname = None
-
-    def __init__(self, *args, **kwargs):
-        self.loader = kwargs.get("loader") or current_app.loader
-        self.hostname = kwargs.get("hostname") or socket.gethostname()
-        super(WorkerTaskTrace, self).__init__(*args, **kwargs)
-
-        self._store_errors = True
-        if self.task.ignore_result:
-            self._store_errors = self.task.store_errors_even_if_ignored
-        self.super = super(WorkerTaskTrace, self)
-
-    def execute_safe(self, *args, **kwargs):
-        """Same as :meth:`execute`, but catches errors."""
-        try:
-            return self.execute(*args, **kwargs)
-        except Exception, exc:
-            _type, _value, _tb = sys.exc_info()
-            _value = self.task.backend.prepare_exception(exc)
-            exc_info = ExceptionInfo((_type, _value, _tb))
-            warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
-                map(str, (exc.__class__, exc, exc_info.traceback))))
-            return exc_info
-
-    def execute(self):
-        """Execute, trace and store the result of the task."""
-        self.loader.on_task_init(self.task_id, self.task)
-        if self.task.track_started:
-            if not self.task.ignore_result:
-                self.task.backend.mark_as_started(self.task_id,
-                                                  pid=os.getpid(),
-                                                  hostname=self.hostname)
-        try:
-            return super(WorkerTaskTrace, self).execute()
-        finally:
-            try:
-                self.task.backend.process_cleanup()
-                self.loader.on_process_cleanup()
-            except (KeyboardInterrupt, SystemExit, MemoryError):
-                raise
-            except Exception, exc:
-                logger = current_app.log.get_default_logger()
-                logger.error("Process cleanup failed: %r", exc,
-                             exc_info=sys.exc_info())
-
-    def handle_success(self, retval, *args):
-        """Handle successful execution."""
-        if not self.task.ignore_result:
-            self.task.backend.mark_as_done(self.task_id, retval)
-        return self.super.handle_success(retval, *args)
-
-    def handle_retry(self, exc, type_, tb, strtb):
-        """Handle retry exception."""
-        message, orig_exc = exc.args
-        if self._store_errors:
-            self.task.backend.mark_as_retry(self.task_id, orig_exc, strtb)
-        return self.super.handle_retry(exc, type_, tb, strtb)
-
-    def handle_failure(self, exc, type_, tb, strtb):
-        """Handle exception."""
-        if self._store_errors:
-            self.task.backend.mark_as_failure(self.task_id, exc, strtb)
-        exc = get_pickleable_exception(exc)
-        return self.super.handle_failure(exc, type_, tb, strtb)
-
-
 def execute_and_trace(task_name, *args, **kwargs):
     """This is a pickleable method used as a target when applying to pools.
 
     It's the same as::
 
-        >>> WorkerTaskTrace(task_name, *args, **kwargs).execute_safe()
+        >>> TaskTrace(task_name, *args, **kwargs)()
 
     """
     hostname = kwargs.get("hostname")
     setps("celeryd", task_name, hostname, rate_limit=True)
     try:
-        return WorkerTaskTrace(task_name, *args, **kwargs).execute_safe()
+        return TaskTrace(task_name, *args, **kwargs)()
     finally:
         setps("celeryd", "-idle-", hostname, rate_limit=True)
 
@@ -363,7 +257,7 @@ class TaskRequest(object):
         return result
 
     def execute(self, loglevel=None, logfile=None):
-        """Execute the task in a :class:`WorkerTaskTrace`.
+        """Execute the task in a :class:`TaskTrace`.
 
         :keyword loglevel: The loglevel used by the task.
 
@@ -378,10 +272,10 @@ class TaskRequest(object):
             self.acknowledge()
 
         instance_attrs = self.get_instance_attrs(loglevel, logfile)
-        tracer = WorkerTaskTrace(*self._get_tracer_args(loglevel, logfile),
-                                 **{"hostname": self.hostname,
-                                    "loader": self.app.loader,
-                                    "request": instance_attrs})
+        tracer = TaskTrace(*self._get_tracer_args(loglevel, logfile),
+                           **{"hostname": self.hostname,
+                              "loader": self.app.loader,
+                              "request": instance_attrs})
         retval = tracer.execute()
         self.acknowledge()
         return retval
@@ -552,6 +446,6 @@ class TaskRequest(object):
                 self.task_name, self.task_id, self.args, self.kwargs)
 
     def _get_tracer_args(self, loglevel=None, logfile=None):
-        """Get the :class:`WorkerTaskTrace` tracer for this task."""
+        """Get the :class:`TaskTrace` tracer for this task."""
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         return self.task_name, self.task_id, self.args, task_func_kwargs