Explorar el Código

Optimization: celeryd now handles twice as many messages per/s.

Ask Solem hace 13 años
padre
commit
76bb45c79a

+ 6 - 4
celery/app/task/__init__.py

@@ -15,6 +15,7 @@ from __future__ import absolute_import
 import sys
 import threading
 
+from ... import states
 from ...datastructures import ExceptionInfo
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...execute.trace import trace_task
@@ -621,8 +622,10 @@ class BaseTask(object):
                                   task=task, request=request, propagate=throw)
         if isinstance(retval, ExceptionInfo):
             retval = retval.exception
-        return EagerResult(task_id, retval, info.state,
-                           traceback=info.strtb)
+        state, tb = states.SUCCESS, ''
+        if info is not None:
+            state, tb = info.state, info.strtb
+        return EagerResult(task_id, retval, state, traceback=tb)
 
     @classmethod
     def AsyncResult(self, task_id):
@@ -680,8 +683,7 @@ class BaseTask(object):
         The return value of this handler is ignored.
 
         """
-        if self.request.chord:
-            self.backend.on_chord_part_return(self)
+        pass
 
     def on_failure(self, exc, task_id, args, kwargs, einfo):
         """Error handler.

+ 138 - 150
celery/execute/trace.py

@@ -12,6 +12,14 @@
 """
 from __future__ import absolute_import
 
+# ## ---
+# BE WARNED: You are probably going to suffer a heartattack just
+#            by looking at this code!
+#
+# This is the heart of the worker, the inner loop so to speak.
+# It used to be split up into nice little classes and methods,
+# but in the end it only resulted in bad performance, and horrible tracebacks.
+
 import os
 import socket
 import sys
@@ -26,11 +34,21 @@ from ..registry import tasks
 from ..utils.serialization import get_pickleable_exception
 
 send_prerun = signals.task_prerun.send
+prerun_receivers = signals.task_prerun.receivers
 send_postrun = signals.task_postrun.send
+postrun_receivers = signals.task_postrun.receivers
 SUCCESS = states.SUCCESS
 RETRY = states.RETRY
 FAILURE = states.FAILURE
 EXCEPTION_STATES = states.EXCEPTION_STATES
+_pid = None
+
+
+def getpid():
+    global _pid
+    if _pid is None:
+        _pid = os.getpid()
+    return _pid
 
 
 class TraceInfo(object):
@@ -46,38 +64,56 @@ class TraceInfo(object):
         else:
             self.exc_type = self.exc_value = self.tb = None
 
-    @property
-    def strtb(self):
-        if self.exc_info:
-            return "\n".join(traceback.format_exception(*self.exc_info))
-
+    def handle_error_state(self, task, eager=False):
+        store_errors = not eager
+        if task.ignore_result:
+            store_errors = task.store_errors_even_if_ignored
 
-def trace(fun, args, kwargs, propagate=False, Info=TraceInfo):
-    """Trace the execution of a function, calling the appropiate callback
-    if the function raises retry, an failure or returned successfully.
+        return {
+            RETRY: self.handle_retry,
+            FAILURE: self.handle_failure,
+        }[self.state](task, store_errors=store_errors)
 
-    :keyword propagate: If true, errors will propagate to the caller.
+    def handle_retry(self, task, store_errors=True):
+        """Handle retry exception."""
+        # Create a simpler version of the RetryTaskError that stringifies
+        # the original exception instead of including the exception instance.
+        # This is for reporting the retry in logs, email etc, while
+        # guaranteeing pickleability.
+        req = task.request
+        exc, type_, tb = self.retval, self.exc_type, self.tb
+        message, orig_exc = self.retval.args
+        if store_errors:
+            task.backend.mark_as_retry(req.id, orig_exc, self.strtb)
+        expanded_msg = "%s: %s" % (message, str(orig_exc))
+        einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
+        task.on_retry(exc, req.id, req.args, req.kwargs, einfo)
+        return einfo
 
