Browse Source

Optimization: celeryd now handles twice as many messages per/s.

Ask Solem 13 years ago
parent
commit
76bb45c79a

+ 6 - 4
celery/app/task/__init__.py

@@ -15,6 +15,7 @@ from __future__ import absolute_import
 import sys
 import sys
 import threading
 import threading
 
 
+from ... import states
 from ...datastructures import ExceptionInfo
 from ...datastructures import ExceptionInfo
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...execute.trace import trace_task
 from ...execute.trace import trace_task
@@ -621,8 +622,10 @@ class BaseTask(object):
                                   task=task, request=request, propagate=throw)
                                   task=task, request=request, propagate=throw)
         if isinstance(retval, ExceptionInfo):
         if isinstance(retval, ExceptionInfo):
             retval = retval.exception
             retval = retval.exception
-        return EagerResult(task_id, retval, info.state,
-                           traceback=info.strtb)
+        state, tb = states.SUCCESS, ''
+        if info is not None:
+            state, tb = info.state, info.strtb
+        return EagerResult(task_id, retval, state, traceback=tb)
 
 
     @classmethod
     @classmethod
     def AsyncResult(self, task_id):
     def AsyncResult(self, task_id):
@@ -680,8 +683,7 @@ class BaseTask(object):
         The return value of this handler is ignored.
         The return value of this handler is ignored.
 
 
         """
         """
-        if self.request.chord:
-            self.backend.on_chord_part_return(self)
+        pass
 
 
     def on_failure(self, exc, task_id, args, kwargs, einfo):
     def on_failure(self, exc, task_id, args, kwargs, einfo):
         """Error handler.
         """Error handler.

+ 138 - 150
celery/execute/trace.py

@@ -12,6 +12,14 @@
 """
 """
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
+# ## ---
+# BE WARNED: You are probably going to suffer a heartattack just
+#            by looking at this code!
+#
+# This is the heart of the worker, the inner loop so to speak.
+# It used to be split up into nice little classes and methods,
+# but in the end it only resulted in bad performance, and horrible tracebacks.
+
 import os
 import os
 import socket
 import socket
 import sys
 import sys
@@ -26,11 +34,21 @@ from ..registry import tasks
 from ..utils.serialization import get_pickleable_exception
 from ..utils.serialization import get_pickleable_exception
 
 
 send_prerun = signals.task_prerun.send
 send_prerun = signals.task_prerun.send
+prerun_receivers = signals.task_prerun.receivers
 send_postrun = signals.task_postrun.send
 send_postrun = signals.task_postrun.send
+postrun_receivers = signals.task_postrun.receivers
 SUCCESS = states.SUCCESS
 SUCCESS = states.SUCCESS
 RETRY = states.RETRY
 RETRY = states.RETRY
 FAILURE = states.FAILURE
 FAILURE = states.FAILURE
 EXCEPTION_STATES = states.EXCEPTION_STATES
 EXCEPTION_STATES = states.EXCEPTION_STATES
+_pid = None
+
+
+def getpid():
+    global _pid
+    if _pid is None:
+        _pid = os.getpid()
+    return _pid
 
 
 
 
 class TraceInfo(object):
 class TraceInfo(object):
@@ -46,38 +64,56 @@ class TraceInfo(object):
         else:
         else:
             self.exc_type = self.exc_value = self.tb = None
             self.exc_type = self.exc_value = self.tb = None
 
 