-    """
-    try:
-        return Info(states.SUCCESS, retval=fun(*args, **kwargs))
-    except RetryTaskError, exc:
-        return Info(states.RETRY, retval=exc, exc_info=sys.exc_info())
-    except Exception, exc:
-        if propagate:
-            raise
-        return Info(states.FAILURE, retval=exc, exc_info=sys.exc_info())
-    except BaseException, exc:
-        raise
-    except:  # pragma: no cover
-        # For Python2.5 where raising strings are still allowed
-        # (but deprecated)
-        if propagate:
-            raise
-        return Info(states.FAILURE, retval=None, exc_info=sys.exc_info())
+    def handle_failure(self, task, store_errors=True):
+        """Handle exception."""
+        req = task.request
+        exc, type_, tb = self.retval, self.exc_type, self.tb
+        if store_errors:
+            task.backend.mark_as_failure(req.id, exc, self.strtb)
+        exc = get_pickleable_exception(exc)
+        einfo = ExceptionInfo((type_, exc, tb))
+        task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
+        signals.task_failure.send(sender=task, task_id=req.id,
+                                  exception=exc, args=req.args,
+                                  kwargs=req.kwargs, traceback=tb,
+                                  einfo=einfo)
+        return einfo
 
+    @property
+    def strtb(self):
+        if self.exc_info:
+            return '\n'.join(traceback.format_exception(*self.exc_info))
+        return ''
 
-class TaskTrace(object):
+def trace_task(name, uuid, args, kwargs, task=None, request=None,
+        eager=False, propagate=False, loader=None, hostname=None,
+        store_errors=True, Info=TraceInfo):
     """Wraps the task in a jail, catches all exceptions, and
     saves the status and result of the task execution to the task
     meta backend.
@@ -104,131 +140,83 @@ class TaskTrace(object):
 
     :returns: the evaluated functions return value on success, or
         the exception instance on failure.
-
     """
-
-    def __init__(self, task_name, task_id, args, kwargs, task=None,
-            request=None, propagate=None, propagate_internal=False,
-            eager=False, **_):
-        self.task_id = task_id
-        self.task_name = task_name
-        self.args = args
-        self.kwargs = kwargs
-        self.task = task or tasks[self.task_name]
-        self.request = request or {}
-        self.propagate = propagate
-        self.propagate_internal = propagate_internal
-        self.eager = eager
-        self._trace_handlers = {FAILURE: self.handle_failure,
-                                RETRY: self.handle_retry}
-        self.loader = kwargs.get("loader") or current_app.loader
-        self.hostname = kwargs.get("hostname") or socket.gethostname()
-        self._store_errors = True
-        if eager:
-            self._store_errors = False
-        elif self.task.ignore_result:
-            self._store_errors = self.task.store_errors_even_if_ignored
-
-        # Used by always_eager
-        self.state = None
-        self.strtb = None
-
-    def __call__(self):
+    R = I = None
+    try:
+        task = task or tasks[name]
+        backend = task.backend
+        ignore_result = task.ignore_result
+        loader = loader or current_app.loader
+        hostname = hostname or socket.gethostname()
+        task.request.update(request or {}, args=args,
+                            called_directly=False, kwargs=kwargs)
         try:
-            task, uuid, args, kwargs, eager = (self.task, self.task_id,
-                                               self.args, self.kwargs,
-                                               self.eager)
-            backend = task.backend
-            ignore_result = task.ignore_result
-            loader = self.loader
-
-            task.request.update(self.request, args=args,
-                                called_directly=False, kwargs=kwargs)
+            # -*- PRE -*-
+            send_prerun(sender=task, task_id=uuid, task=task,
+                        args=args, kwargs=kwargs)
+            loader.on_task_init(uuid, task)
+            if not eager and (task.track_started and not ignore_result):
+                backend.mark_as_started(uuid, pid=getpid(),
+                                        hostname=hostname)
+
+            # -*- TRACE -*-
             try:
-                # -*- PRE -*-
-                send_prerun(sender=task, task_id=uuid, task=task,
-                            args=args, kwargs=kwargs)
-                loader.on_task_init(uuid, task)
-                if not eager and (task.track_started and not ignore_result):
-                    backend.mark_as_started(uuid, pid=os.getpid(),
-                                            hostname=self.hostname)
-
-                # -*- TRACE -*-
-                # self.info is used by always_eager
-                info = self.info = trace(task, args, kwargs, self.propagate)
-                state, retval, einfo = info.state, info.retval, info.exc_info
-                if eager:
-                    self.state, self.strtb = info.state, info.strtb
-
-                if state == SUCCESS:
-                    task.on_success(retval, uuid, args, kwargs)
-                    if not eager and not ignore_result:
-                        backend.mark_as_done(uuid, retval)
-                    R = retval
-                else:
-                    R = self.handle_trace(info)
-
-                # -* POST *-
-                if state in EXCEPTION_STATES:
-                    einfo = ExceptionInfo(einfo)
-                task.after_return(state, retval, self.task_id,
-                                  self.args, self.kwargs, einfo)
-                send_postrun(sender=task, task_id=uuid, task=task,
-                             args=args, kwargs=kwargs, retval=retval)
-            finally:
-                task.request.clear()
-                if not eager:
-                    try:
-                        backend.process_cleanup()
-                        loader.on_process_cleanup()
-                    except (KeyboardInterrupt, SystemExit, MemoryError):
-                        raise
-                    except Exception, exc:
-                        logger = current_app.log.get_default_logger()
-                        logger.error("Process cleanup failed: %r", exc,
-                                    exc_info=sys.exc_info())
-        except Exception, exc:
-            if self.propagate_internal:
+                R = retval = task(*args, **kwargs)
+                state, einfo = SUCCESS, None
+                task.on_success(retval, uuid, args, kwargs)
+                if not eager and not ignore_result:
+                    backend.mark_as_done(uuid, retval)
+            except RetryTaskError, exc:
+                I = Info(RETRY, exc, sys.exc_info())
+                state, retval, einfo = I.state, I.retval, I.exc_info
+                R = I.handle_error_state(task, eager=eager)
+            except Exception, exc:
+                if propagate:
+                    raise
+                I = Info(FAILURE, exc, sys.exc_info())
+                state, retval, einfo = I.state, I.retval, I.exc_info
+                R = I.handle_error_state(task, eager=eager)
+            except BaseException, exc:
                 raise
-            return self.report_internal_error(exc)
-        return R
-    execute = __call__
-
-    def report_internal_error(self, exc):
-        _type, _value, _tb = sys.exc_info()
-        _value = self.task.backend.prepare_exception(exc)
-        exc_info = ExceptionInfo((_type, _value, _tb))
-        warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
-            map(str, (exc.__class__, exc, exc_info.traceback))))
-        return exc_info
-
-    def handle_trace(self, trace):
-        return self._trace_handlers[trace.state](trace.retval, trace.exc_type,
-                                                 trace.tb, trace.strtb)
-
-    def handle_retry(self, exc, type_, tb, strtb):
-        """Handle retry exception."""
-        # Create a simpler version of the RetryTaskError that stringifies
-        # the original exception instead of including the exception instance.
-        # This is for reporting the retry in logs, email etc, while
-        # guaranteeing pickleability.
-        message, orig_exc = exc.args
-        if self._store_errors:
-            self.task.backend.mark_as_retry(self.task_id, orig_exc, strtb)
-        expanded_msg = "%s: %s" % (message, str(orig_exc))
-        einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
-        self.task.on_retry(exc, self.task_id, self.args, self.kwargs, einfo)
-        return einfo
+            except:
+                # pragma: no cover
+                # For Python2.5 where raising strings are still allowed
+                # (but deprecated)
+                if propagate:
+                    raise
+                I = Info(FAILURE, None, sys.exc_info())
+                state, retval, einfo = I.state, I.retval, I.exc_info
+                R = I.handle_error_state(task, eager=eager)
+
+            # -* POST *-
+            if task.request.chord:
+                backend.on_chord_part_return(task)
+            task.after_return(state, retval, uuid, args, kwargs, einfo)
+            send_postrun(sender=task, task_id=uuid, task=task,
+                         args=args, kwargs=kwargs, retval=retval)
+        finally:
+            task.request.clear()
+            if not eager:
+                try:
+                    backend.process_cleanup()
+                    loader.on_process_cleanup()
+                except (KeyboardInterrupt, SystemExit, MemoryError):
+                    raise
+                except Exception, exc:
+                    logger = current_app.log.get_default_logger()
+                    logger.error("Process cleanup failed: %r", exc,
+                                 exc_info=sys.exc_info())
+    except Exception, exc:
+        if eager:
+            raise
+        R = report_internal_error(task, exc)
+    return R, I
 
-    def handle_failure(self, exc, type_, tb, strtb):
-        """Handle exception."""
-        if self._store_errors:
-            self.task.backend.mark_as_failure(self.task_id, exc, strtb)
-        exc = get_pickleable_exception(exc)
-        einfo = ExceptionInfo((type_, exc, tb))
-        self.task.on_failure(exc, self.task_id, self.args, self.kwargs, einfo)
-        signals.task_failure.send(sender=self.task, task_id=self.task_id,
-                                  exception=exc, args=self.args,
-                                  kwargs=self.kwargs, traceback=tb,
-                                  einfo=einfo)
-        return einfo
+
+def report_internal_error(task, exc):
+    _type, _value, _tb = sys.exc_info()
+    _value = task.backend.prepare_exception(exc)
+    exc_info = ExceptionInfo((_type, _value, _tb))
+    warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
+        map(str, (exc.__class__, exc, exc_info.traceback))))
+    return exc_info

+ 14 - 10
celery/log.py

@@ -24,6 +24,12 @@ from .utils.term import colored
 is_py3k = sys.version_info >= (3, 0)
 
 
+def mlevel(level):
+    if level and not isinstance(level, int):
+        return LOG_LEVELS[level.upper()]
+    return level
+
+
 class ColorFormatter(logging.Formatter):
     #: Loglevel -> Color mapping.
     COLORS = colored().names
@@ -71,7 +77,7 @@ class Logging(object):
 
     def __init__(self, app):
         self.app = app
-        self.loglevel = self.app.conf.CELERYD_LOG_LEVEL
+        self.loglevel = mlevel(self.app.conf.CELERYD_LOG_LEVEL)
         self.format = self.app.conf.CELERYD_LOG_FORMAT
         self.task_format = self.app.conf.CELERYD_TASK_LOG_FORMAT
         self.colorize = self.app.conf.CELERYD_LOG_COLOR
@@ -92,14 +98,14 @@ class Logging(object):
     def get_task_logger(self, loglevel=None, name=None):
         logger = logging.getLogger(name or "celery.task.default")
         if loglevel is not None:
-            logger.setLevel(loglevel)
+            logger.setLevel(mlevel(loglevel))
         return logger
 
     def setup_logging_subsystem(self, loglevel=None, logfile=None,
             format=None, colorize=None, **kwargs):
         if Logging._setup:
             return
-        loglevel = loglevel or self.loglevel
+        loglevel = mlevel(loglevel or self.loglevel)
         format = format or self.format
         if colorize is None:
             colorize = self.supports_color(logfile)
@@ -120,7 +126,7 @@ class Logging(object):
             mp = mputil.get_logger() if mputil else None
             for logger in filter(None, (root, mp)):
                 self._setup_logger(logger, logfile, format, colorize, **kwargs)
-                logger.setLevel(loglevel)
+                logger.setLevel(mlevel(loglevel))
                 signals.after_setup_logger.send(sender=None, logger=logger,
                                         loglevel=loglevel, logfile=logfile,
                                         format=format, colorize=colorize)
@@ -144,7 +150,7 @@ class Logging(object):
         """
         logger = logging.getLogger(name)
         if loglevel is not None:
-            logger.setLevel(loglevel)
+            logger.setLevel(mlevel(loglevel))
         return logger
 
     def setup_logger(self, loglevel=None, logfile=None,
@@ -157,7 +163,7 @@ class Logging(object):
         Returns logger object.
 
         """
-        loglevel = loglevel or self.loglevel
+        loglevel = mlevel(loglevel or self.loglevel)
         format = format or self.format
         if colorize is None:
             colorize = self.supports_color(logfile)
@@ -179,7 +185,7 @@ class Logging(object):
         Returns logger object.
 
         """
-        loglevel = loglevel or self.loglevel
+        loglevel = mlevel(loglevel or self.loglevel)
         format = format or self.task_format
         if colorize is None:
             colorize = self.supports_color(logfile)
@@ -247,9 +253,7 @@ class LoggingProxy(object):
 
     def __init__(self, logger, loglevel=None):
         self.logger = logger
-        self.loglevel = loglevel or self.logger.level or self.loglevel
-        if not isinstance(self.loglevel, int):
-            self.loglevel = LOG_LEVELS[self.loglevel.upper()]
+        self.loglevel = mlevel(loglevel or self.logger.level or self.loglevel)
         self._safewrap_handlers()
 
     def _safewrap_handlers(self):

+ 0 - 2
celery/tests/test_task/__init__.py

@@ -343,10 +343,8 @@ class TestCeleryTasks(unittest.TestCase):
 
     def test_after_return(self):
         task = self.createTaskCls("T1", "c.unittest.t.after_return")()
-        task.backend = Mock()
         task.request.chord = return_True_task.subtask()
         task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
-        task.backend.on_chord_part_return.assert_called_with(task)
         task.request.clear()
 
     def test_send_task_sent_event(self):

+ 19 - 6
celery/tests/test_task/test_execute_trace.py

@@ -3,21 +3,34 @@ from __future__ import with_statement
 
 import operator
 
+from celery import current_app
 from celery import states
 from celery.exceptions import RetryTaskError
-from celery.execute.trace import trace
+from celery.execute.trace import trace_task
 from celery.tests.utils import unittest
 
+@current_app.task
+def add(x, y):
+    return x + y
+
+
+@current_app.task
 def raises(exc):
     raise exc
 
 
+def trace(task, args=(), kwargs={}, propagate=False):
+    return trace_task(task.__name__, "id-1", args, kwargs, task,
+                      propagate=propagate, eager=True)
+
+
+
 class test_trace(unittest.TestCase):
 
     def test_trace_successful(self):
-        info = trace(operator.add, (2, 2), {})
-        self.assertEqual(info.state, states.SUCCESS)
-        self.assertEqual(info.retval, 4)
+        retval, info = trace(add, (2, 2), {})
+        self.assertIsNone(info)
+        self.assertEqual(retval, 4)
 
     def test_trace_SystemExit(self):
         with self.assertRaises(SystemExit):
@@ -25,13 +38,13 @@ class test_trace(unittest.TestCase):
 
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError("foo", "bar")
-        info = trace(raises, (exc, ), {})
+        _, info = trace(raises, (exc, ), {})
         self.assertEqual(info.state, states.RETRY)
         self.assertIs(info.retval, exc)
 
     def test_trace_exception(self):
         exc = KeyError("foo")
-        info = trace(raises, (exc, ), {})
+        _, info = trace(raises, (exc, ), {})
         self.assertEqual(info.state, states.FAILURE)
         self.assertIs(info.retval, exc)
 

+ 43 - 30
celery/tests/test_worker/test_worker_job.py

@@ -19,12 +19,12 @@ from celery import states
 from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
-from celery.task import task as task_dec
 from celery.exceptions import (RetryTaskError, NotRegistered,
                                WorkerLostError, InvalidTaskError)
-from celery.execute.trace import TaskTrace
+from celery.execute.trace import trace_task, TraceInfo
 from celery.log import setup_logger
 from celery.result import AsyncResult
+from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils.encoding import from_utf8, default_encode
@@ -41,7 +41,7 @@ some_kwargs_scratchpad = {}
 
 
 def jail(task_id, task_name, args, kwargs):
-    return TaskTrace(task_name, task_id, args, kwargs)()
+    return trace_task(task_name, task_id, args, kwargs)[0]
 
 
 def on_ack():
@@ -110,7 +110,7 @@ class test_RetryTaskError(unittest.TestCase):
             self.assertEqual(ret.exc, exc)
 
 
-class test_TaskTrace(unittest.TestCase):
+class test_trace_task(unittest.TestCase):
 
     def test_process_cleanup_fails(self):
         backend = mytask.backend
@@ -169,17 +169,26 @@ class test_TaskTrace(unittest.TestCase):
             mytask.ignore_result = False
 
     def test_execute_jail_failure(self):
-        ret = jail(uuid(), mytask_raising.name,
-                   [4], {})
-        self.assertIsInstance(ret, ExceptionInfo)
-        self.assertTupleEqual(ret.exception.args, (4, ))
+        u = uuid()
+        mytask_raising.request.update({"id": u})
+        try:
+            ret = jail(u, mytask_raising.name,
+                    [4], {})
+            self.assertIsInstance(ret, ExceptionInfo)
+            self.assertTupleEqual(ret.exception.args, (4, ))
+        finally:
+            mytask_raising.request.clear()
 
     def test_execute_ignore_result(self):
         task_id = uuid()
-        ret = jail(id, MyTaskIgnoreResult.name,
-                   [4], {})
-        self.assertEqual(ret, 256)
-        self.assertFalse(AsyncResult(task_id).ready())
+        MyTaskIgnoreResult.request.update({"id": task_id})
+        try:
+            ret = jail(task_id, MyTaskIgnoreResult.name,
+                       [4], {})
+            self.assertEqual(ret, 256)
+            self.assertFalse(AsyncResult(task_id).ready())
+        finally:
+            MyTaskIgnoreResult.request.clear()
 
 
 class MockEventDispatcher(object):
@@ -488,27 +497,31 @@ class test_TaskRequest(unittest.TestCase):
     def test_worker_task_trace_handle_retry(self):
         from celery.exceptions import RetryTaskError
         tid = uuid()
-        w = TaskTrace(mytask.name, tid, [4], {})
-        type_, value_, tb_ = self.create_exception(ValueError("foo"))
-        type_, value_, tb_ = self.create_exception(RetryTaskError(str(value_),
-                                                                  exc=value_))
-        w._store_errors = False
-        w.handle_retry(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
-        w._store_errors = True
-        w.handle_retry(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.RETRY)
+        mytask.request.update({"id": tid})
+        try:
+            _, value_, _ = self.create_exception(ValueError("foo"))
+            einfo = self.create_exception(RetryTaskError(str(value_),
+                                          exc=value_))
+            w = TraceInfo(states.RETRY, einfo[1], einfo)
+            w.handle_retry(mytask, store_errors=False)
+            self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
+            w.handle_retry(mytask, store_errors=True)
+            self.assertEqual(mytask.backend.get_status(tid), states.RETRY)
+        finally:
+            mytask.request.clear()
 
     def test_worker_task_trace_handle_failure(self):
         tid = uuid()
-        w = TaskTrace(mytask.name, tid, [4], {})
-        type_, value_, tb_ = self.create_exception(ValueError("foo"))
-        w._store_errors = False
-        w.handle_failure(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
-        w._store_errors = True
-        w.handle_failure(value_, type_, tb_, "")
-        self.assertEqual(mytask.backend.get_status(tid), states.FAILURE)
+        mytask.request.update({"id": tid})
+        try:
+            einfo = self.create_exception(ValueError("foo"))
+            w = TraceInfo(states.FAILURE, einfo[1], einfo)
+            w.handle_failure(mytask, store_errors=False)
+            self.assertEqual(mytask.backend.get_status(tid), states.PENDING)
+            w.handle_failure(mytask, store_errors=True)
+            self.assertEqual(mytask.backend.get_status(tid), states.FAILURE)
+        finally:
+            mytask.request.clear()
 
     def test_task_wrapper_mail_attrs(self):
         tw = TaskRequest(mytask.name, uuid(), [], {})

+ 4 - 0
celery/worker/__init__.py

@@ -341,3 +341,7 @@ class WorkController(object):
 
     def on_timer_tick(self, delay):
         self.timer_debug("Scheduler wake-up! Next eta %s secs." % delay)
+
+    @property
+    def state(self):
+        return state

+ 9 - 10
celery/worker/job.py

@@ -21,7 +21,7 @@ from datetime import datetime
 from .. import exceptions
 from .. import registry
 from ..app import app_or_default
-from ..execute.trace import TaskTrace
+from ..execute.trace import trace_task
 from ..platforms import set_mp_process_title as setps
 from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
 from ..utils.encoding import safe_repr, safe_str
@@ -39,13 +39,13 @@ def execute_and_trace(task_name, *args, **kwargs):
 
     It's the same as::
 
-        >>> TaskTrace(task_name, *args, **kwargs)()
+        >>> trace_task(task_name, *args, **kwargs)[0]
 
     """
     hostname = kwargs.get("hostname")
     setps("celeryd", task_name, hostname, rate_limit=True)
     try:
-        return TaskTrace(task_name, *args, **kwargs)()
+        return trace_task(task_name, *args, **kwargs)[0]
     finally:
         setps("celeryd", "-idle-", hostname, rate_limit=True)
 
@@ -257,7 +257,7 @@ class TaskRequest(object):
         return result
 
     def execute(self, loglevel=None, logfile=None):
-        """Execute the task in a :class:`TaskTrace`.
+        """Execute the task in a :func:`~celery.execute.trace.trace_task`.
 
         :keyword loglevel: The loglevel used by the task.
 
@@ -272,11 +272,10 @@ class TaskRequest(object):
             self.acknowledge()
 
         instance_attrs = self.get_instance_attrs(loglevel, logfile)
-        tracer = TaskTrace(*self._get_tracer_args(loglevel, logfile),
-                           **{"hostname": self.hostname,
-                              "loader": self.app.loader,
-                              "request": instance_attrs})
-        retval = tracer.execute()
+        retval, _ = trace_task(*self._get_tracer_args(loglevel, logfile),
+                               **{"hostname": self.hostname,
+                                  "loader": self.app.loader,
+                                  "request": instance_attrs})
         self.acknowledge()
         return retval
 
@@ -446,6 +445,6 @@ class TaskRequest(object):
                 self.task_name, self.task_id, self.args, self.kwargs)
 
     def _get_tracer_args(self, loglevel=None, logfile=None):
-        """Get the :class:`TaskTrace` tracer for this task."""
+        """Get the task trace args for this task."""
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         return self.task_name, self.task_id, self.args, task_func_kwargs

+ 25 - 16
funtests/benchmarks/bench_worker.py

@@ -2,8 +2,12 @@ import os
 import sys
 import time
 
-#import anyjson
-#anyjson.force_implementation("cjson")
+import anyjson
+JSONIMP = os.environ.get("JSONIMP")
+if JSONIMP:
+    anyjson.force_implementation(JSONIMP)
+
+print("anyjson implementation: %r" % (anyjson.implementation.name, ))
 
 from celery import Celery
 
@@ -14,7 +18,7 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                    BROKER_POOL_LIMIT=10,
                    CELERY_PREFETCH_MULTIPLIER=0,
                    CELERY_DISABLE_RATE_LIMITS=True,
-                   CELERY_DEFAULT_DELIVERY_MODE="transient",
+                   #CELERY_DEFAULT_DELIVERY_MODE="transient",
                    CELERY_QUEUES = {
                        "bench.worker": {
                            "exchange": "bench.worker",
@@ -27,16 +31,21 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                    CELERY_BACKEND=None)
 
 
+def tdiff(then):
+    return time.time() - then
+
+
 @celery.task(cur=0, time_start=None, queue="bench.worker")
 def it(_, n):
     i = it.cur  # use internal counter, as ordering can be skewed
                 # by previous runs, or the broker.
     if i and not i % 5000:
-        print >> sys.stderr, "(%s so far)" % (i, )
+        print >> sys.stderr, "(%s so far: %ss)" % (i, tdiff(it.subt))
+        it.subt = time.time()
     if not i:
-        it.time_start = time.time()
+        it.subt = it.time_start = time.time()
     elif i == n - 1:
-        print("-- process %s tasks: %ss" % (n, time.time() - it.time_start, ))
+        print("-- process %s tasks: %ss" % (n, tdiff(it.time_start), ))
         sys.exit()
     it.cur += 1
 
@@ -47,12 +56,9 @@ def bench_apply(n=DEFAULT_ITS):
     print("-- apply %s tasks: %ss" % (n, time.time() - time_start, ))
 
 
-def bench_work(n=DEFAULT_ITS):
-    from celery.worker import WorkController
-    from celery.worker import state
-
-    #import logging
-    #celery.log.setup_logging_subsystem(loglevel=logging.DEBUG)
+def bench_work(n=DEFAULT_ITS, loglevel=None):
+    if loglevel:
+        celery.log.setup_logging_subsystem(loglevel=loglevel)
     worker = celery.WorkController(concurrency=15, pool_cls="solo",
                                    queues=["bench.worker"])
 
@@ -60,7 +66,7 @@ def bench_work(n=DEFAULT_ITS):
         print("STARTING WORKER")
         worker.start()
     except SystemExit:
-        assert sum(state.total_count.values()) == n + 1
+        assert sum(worker.state.total_count.values()) == n + 1
 
 
 def bench_both(n=DEFAULT_ITS):
@@ -72,9 +78,12 @@ def main(argv=sys.argv):
     if len(argv) < 2:
         print("Usage: %s [apply|work|both]" % (os.path.basename(argv[0]), ))
         return sys.exit(1)
-    return {"apply": bench_apply,
-            "work": bench_work,
-            "both": bench_both}[argv[1]]()
+    try:
+        return {"apply": bench_apply,
+                "work": bench_work,
+                "both": bench_both}[argv[1]]()
+    except KeyboardInterrupt:
+        pass
 
 
 if __name__ == "__main__":