-    @property
-    def strtb(self):
-        if self.exc_info:
-            return "\n".join(traceback.format_exception(*self.exc_info))
-
+    def handle_error_state(self, task, eager=False):
+        store_errors = not eager
+        if task.ignore_result:
+            store_errors = task.store_errors_even_if_ignored
 
 
-def trace(fun, args, kwargs, propagate=False, Info=TraceInfo):
-    """Trace the execution of a function, calling the appropiate callback
-    if the function raises retry, an failure or returned successfully.
+        return {
+            RETRY: self.handle_retry,
+            FAILURE: self.handle_failure,
+        }[self.state](task, store_errors=store_errors)
 
 
-    :keyword propagate: If true, errors will propagate to the caller.
+    def handle_retry(self, task, store_errors=True):
+        """Handle retry exception."""
+        # Create a simpler version of the RetryTaskError that stringifies
+        # the original exception instead of including the exception instance.
+        # This is for reporting the retry in logs, email etc, while
+        # guaranteeing pickleability.
+        req = task.request
+        exc, type_, tb = self.retval, self.exc_type, self.tb
+        message, orig_exc = self.retval.args
+        if store_errors:
+            task.backend.mark_as_retry(req.id, orig_exc, self.strtb)
+        expanded_msg = "%s: %s" % (message, str(orig_exc))
+        einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
+        task.on_retry(exc, req.id, req.args, req.kwargs, einfo)
+        return einfo
 
 
-    """
-    try:
-        return Info(states.SUCCESS, retval=fun(*args, **kwargs))
-    except RetryTaskError, exc:
-        return Info(states.RETRY, retval=exc, exc_info=sys.exc_info())
-    except Exception, exc:
-        if propagate:
-            raise
-        return Info(states.FAILURE, retval=exc, exc_info=sys.exc_info())
-    except BaseException, exc:
-        raise
-    except:  # pragma: no cover
-        # For Python2.5 where raising strings are still allowed
-        # (but deprecated)
-        if propagate:
-            raise
-        return Info(states.FAILURE, retval=None, exc_info=sys.exc_info())
+    def handle_failure(self, task, store_errors=True):
+        """Handle exception."""
+        req = task.request
+        exc, type_, tb = self.retval, self.exc_type, self.tb
+        if store_errors:
+            task.backend.mark_as_failure(req.id, exc, self.strtb)
+        exc = get_pickleable_exception(exc)
+        einfo = ExceptionInfo((type_, exc, tb))
+        task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
+        signals.task_failure.send(sender=task, task_id=req.id,
+                                  exception=exc, args=req.args,
+                                  kwargs=req.kwargs, traceback=tb,
+                                  einfo=einfo)
+        return einfo
 
 
+    @property
+    def strtb(self):
+        if self.exc_info:
+            return '\n'.join(traceback.format_exception(*self.exc_info))
+        return ''
 
 
-class TaskTrace(object):
+def trace_task(name, uuid, args, kwargs, task=None, request=None,
+        eager=False, propagate=False, loader=None, hostname=None,
+        store_errors=True, Info=TraceInfo):
     """Wraps the task in a jail, catches all exceptions, and
     """Wraps the task in a jail, catches all exceptions, and
     saves the status and result of the task execution to the task
     saves the status and result of the task execution to the task
     meta backend.
     meta backend.
@@ -104,131 +140,83 @@ class TaskTrace(object):
 
 
     :returns: the evaluated functions return value on success, or
     :returns: the evaluated functions return value on success, or
         the exception instance on failure.
         the exception instance on failure.
-
     """
     """
-
-    def __init__(self, task_name, task_id, args, kwargs, task=None,
-            request=None, propagate=None, propagate_internal=False,
-            eager=False, **_):
-        self.task_id = task_id
-        self.task_name = task_name
-        self.args = args
-        self.kwargs = kwargs
-        self.task = task or tasks[self.task_name]
-        self.request = request or {}
-        self.propagate = propagate
-        self.propagate_internal = propagate_internal
-        self.eager = eager
-        self._trace_handlers = {FAILURE: self.handle_failure,
-                                RETRY: self.handle_retry}
-        self.loader = kwargs.get("loader") or current_app.loader
-        self.hostname = kwargs.get("hostname") or socket.gethostname()
-        self._store_errors = True
-        if eager:
-            self._store_errors = False
-        elif self.task.ignore_result:
-            self._store_errors = self.task.store_errors_even_if_ignored
-
-        # Used by always_eager
-        self.state = None
-        self.strtb = None
-
-    def __call__(self):
+    R = I = None
+    try:
+        task = task or tasks[name]
+        backend = task.backend
+        ignore_result = task.ignore_result
+        loader = loader or current_app.loader
+        hostname = hostname or socket.gethostname()
+        task.request.update(request or {}, args=args,
+                            called_directly=False, kwargs=kwargs)
         try:
         try:
-            task, uuid, args, kwargs, eager = (self.task, self.task_id,
-                                               self.args, self.kwargs,
-                                               self.eager)
-            backend = task.backend
-            ignore_result = task.ignore_result
-            loader = self.loader
-
-            task.request.update(self.request, args=args,
-                                called_directly=False, kwargs=kwargs)
+            # -*- PRE -*-
+            send_prerun(sender=task, task_id=uuid, task=task,
+                        args=args, kwargs=kwargs)
+            loader.on_task_init(uuid, task)
+            if not eager and (task.track_started and not ignore_result):
+                backend.mark_as_started(uuid, pid=getpid(),
+                                        hostname=hostname)
+
+            # -*- TRACE -*-
             try:
             try:
-                # -*- PRE -*-
-                send_prerun(sender=task, task_id=uuid, task=task,
-                            args=args, kwargs=kwargs)
-                loader.on_task_init(uuid, task)
-                if not eager and (task.track_started and not ignore_result):
-                    backend.mark_as_started(uuid, pid=os.getpid(),
-                                            hostname=self.hostname)
-
-                # -*- TRACE -*-
-                # self.info is used by always_eager
-                info = self.info = trace(task, args, kwargs, self.propagate)
-                state, retval, einfo = info.state, info.retval, info.exc_info
-                if eager:
-                    self.state, self.strtb = info.state, info.strtb
-
-                if state == SUCCESS:
-                    task.on_success(retval, uuid, args, kwargs)
-                    if not eager and not ignore_result:
-                        backend.mark_as_done(uuid, retval)
-                    R = retval
-                else:
-                    R = self.handle_trace(info)
-
-                # -* POST *-
-                if state in EXCEPTION_STATES:
-                    einfo = ExceptionInfo(einfo)
-                task.after_return(state, retval, self.task_id,
-                                  self.args, self.kwargs, einfo)
-                send_postrun(sender=task, task_id=uuid, task=task,
-                             args=args, kwargs=kwargs, retval=retval)
-            finally:
-                task.request.clear()
-                if not eager:
-                    try:
-                        backend.process_cleanup()
-                        loader.on_process_cleanup()
-                    except (KeyboardInterrupt, SystemExit, MemoryError):
-                        raise
-                    except Exception, exc:
-                        logger = current_app.log.get_default_logger()
-                        logger.error("Process cleanup failed: %r", exc,
-                                    exc_info=sys.exc_info())
-        except Exception, exc:
-            if self.propagate_internal:
+                R = retval = task(*args, **kwargs)
+                state, einfo = SUCCESS, None
+                task.on_success(retval, uuid, args, kwargs)
+                if not eager and not ignore_result:
+                    backend.mark_as_done(uuid, retval)
+            except RetryTaskError, exc:
+                I = Info(RETRY, exc, sys.exc_info())
+                state, retval, einfo = I.state, I.retval, I.exc_info
+                R = I.handle_error_state(task, eager=eager)
+            except Exception, exc:
+                if propagate:
+                    raise
+                I = Info(FAILURE, exc, sys.exc_info())
+                state, retval, einfo = I.state, I.retval, I.exc_info
+                R = I.handle_error_state(task, eager=eager)
+            except BaseException, exc:
                 raise
                 raise
-            return self.report_internal_error(exc)
-        return R
-    execute = __call__
-
-    def report_internal_error(self, exc):
-        _type, _value, _tb = sys.exc_info()
-        _value = self.task.backend.prepare_exception(exc)
-        exc_info = ExceptionInfo((_type, _value, _tb))
-        warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
-            map(str, (exc.__class__, exc, exc_info.traceback))))
-        return exc_info
-
-    def handle_trace(self, trace):
-        return self._trace_handlers[trace.state](trace.retval, trace.exc_type,
-                                                 trace.tb, trace.strtb)
-
-    def handle_retry(self, exc, type_, tb, strtb):
-        """Handle retry exception."""
-        # Create a simpler version of the RetryTaskError that stringifies
-        # the original exception instead of including the exception instance.
-        # This is for reporting the retry in logs, email etc, while
-        # guaranteeing pickleability.
-        message, orig_exc = exc.args
-        if self._store_errors:
-            self.task.backend.mark_as_retry(self.task_id, orig_exc, strtb)
-        expanded_msg = "%s: %s" % (message, str(orig_exc))
-        einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
-        self.task.on_retry(exc, self.task_id, self.args, self.kwargs, einfo)
-        return einfo
+            except:
+                # pragma: no cover
+                # For Python2.5 where raising strings are still allowed
+                # (but deprecated)
+                if propagate:
+                    raise
+                I = Info(FAILURE, None, sys.exc_info())
+                state, retval, einfo = I.state, I.retval, I.exc_info
+                R = I.handle_error_state(task, eager=eager)
+
+            # -* POST *-
+            if task.request.chord:
+                backend.on_chord_part_return(task)
+            task.after_return(state, retval, uuid, args, kwargs, einfo)
+            send_postrun(sender=task, task_id=uuid, task=task,
+                         args=args, kwargs=kwargs, retval=retval)
+        finally:
+            task.request.clear()
+            if not eager:
+                try:
+                    backend.process_cleanup()
+                    loader.on_process_cleanup()
+                except (KeyboardInterrupt, SystemExit, MemoryError):
+                    raise
+                except Exception, exc:
+                    logger = current_app.log.get_default_logger()
+                    logger.error("Process cleanup failed: %r", exc,
+                                 exc_info=sys.exc_info())
+    except Exception, exc:
+        if eager:
+            raise
+        R = report_internal_error(task, exc)
+    return R, I
 
 
-    def handle_failure(self, exc, type_, tb, strtb):
-        """Handle exception."""
-        if self._store_errors:
-            self.task.backend.mark_as_failure(self.task_id, exc, strtb)
-        exc = get_pickleable_exception(exc)
-        einfo = ExceptionInfo((type_, exc, tb))
-        self.task.on_failure(exc, self.task_id, self.args, self.kwargs, einfo)
-        signals.task_failure.send(sender=self.task, task_id=self.task_id,
-                                  exception=exc, args=self.args,
-                                  kwargs=self.kwargs, traceback=tb,
-                                  einfo=einfo)
-        return einfo
+
+def report_internal_error(task, exc):
+    _type, _value, _tb = sys.exc_info()
+    _value = task.backend.prepare_exception(exc)
+    exc_info = ExceptionInfo((_type, _value, _tb))
+    warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
+        map(str, (exc.__class__, exc, exc_info.traceback))))
+    return exc_info

+ 14 - 10
celery/log.py

@@ -24,6 +24,12 @@ from .utils.term import colored
 is_py3k = sys.version_info >= (3, 0)
 is_py3k = sys.version_info >= (3, 0)
 
 
 
 
+def mlevel(level):
+    if level and not isinstance(level, int):
+        return LOG_LEVELS[level.upper()]
+    return level
+
+
 class ColorFormatter(logging.Formatter):
 class ColorFormatter(logging.Formatter):
     #: Loglevel -> Color mapping.
     #: Loglevel -> Color mapping.
     COLORS = colored().names
     COLORS = colored().names
@@ -71,7 +77,7 @@ class Logging(object):
 
 
     def __init__(self, app):
     def __init__(self, app):
         self.app = app
         self.app = app
-        self.loglevel = self.app.conf.CELERYD_LOG_LEVEL
+        self.loglevel = mlevel(self.app.conf.CELERYD_LOG_LEVEL)
         self.format = self.app.conf.CELERYD_LOG_FORMAT
         self.format = self.app.conf.CELERYD_LOG_FORMAT
         self.task_format = self.app.conf.CELERYD_TASK_LOG_FORMAT
         self.task_format = self.app.conf.CELERYD_TASK_LOG_FORMAT
         self.colorize = self.app.conf.CELERYD_LOG_COLOR
         self.colorize = self.app.conf.CELERYD_LOG_COLOR
@@ -92,14 +98,14 @@ class Logging(object):
     def get_task_logger(self, loglevel=None, name=None):
     def get_task_logger(self, loglevel=None, name=None):
         logger = logging.getLogger(name or "celery.task.default")
         logger = logging.getLogger(name or "celery.task.default")
         if loglevel is not None:
         if loglevel is not None:
-            logger.setLevel(loglevel)
+            logger.setLevel(mlevel(loglevel))
         return logger
         return logger
 
 
     def setup_logging_subsystem(self, loglevel=None, logfile=None,
     def setup_logging_subsystem(self, loglevel=None, logfile=None,
             format=None, colorize=None, **kwargs):
             format=None, colorize=None, **kwargs):
         if Logging._setup:
         if Logging._setup:
             return
             return
-        loglevel = loglevel or self.loglevel
+        loglevel = mlevel(loglevel or self.loglevel)
         format = format or self.format
         format = format or self.format
         if colorize is None:
         if colorize is None:
             colorize = self.supports_color(logfile)
             colorize = self.supports_color(logfile)
@@ -120,7 +126,7 @@ class Logging(object):
             mp = mputil.get_logger() if mputil else None
             mp = mputil.get_logger() if mputil else None
             for logger in filter(None, (root, mp)):
             for logger in filter(None, (root, mp)):
                 self._setup_logger(logger, logfile, format, colorize, **kwargs)
                 self._setup_logger(logger, logfile, format, colorize, **kwargs)
-                logger.setLevel(loglevel)
+                logger.setLevel(mlevel(loglevel))
                 signals.after_setup_logger.send(sender=None, logger=logger,
                 signals.after_setup_logger.send(sender=None, logger=logger,
                                         loglevel=loglevel, logfile=logfile,
                                         loglevel=loglevel, logfile=logfile,
                                         format=format, colorize=colorize)
                                         format=format, colorize=colorize)
@@ -144,7 +150,7 @@ class Logging(object):
         """
         """
         logger = logging.getLogger(name)
         logger = logging.getLogger(name)
         if loglevel is not None:
         if loglevel is not None:
-            logger.setLevel(loglevel)
+            logger.setLevel(mlevel(loglevel))
         return logger
         return logger
 
 
     def setup_logger(self, loglevel=None, logfile=None,
     def setup_logger(self, loglevel=None, logfile=None,
@@ -157,7 +163,7 @@ class Logging(object):
         Returns logger object.
         Returns logger object.
 
 
         """
         """
-        loglevel = loglevel or self.loglevel
+        loglevel = mlevel(loglevel or self.loglevel)
         format = format or self.format
         format = format or self.format
         if colorize is None:
         if colorize is None:
             colorize = self.supports_color(logfile)
             colorize = self.supports_color(logfile)
@@ -179,7 +185,7 @@ class Logging(object):
         Returns logger object.
         Returns logger object.
 
 
         """
         """
-        loglevel = loglevel or self.loglevel
+        loglevel = mlevel(loglevel or self.loglevel)
         format = format or self.task_format
         format = format or self.task_format
         if colorize is None:
         if colorize is None:
             colorize = self.supports_color(logfile)
             colorize = self.supports_color(logfile)
@@ -247,9 +253,7 @@ class LoggingProxy(object):
 
 
     def __init__(self, logger, loglevel=None):
     def __init__(self, logger, loglevel=None):
         self.logger = logger
         self.logger = logger
-        self.loglevel = loglevel or self.logger.level or self.loglevel
-        if not isinstance(self.loglevel, int):
-            self.loglevel = LOG_LEVELS[self.loglevel.upper()]
+        self.loglevel = mlevel(loglevel or self.logger.level or self.loglevel)
         self._safewrap_handlers()
         self._safewrap_handlers()
 
 
     def _safewrap_handlers(self):
     def _safewrap_handlers(self):

+ 0 - 2
celery/tests/test_task/__init__.py

@@ -343,10 +343,8 @@ class TestCeleryTasks(unittest.TestCase):
 
 
     def test_after_return(self):
     def test_after_return(self):
         task = self.createTaskCls("T1", "c.unittest.t.after_return")()
         task = self.createTaskCls("T1", "c.unittest.t.after_return")()
-        task.backend = Mock()
         task.request.chord = return_True_task.subtask()
         task.request.chord = return_True_task.subtask()
         task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
         task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
-        task.backend.on_chord_part_return.assert_called_with(task)
         task.request.clear()
         task.request.clear()
 
 
     def test_send_task_sent_event(self):
     def test_send_task_sent_event(self):

+ 19 - 6
celery/tests/test_task/test_execute_trace.py

@@ -3,21 +3,34 @@ from __future__ import with_statement
 
 
 import operator
 import operator
 
 
+from celery import current_app
 from celery import states
 from celery import states
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
-from celery.execute.trace import trace
+from celery.execute.trace import trace_task
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
 
 
+@current_app.task
+def add(x, y):
+    return x + y
+
+
+@current_app.task
 def raises(exc):
 def raises(exc):
     raise exc
     raise exc
 
 
 
 
+def trace(task, args=(), kwargs={}, propagate=False):
+    return trace_task(task.__name__, "id-1", args, kwargs, task,
+                      propagate=propagate, eager=True)
+
+
+
 class test_trace(unittest.TestCase):
 class test_trace(unittest.TestCase):
 
 
     def test_trace_successful(self):
     def test_trace_successful(self):
-        info = trace(operator.add, (2, 2), {})
-        self.assertEqual(info.state, states.SUCCESS)
-        self.assertEqual(info.retval, 4)
+        retval, info = trace(add, (2, 2), {})
+        self.assertIsNone(info)
+        self.assertEqual(retval, 4)
 
 
     def test_trace_SystemExit(self):
     def test_trace_SystemExit(self):
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
@@ -25,13 +38,13 @@ class test_trace(unittest.TestCase):
 
 
     def test_trace_RetryTaskError(self):
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError("foo", "bar")
         exc = RetryTaskError("foo", "bar")
-        info = trace(raises, (exc, ), {})
+        _, info = trace(raises, (exc, ), {})
         self.assertEqual(info.state, states.RETRY)
         self.assertEqual(info.state, states.RETRY)
         self.assertIs(info.retval, exc)
         self.assertIs(info.retval, exc)
 
 
     def test_trace_exception(self):
     def test_trace_exception(self):
         exc = KeyError("foo")
         exc = KeyError("foo")
-        info = trace(raises, (exc, ), {})
+        _, info = trace(raises, (exc, ), {})
         self.assertEqual(info.state, states.FAILURE)
         self.assertEqual(info.state, states.FAILURE)
         self.assertIs(info.retval, exc)
         self.assertIs(info.retval, exc)
 
 

+ 43 - 30
celery/tests/test_worker/test_worker_job.py

@@ -19,12 +19,12 @@ from celery import states
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
-from celery.task import task as task_dec
 from celery.exceptions import (RetryTaskError, NotRegistered,
 from celery.exceptions import (RetryTaskError, NotRegistered,
                                WorkerLostError, InvalidTaskError)
                                WorkerLostError, InvalidTaskError)
-from celery.execute.trace import TaskTrace
+from celery.execute.trace import trace_task, TraceInfo
 from celery.log import setup_logger
 from celery.log import setup_logger
 from celery.result import AsyncResult
 from celery.result import AsyncResult
+from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.encoding import from_utf8, default_encode
 from celery.utils.encoding import from_utf8, default_encode
@@ -41,7 +41,7 @@ some_kwargs_scratchpad = {}
 
 
 
 
 def jail(task_id, task_name, args, kwargs):
 def jail(task_id, task_name, args, kwargs):
-    return TaskTrace(task_name, task_id, args, kwargs)()
+    return trace_task(task_name, task_id, args, kwargs)[0]
 
 
 
 
 def on_ack():
 def on_ack():
@@ -110,7 +110,7 @@ class test_RetryTaskError(unittest.TestCase):
             self.assertEqual(ret.exc, exc)
             self.assertEqual(ret.exc, exc)
 
 
 
 
-class test_TaskTrace(unittest.TestCase):
+class test_trace_task(unittest.TestCase):
 
 
     def test_process_cleanup_fails(self):
     def test_process_cleanup_fails(self):
         backend = mytask.backend
         backend = mytask.backend
@@ -169,17 +169,26 @@ class test_TaskTrace(unittest.TestCase):
             mytask.ignore_result = False
             mytask.ignore_result = False
 
 
     def test_execute_jail_failure(self):
     def test_execute_jail_failure(self):
-        ret = jail(uuid(), mytask_raising.name,
-                   [4], {})
-        self.assertIsInstance(ret, ExceptionInfo)
-        self.assertTupleEqual(ret.exception.args, (4, ))
+        u = uuid()
+        mytask_raising.request.update({"id": u})
+        try:
+            ret = jail(u, mytask_raising.name,
+                    [4], {})
+            self.assertIsInstance(ret, ExceptionInfo)
+            self.assertTupleEqual(ret.exception.args, (4, ))
+        finally:
+            mytask_raising.request.clear()
 
 
     def test_execute_ignore_result(self):
     def test_execute_ignore_result(self):
         task_id = uuid()
         task_id = uuid()
-        ret = jail(id, MyTaskIgnoreResult.name,
-                   [4], {})
-        self.assertEqual(ret, 256)
-        self.assertFalse(AsyncResult(task_id).ready())
+        MyTaskIgnoreResult.request.update({"id": task_id})
+        try:
+            ret = jail(task_id, MyTaskIgnoreResult.name,
+                       [4], {})
+            self.assertEqual(ret, 256)
+            self.assertFalse(AsyncResult(task_id).ready())
+        finally:
+            MyTaskIgnoreResult.request.clear()
 
 
 
 
 class MockEventDispatcher(object):
 class MockEventDispatcher(object):
@@ -488,27 +497,31 @@ class test_TaskRequest(unittest.TestCase):
     def test_worker_task_trace_handle_retry(self):
     def test_worker_task_trace_handle_retry(self):
         from celery.exceptions import RetryTaskError
         from celery.exceptions import RetryTaskError
         tid = uuid()
         tid = uuid()
-        w = TaskTrace(mytask.name, tid, [4], {})
-        type_, value_, tb_ = self.create_exception(ValueError("foo"))
-        type_, value_, tb_ = self.create_exception(RetryTaskError(str(value_),
-                                                                  exc=value_))
-        w._store_errors = False
-        w.handle_retry(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
-        w._store_errors = True
-        w.handle_retry(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.RETRY)
+        mytask.request.update({"id": tid})
+        try:
+            _, value_, _ = self.create_exception(ValueError("foo"))
+            einfo = self.create_exception(RetryTaskError(str(value_),
+                                          exc=value_))
+            w = TraceInfo(states.RETRY, einfo[1], einfo)
+            w.handle_retry(mytask, store_errors=False)
+            self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
+            w.handle_retry(mytask, store_errors=True)
+            self.assertEqual(mytask.backend.get_status(tid), states.RETRY)
+        finally:
+            mytask.request.clear()
 
 
     def test_worker_task_trace_handle_failure(self):
     def test_worker_task_trace_handle_failure(self):
         tid = uuid()
         tid = uuid()
-        w = TaskTrace(mytask.name, tid, [4], {})
-        type_, value_, tb_ = self.create_exception(ValueError("foo"))
-        w._store_errors = False
-        w.handle_failure(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
-        w._store_errors = True
-        w.handle_failure(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.FAILURE)
+        mytask.request.update({"id": tid})
+        try:
+            einfo = self.create_exception(ValueError("foo"))
+            w = TraceInfo(states.FAILURE, einfo[1], einfo)
+            w.handle_failure(mytask, store_errors=False)
+            self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
+            w.handle_failure(mytask, store_errors=True)
+            self.assertEqual(mytask.backend.get_status(tid), states.FAILURE)
+        finally:
+            mytask.request.clear()
 
 
     def test_task_wrapper_mail_attrs(self):
     def test_task_wrapper_mail_attrs(self):
         tw = TaskRequest(mytask.name, uuid(), [], {})
         tw = TaskRequest(mytask.name, uuid(), [], {})

+ 4 - 0
celery/worker/__init__.py

@@ -341,3 +341,7 @@ class WorkController(object):
 
 
     def on_timer_tick(self, delay):
     def on_timer_tick(self, delay):
         self.timer_debug("Scheduler wake-up! Next eta %s secs." % delay)
         self.timer_debug("Scheduler wake-up! Next eta %s secs." % delay)
+
+    @property
+    def state(self):
+        return state

+ 9 - 10
celery/worker/job.py

@@ -21,7 +21,7 @@ from datetime import datetime
 from .. import exceptions
 from .. import exceptions
 from .. import registry
 from .. import registry
 from ..app import app_or_default
 from ..app import app_or_default
-from ..execute.trace import TaskTrace
+from ..execute.trace import trace_task
 from ..platforms import set_mp_process_title as setps
 from ..platforms import set_mp_process_title as setps
 from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
 from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
 from ..utils.encoding import safe_repr, safe_str
 from ..utils.encoding import safe_repr, safe_str
@@ -39,13 +39,13 @@ def execute_and_trace(task_name, *args, **kwargs):
 
 
     It's the same as::
     It's the same as::
 
 
-        >>> TaskTrace(task_name, *args, **kwargs)()
+        >>> trace_task(task_name, *args, **kwargs)[0]
 
 
     """
     """
     hostname = kwargs.get("hostname")
     hostname = kwargs.get("hostname")
     setps("celeryd", task_name, hostname, rate_limit=True)
     setps("celeryd", task_name, hostname, rate_limit=True)
     try:
     try:
-        return TaskTrace(task_name, *args, **kwargs)()
+        return trace_task(task_name, *args, **kwargs)[0]
     finally:
     finally:
         setps("celeryd", "-idle-", hostname, rate_limit=True)
         setps("celeryd", "-idle-", hostname, rate_limit=True)
 
 
@@ -257,7 +257,7 @@ class TaskRequest(object):
         return result
         return result
 
 
     def execute(self, loglevel=None, logfile=None):
     def execute(self, loglevel=None, logfile=None):
-        """Execute the task in a :class:`TaskTrace`.
+        """Execute the task in a :func:`~celery.execute.trace.trace_task`.
 
 
         :keyword loglevel: The loglevel used by the task.
         :keyword loglevel: The loglevel used by the task.
 
 
@@ -272,11 +272,10 @@ class TaskRequest(object):
             self.acknowledge()
             self.acknowledge()
 
 
         instance_attrs = self.get_instance_attrs(loglevel, logfile)
         instance_attrs = self.get_instance_attrs(loglevel, logfile)
-        tracer = TaskTrace(*self._get_tracer_args(loglevel, logfile),
-                           **{"hostname": self.hostname,
-                              "loader": self.app.loader,
-                              "request": instance_attrs})
-        retval = tracer.execute()
+        retval, _ = trace_task(*self._get_tracer_args(loglevel, logfile),
+                               **{"hostname": self.hostname,
+                                  "loader": self.app.loader,
+                                  "request": instance_attrs})
         self.acknowledge()
         self.acknowledge()
         return retval
         return retval
 
 
@@ -446,6 +445,6 @@ class TaskRequest(object):
                 self.task_name, self.task_id, self.args, self.kwargs)
                 self.task_name, self.task_id, self.args, self.kwargs)
 
 
     def _get_tracer_args(self, loglevel=None, logfile=None):
     def _get_tracer_args(self, loglevel=None, logfile=None):
-        """Get the :class:`TaskTrace` tracer for this task."""
+        """Get the task trace args for this task."""
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         return self.task_name, self.task_id, self.args, task_func_kwargs
         return self.task_name, self.task_id, self.args, task_func_kwargs

+ 25 - 16
funtests/benchmarks/bench_worker.py

@@ -2,8 +2,12 @@ import os
 import sys
 import sys
 import time
 import time
 
 
-#import anyjson
-#anyjson.force_implementation("cjson")
+import anyjson
+JSONIMP = os.environ.get("JSONIMP")
+if JSONIMP:
+    anyjson.force_implementation(JSONIMP)
+
+print("anyjson implementation: %r" % (anyjson.implementation.name, ))
 
 
 from celery import Celery
 from celery import Celery
 
 
@@ -14,7 +18,7 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                    BROKER_POOL_LIMIT=10,
                    BROKER_POOL_LIMIT=10,
                    CELERY_PREFETCH_MULTIPLIER=0,
                    CELERY_PREFETCH_MULTIPLIER=0,
                    CELERY_DISABLE_RATE_LIMITS=True,
                    CELERY_DISABLE_RATE_LIMITS=True,
-                   CELERY_DEFAULT_DELIVERY_MODE="transient",
+                   #CELERY_DEFAULT_DELIVERY_MODE="transient",
                    CELERY_QUEUES = {
                    CELERY_QUEUES = {
                        "bench.worker": {
                        "bench.worker": {
                            "exchange": "bench.worker",
                            "exchange": "bench.worker",
@@ -27,16 +31,21 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                    CELERY_BACKEND=None)
                    CELERY_BACKEND=None)
 
 
 
 
+def tdiff(then):
+    return time.time() - then
+
+
 @celery.task(cur=0, time_start=None, queue="bench.worker")
 @celery.task(cur=0, time_start=None, queue="bench.worker")
 def it(_, n):
 def it(_, n):
     i = it.cur  # use internal counter, as ordering can be skewed
     i = it.cur  # use internal counter, as ordering can be skewed
                 # by previous runs, or the broker.
                 # by previous runs, or the broker.
     if i and not i % 5000:
     if i and not i % 5000:
-        print >> sys.stderr, "(%s so far)" % (i, )
+        print >> sys.stderr, "(%s so far: %ss)" % (i, tdiff(it.subt))
+        it.subt = time.time()
     if not i:
     if not i:
-        it.time_start = time.time()
+        it.subt = it.time_start = time.time()
     elif i == n - 1:
     elif i == n - 1:
-        print("-- process %s tasks: %ss" % (n, time.time() - it.time_start, ))
+        print("-- process %s tasks: %ss" % (n, tdiff(it.time_start), ))
         sys.exit()
         sys.exit()
     it.cur += 1
     it.cur += 1
 
 
@@ -47,12 +56,9 @@ def bench_apply(n=DEFAULT_ITS):
     print("-- apply %s tasks: %ss" % (n, time.time() - time_start, ))
     print("-- apply %s tasks: %ss" % (n, time.time() - time_start, ))
 
 
 
 
-def bench_work(n=DEFAULT_ITS):
-    from celery.worker import WorkController
-    from celery.worker import state
-
-    #import logging
-    #celery.log.setup_logging_subsystem(loglevel=logging.DEBUG)
+def bench_work(n=DEFAULT_ITS, loglevel=None):
+    if loglevel:
+        celery.log.setup_logging_subsystem(loglevel=loglevel)
     worker = celery.WorkController(concurrency=15, pool_cls="solo",
     worker = celery.WorkController(concurrency=15, pool_cls="solo",
                                    queues=["bench.worker"])
                                    queues=["bench.worker"])
 
 
@@ -60,7 +66,7 @@ def bench_work(n=DEFAULT_ITS):
         print("STARTING WORKER")
         print("STARTING WORKER")
         worker.start()
         worker.start()
     except SystemExit:
     except SystemExit:
-        assert sum(state.total_count.values()) == n + 1
+        assert sum(worker.state.total_count.values()) == n + 1
 
 
 
 
 def bench_both(n=DEFAULT_ITS):
 def bench_both(n=DEFAULT_ITS):
@@ -72,9 +78,12 @@ def main(argv=sys.argv):
     if len(argv) < 2:
     if len(argv) < 2:
         print("Usage: %s [apply|work|both]" % (os.path.basename(argv[0]), ))
         print("Usage: %s [apply|work|both]" % (os.path.basename(argv[0]), ))
         return sys.exit(1)
         return sys.exit(1)
-    return {"apply": bench_apply,
-            "work": bench_work,
-            "both": bench_both}[argv[1]]()
+    try:
+        return {"apply": bench_apply,
+                "work": bench_work,
+                "both": bench_both}[argv[1]]()
+    except KeyboardInterrupt:
+        pass
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":