Bladeren bron

Renames celery.task.trace -> celery.app.trace. Closes #1446

Ask Solem 11 jaren geleden
bovenliggende
commit
250631f5c5

+ 1 - 1
celery/app/task.py

@@ -622,7 +622,7 @@ class Task(object):
 
         """
         # trace imports Task, so need to import inline.
-        from celery.task.trace import eager_trace_task
+        from celery.app.trace import eager_trace_task
 
         app = self._get_app()
         args = args or ()

+ 384 - 0
celery/app/trace.py

@@ -0,0 +1,384 @@
+# -*- coding: utf-8 -*-
+"""
+    celery.app.trace
+    ~~~~~~~~~~~~~~~~
+
+    This module defines how the task execution is traced:
+    errors are recorded, handlers are applied and so on.
+
+"""
+from __future__ import absolute_import
+
+# ## ---
+# This is the heart of the worker, the inner loop so to speak.
+# It used to be split up into nice little classes and methods,
+# but in the end it only resulted in bad performance and horrible tracebacks,
+# so instead we now use one closure per task class.
+
+import os
+import socket
+import sys
+
+from warnings import warn
+
+from billiard.einfo import ExceptionInfo
+from kombu.utils import kwdict
+
+from celery import current_app
+from celery import states, signals
+from celery._state import _task_stack
+from celery.app import set_default_app
+from celery.app.task import Task as BaseTask, Context
+from celery.exceptions import Ignore, RetryTaskError
+from celery.utils.log import get_logger
+from celery.utils.objects import mro_lookup
+from celery.utils.serialization import (
+    get_pickleable_exception,
+    get_pickleable_etype,
+)
+
+_logger = get_logger(__name__)
+
+send_prerun = signals.task_prerun.send
+send_postrun = signals.task_postrun.send
+send_success = signals.task_success.send
+STARTED = states.STARTED
+SUCCESS = states.SUCCESS
+IGNORED = states.IGNORED
+RETRY = states.RETRY
+FAILURE = states.FAILURE
+EXCEPTION_STATES = states.EXCEPTION_STATES
+IGNORE_STATES = frozenset([IGNORED, RETRY])
+
+#: set by :func:`setup_worker_optimizations`
+_tasks = None
+_patched = {}
+
+
+def task_has_custom(task, attr):
+    """Returns true if the task or one of its bases
+    defines ``attr`` (excluding the one in BaseTask)."""
+    return mro_lookup(task.__class__, attr, stop=(BaseTask, object),
+                      monkey_patched=['celery.app.task'])
+
+
+class TraceInfo(object):
+    __slots__ = ('state', 'retval')
+
+    def __init__(self, state, retval=None):
+        self.state = state
+        self.retval = retval
+
+    def handle_error_state(self, task, eager=False):
+        store_errors = not eager
+        if task.ignore_result:
+            store_errors = task.store_errors_even_if_ignored
+
+        return {
+            RETRY: self.handle_retry,
+            FAILURE: self.handle_failure,
+        }[self.state](task, store_errors=store_errors)
+
+    def handle_retry(self, task, store_errors=True):
+        """Handle retry exception."""
+        # the exception raised is the RetryTaskError semi-predicate,
+        # and it's exc' attribute is the original exception raised (if any).
+        req = task.request
+        type_, _, tb = sys.exc_info()
+        try:
+            reason = self.retval
+            einfo = ExceptionInfo((type_, reason, tb))
+            if store_errors:
+                task.backend.mark_as_retry(req.id, reason.exc, einfo.traceback)
+            task.on_retry(reason.exc, req.id, req.args, req.kwargs, einfo)
+            signals.task_retry.send(sender=task, request=req,
+                                    reason=reason, einfo=einfo)
+            return einfo
+        finally:
+            del(tb)
+
+    def handle_failure(self, task, store_errors=True):
+        """Handle exception."""
+        req = task.request
+        type_, _, tb = sys.exc_info()
+        try:
+            exc = self.retval
+            einfo = ExceptionInfo()
+            einfo.exception = get_pickleable_exception(einfo.exception)
+            einfo.type = get_pickleable_etype(einfo.type)
+            if store_errors:
+                task.backend.mark_as_failure(req.id, exc, einfo.traceback)
+            task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
+            signals.task_failure.send(sender=task, task_id=req.id,
+                                      exception=exc, args=req.args,
+                                      kwargs=req.kwargs,
+                                      traceback=tb,
+                                      einfo=einfo)
+            return einfo
+        finally:
+            del(tb)
+
+
+def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
+                 Info=TraceInfo, eager=False, propagate=False,
+                 IGNORE_STATES=IGNORE_STATES):
+    """Returns a function that traces task execution; catches all
+    exceptions and updates result backend with the state and result
+
+    If the call was successful, it saves the result to the task result
+    backend, and sets the task status to `"SUCCESS"`.
+
+    If the call raises :exc:`~celery.exceptions.RetryTaskError`, it extracts
+    the original exception, uses that as the result and sets the task state
+    to `"RETRY"`.
+
+    If the call results in an exception, it saves the exception as the task
+    result, and sets the task state to `"FAILURE"`.
+
+    Returns a function that takes the following arguments:
+
+        :param uuid: The id of the task.
+        :param args: List of positional args to pass on to the function.
+        :param kwargs: Keyword arguments mapping to pass on to the function.
+        :keyword request: Request dict.
+
+    """
+    # If the task doesn't define a custom __call__ method
+    # we optimize it away by simply calling the run method directly,
+    # saving the extra method call and a line less in the stack trace.
+    fun = task if task_has_custom(task, '__call__') else task.run
+
+    loader = loader or current_app.loader
+    backend = task.backend
+    ignore_result = task.ignore_result
+    track_started = task.track_started
+    track_started = not eager and (task.track_started and not ignore_result)
+    publish_result = not eager and not ignore_result
+    hostname = hostname or socket.gethostname()
+
+    loader_task_init = loader.on_task_init
+    loader_cleanup = loader.on_process_cleanup
+
+    task_on_success = None
+    task_after_return = None
+    if task_has_custom(task, 'on_success'):
+        task_on_success = task.on_success
+    if task_has_custom(task, 'after_return'):
+        task_after_return = task.after_return
+
+    store_result = backend.store_result
+    backend_cleanup = backend.process_cleanup
+
+    pid = os.getpid()
+
+    request_stack = task.request_stack
+    push_request = request_stack.push
+    pop_request = request_stack.pop
+    push_task = _task_stack.push
+    pop_task = _task_stack.pop
+    on_chord_part_return = backend.on_chord_part_return
+
+    prerun_receivers = signals.task_prerun.receivers
+    postrun_receivers = signals.task_postrun.receivers
+    success_receivers = signals.task_success.receivers
+
+    from celery import canvas
+    subtask = canvas.subtask
+
+    def trace_task(uuid, args, kwargs, request=None):
+        R = I = None
+        kwargs = kwdict(kwargs)
+        try:
+            push_task(task)
+            task_request = Context(request or {}, args=args,
+                                   called_directly=False, kwargs=kwargs)
+            push_request(task_request)
+            try:
+                # -*- PRE -*-
+                if prerun_receivers:
+                    send_prerun(sender=task, task_id=uuid, task=task,
+                                args=args, kwargs=kwargs)
+                loader_task_init(uuid, task)
+                if track_started:
+                    store_result(uuid, {'pid': pid,
+                                        'hostname': hostname}, STARTED)
+
+                # -*- TRACE -*-
+                try:
+                    R = retval = fun(*args, **kwargs)
+                    state = SUCCESS
+                except Ignore as exc:
+                    I, R = Info(IGNORED, exc), ExceptionInfo(internal=True)
+                    state, retval = I.state, I.retval
+                except RetryTaskError as exc:
+                    I = Info(RETRY, exc)
+                    state, retval = I.state, I.retval
+                    R = I.handle_error_state(task, eager=eager)
+                except Exception as exc:
+                    if propagate:
+                        raise
+                    I = Info(FAILURE, exc)
+                    state, retval = I.state, I.retval
+                    R = I.handle_error_state(task, eager=eager)
+                    [subtask(errback).apply_async((uuid, ))
+                        for errback in task_request.errbacks or []]
+                except BaseException as exc:
+                    raise
+                else:
+                    # callback tasks must be applied before the result is
+                    # stored, so that result.children is populated.
+                    [subtask(callback).apply_async((retval, ))
+                        for callback in task_request.callbacks or []]
+                    if publish_result:
+                        store_result(uuid, retval, SUCCESS)
+                    if task_on_success:
+                        task_on_success(retval, uuid, args, kwargs)
+                    if success_receivers:
+                        send_success(sender=task, result=retval)
+
+                # -* POST *-
+                if state not in IGNORE_STATES:
+                    if task_request.chord:
+                        on_chord_part_return(task)
+                    if task_after_return:
+                        task_after_return(
+                            state, retval, uuid, args, kwargs, None,
+                        )
+                    if postrun_receivers:
+                        send_postrun(sender=task, task_id=uuid, task=task,
+                                     args=args, kwargs=kwargs,
+                                     retval=retval, state=state)
+            finally:
+                pop_task()
+                pop_request()
+                if not eager:
+                    try:
+                        backend_cleanup()
+                        loader_cleanup()
+                    except (KeyboardInterrupt, SystemExit, MemoryError):
+                        raise
+                    except Exception as exc:
+                        _logger.error('Process cleanup failed: %r', exc,
+                                      exc_info=True)
+        except MemoryError:
+            raise
+        except Exception as exc:
+            if eager:
+                raise
+            R = report_internal_error(task, exc)
+        return R, I
+
+    return trace_task
+
+
+def trace_task(task, uuid, args, kwargs, request={}, **opts):
+    try:
+        if task.__trace__ is None:
+            task.__trace__ = build_tracer(task.name, task, **opts)
+        return task.__trace__(uuid, args, kwargs, request)[0]
+    except Exception as exc:
+        return report_internal_error(task, exc)
+
+
+def _trace_task_ret(name, uuid, args, kwargs, request={}, **opts):
+    return trace_task(current_app.tasks[name],
+                      uuid, args, kwargs, request, **opts)
+trace_task_ret = _trace_task_ret
+
+
+def _fast_trace_task(task, uuid, args, kwargs, request={}):
+    # setup_worker_optimizations will point trace_task_ret to here,
+    # so this is the function used in the worker.
+    return _tasks[task].__trace__(uuid, args, kwargs, request)[0]
+
+
+def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
+    opts.setdefault('eager', True)
+    return build_tracer(task.name, task, **opts)(
+        uuid, args, kwargs, request)
+
+
+def report_internal_error(task, exc):
+    _type, _value, _tb = sys.exc_info()
+    try:
+        _value = task.backend.prepare_exception(exc)
+        exc_info = ExceptionInfo((_type, _value, _tb), internal=True)
+        warn(RuntimeWarning(
+            'Exception raised outside body: {0!r}:\n{1}'.format(
+                exc, exc_info.traceback)))
+        return exc_info
+    finally:
+        del(_tb)
+
+
+def setup_worker_optimizations(app):
+    global _tasks
+    global trace_task_ret
+
+    # make sure custom Task.__call__ methods that calls super
+    # will not mess up the request/task stack.
+    _install_stack_protection()
+
+    # all new threads start without a current app, so if an app is not
+    # passed on to the thread it will fall back to the "default app",
+    # which then could be the wrong app.  So for the worker
+    # we set this to always return our app.  This is a hack,
+    # and means that only a single app can be used for workers
+    # running in the same process.
+    app.set_current()
+    set_default_app(app)
+
+    # evaluate all task classes by finalizing the app.
+    app.finalize()
+
+    # set fast shortcut to task registry
+    _tasks = app._tasks
+
+    trace_task_ret = _fast_trace_task
+    from celery.worker import job as job_module
+    job_module.trace_task_ret = _fast_trace_task
+    job_module.__optimize__()
+
+
+def reset_worker_optimizations():
+    global trace_task_ret
+    trace_task_ret = _trace_task_ret
+    try:
+        delattr(BaseTask, '_stackprotected')
+    except AttributeError:
+        pass
+    try:
+        BaseTask.__call__ = _patched.pop('BaseTask.__call__')
+    except KeyError:
+        pass
+    from celery.worker import job as job_module
+    job_module.trace_task_ret = _trace_task_ret
+
+
+def _install_stack_protection():
+    # Patches BaseTask.__call__ in the worker to handle the edge case
+    # where people override it and also call super.
+    #
+    # - The worker optimizes away BaseTask.__call__ and instead
+    #   calls task.run directly.
+    # - so with the addition of current_task and the request stack
+    #   BaseTask.__call__ now pushes to those stacks so that
+    #   they work when tasks are called directly.
+    #
+    # The worker only optimizes away __call__ in the case
+    # where it has not been overridden, so the request/task stack
+    # will blow if a custom task class defines __call__ and also
+    # calls super().
+    if not getattr(BaseTask, '_stackprotected', False):
+        _patched['BaseTask.__call__'] = orig = BaseTask.__call__
+
+        def __protected_call__(self, *args, **kwargs):
+            stack = self.request_stack
+            req = stack.top
+            if req and not req._protected and \
+                    len(stack) == 1 and not req.called_directly:
+                req._protected = 1
+                return self.run(*args, **kwargs)
+            return orig(self, *args, **kwargs)
+        BaseTask.__call__ = __protected_call__
+        BaseTask._stackprotected = True

+ 1 - 1
celery/apps/worker.py

@@ -27,7 +27,7 @@ from celery import VERSION_BANNER, platforms, signals
 from celery.exceptions import SystemTerminate
 from celery.five import string, string_t
 from celery.loaders.app import AppLoader
-from celery.task import trace
+from celery.app import trace
 from celery.utils import cry, isatty
 from celery.utils.imports import qualname
 from celery.utils.log import get_logger, in_sighandler, set_in_sighandler

+ 2 - 2
celery/concurrency/processes.py

@@ -38,9 +38,9 @@ from kombu.utils.eventio import SELECT_BAD_FD
 from celery import platforms
 from celery import signals
 from celery._state import set_default_app
+from celery.app import trace
 from celery.concurrency.base import BasePool
 from celery.five import Counter, items, values
-from celery.task import trace
 from celery.utils.log import get_logger
 from celery.worker.hub import READ, WRITE, ERR
 
@@ -97,7 +97,7 @@ def process_initializer(app, hostname):
         set_default_app(app)
         app.finalize()
         trace._tasks = app._tasks  # enables fast_trace_task optimization.
-    from celery.task.trace import build_tracer
+    from celery.app.trace import build_tracer
     for name, task in items(app.tasks):
         task.__trace__ = build_tracer(name, task, app.loader, hostname)
     signals.worker_process_init.send(sender=None)

+ 3 - 380
celery/task/trace.py

@@ -1,384 +1,7 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.task.trace
-    ~~~~~~~~~~~~~~~~~~~~
-
-    This module defines how the task execution is traced:
-    errors are recorded, handlers are applied and so on.
-
-"""
+"""This module has moved to celery.app.trace."""
 from __future__ import absolute_import
 
-# ## ---
-# This is the heart of the worker, the inner loop so to speak.
-# It used to be split up into nice little classes and methods,
-# but in the end it only resulted in bad performance and horrible tracebacks,
-# so instead we now use one closure per task class.
-
-import os
-import socket
 import sys
 
-from warnings import warn
-
-from billiard.einfo import ExceptionInfo
-from kombu.utils import kwdict
-
-from celery import current_app
-from celery import states, signals
-from celery._state import _task_stack
-from celery.app import set_default_app
-from celery.app.task import Task as BaseTask, Context
-from celery.exceptions import Ignore, RetryTaskError
-from celery.utils.log import get_logger
-from celery.utils.objects import mro_lookup
-from celery.utils.serialization import (
-    get_pickleable_exception,
-    get_pickleable_etype,
-)
-
-_logger = get_logger(__name__)
-
-send_prerun = signals.task_prerun.send
-send_postrun = signals.task_postrun.send
-send_success = signals.task_success.send
-STARTED = states.STARTED
-SUCCESS = states.SUCCESS
-IGNORED = states.IGNORED
-RETRY = states.RETRY
-FAILURE = states.FAILURE
-EXCEPTION_STATES = states.EXCEPTION_STATES
-IGNORE_STATES = frozenset([IGNORED, RETRY])
-
-#: set by :func:`setup_worker_optimizations`
-_tasks = None
-_patched = {}
-
-
-def task_has_custom(task, attr):
-    """Returns true if the task or one of its bases
-    defines ``attr`` (excluding the one in BaseTask)."""
-    return mro_lookup(task.__class__, attr, stop=(BaseTask, object),
-                      monkey_patched=['celery.app.task'])
-
-
-class TraceInfo(object):
-    __slots__ = ('state', 'retval')
-
-    def __init__(self, state, retval=None):
-        self.state = state
-        self.retval = retval
-
-    def handle_error_state(self, task, eager=False):
-        store_errors = not eager
-        if task.ignore_result:
-            store_errors = task.store_errors_even_if_ignored
-
-        return {
-            RETRY: self.handle_retry,
-            FAILURE: self.handle_failure,
-        }[self.state](task, store_errors=store_errors)
-
-    def handle_retry(self, task, store_errors=True):
-        """Handle retry exception."""
-        # the exception raised is the RetryTaskError semi-predicate,
-        # and it's exc' attribute is the original exception raised (if any).
-        req = task.request
-        type_, _, tb = sys.exc_info()
-        try:
-            reason = self.retval
-            einfo = ExceptionInfo((type_, reason, tb))
-            if store_errors:
-                task.backend.mark_as_retry(req.id, reason.exc, einfo.traceback)
-            task.on_retry(reason.exc, req.id, req.args, req.kwargs, einfo)
-            signals.task_retry.send(sender=task, request=req,
-                                    reason=reason, einfo=einfo)
-            return einfo
-        finally:
-            del(tb)
-
-    def handle_failure(self, task, store_errors=True):
-        """Handle exception."""
-        req = task.request
-        type_, _, tb = sys.exc_info()
-        try:
-            exc = self.retval
-            einfo = ExceptionInfo()
-            einfo.exception = get_pickleable_exception(einfo.exception)
-            einfo.type = get_pickleable_etype(einfo.type)
-            if store_errors:
-                task.backend.mark_as_failure(req.id, exc, einfo.traceback)
-            task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
-            signals.task_failure.send(sender=task, task_id=req.id,
-                                      exception=exc, args=req.args,
-                                      kwargs=req.kwargs,
-                                      traceback=tb,
-                                      einfo=einfo)
-            return einfo
-        finally:
-            del(tb)
-
-
-def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
-                 Info=TraceInfo, eager=False, propagate=False,
-                 IGNORE_STATES=IGNORE_STATES):
-    """Returns a function that traces task execution; catches all
-    exceptions and updates result backend with the state and result
-
-    If the call was successful, it saves the result to the task result
-    backend, and sets the task status to `"SUCCESS"`.
-
-    If the call raises :exc:`~celery.exceptions.RetryTaskError`, it extracts
-    the original exception, uses that as the result and sets the task state
-    to `"RETRY"`.
-
-    If the call results in an exception, it saves the exception as the task
-    result, and sets the task state to `"FAILURE"`.
-
-    Returns a function that takes the following arguments:
-
-        :param uuid: The id of the task.
-        :param args: List of positional args to pass on to the function.
-        :param kwargs: Keyword arguments mapping to pass on to the function.
-        :keyword request: Request dict.
-
-    """
-    # If the task doesn't define a custom __call__ method
-    # we optimize it away by simply calling the run method directly,
-    # saving the extra method call and a line less in the stack trace.
-    fun = task if task_has_custom(task, '__call__') else task.run
-
-    loader = loader or current_app.loader
-    backend = task.backend
-    ignore_result = task.ignore_result
-    track_started = task.track_started
-    track_started = not eager and (task.track_started and not ignore_result)
-    publish_result = not eager and not ignore_result
-    hostname = hostname or socket.gethostname()
-
-    loader_task_init = loader.on_task_init
-    loader_cleanup = loader.on_process_cleanup
-
-    task_on_success = None
-    task_after_return = None
-    if task_has_custom(task, 'on_success'):
-        task_on_success = task.on_success
-    if task_has_custom(task, 'after_return'):
-        task_after_return = task.after_return
-
-    store_result = backend.store_result
-    backend_cleanup = backend.process_cleanup
-
-    pid = os.getpid()
-
-    request_stack = task.request_stack
-    push_request = request_stack.push
-    pop_request = request_stack.pop
-    push_task = _task_stack.push
-    pop_task = _task_stack.pop
-    on_chord_part_return = backend.on_chord_part_return
-
-    prerun_receivers = signals.task_prerun.receivers
-    postrun_receivers = signals.task_postrun.receivers
-    success_receivers = signals.task_success.receivers
-
-    from celery import canvas
-    subtask = canvas.subtask
-
-    def trace_task(uuid, args, kwargs, request=None):
-        R = I = None
-        kwargs = kwdict(kwargs)
-        try:
-            push_task(task)
-            task_request = Context(request or {}, args=args,
-                                   called_directly=False, kwargs=kwargs)
-            push_request(task_request)
-            try:
-                # -*- PRE -*-
-                if prerun_receivers:
-                    send_prerun(sender=task, task_id=uuid, task=task,
-                                args=args, kwargs=kwargs)
-                loader_task_init(uuid, task)
-                if track_started:
-                    store_result(uuid, {'pid': pid,
-                                        'hostname': hostname}, STARTED)
-
-                # -*- TRACE -*-
-                try:
-                    R = retval = fun(*args, **kwargs)
-                    state = SUCCESS
-                except Ignore as exc:
-                    I, R = Info(IGNORED, exc), ExceptionInfo(internal=True)
-                    state, retval = I.state, I.retval
-                except RetryTaskError as exc:
-                    I = Info(RETRY, exc)
-                    state, retval = I.state, I.retval
-                    R = I.handle_error_state(task, eager=eager)
-                except Exception as exc:
-                    if propagate:
-                        raise
-                    I = Info(FAILURE, exc)
-                    state, retval = I.state, I.retval
-                    R = I.handle_error_state(task, eager=eager)
-                    [subtask(errback).apply_async((uuid, ))
-                        for errback in task_request.errbacks or []]
-                except BaseException as exc:
-                    raise
-                else:
-                    # callback tasks must be applied before the result is
-                    # stored, so that result.children is populated.
-                    [subtask(callback).apply_async((retval, ))
-                        for callback in task_request.callbacks or []]
-                    if publish_result:
-                        store_result(uuid, retval, SUCCESS)
-                    if task_on_success:
-                        task_on_success(retval, uuid, args, kwargs)
-                    if success_receivers:
-                        send_success(sender=task, result=retval)
-
-                # -* POST *-
-                if state not in IGNORE_STATES:
-                    if task_request.chord:
-                        on_chord_part_return(task)
-                    if task_after_return:
-                        task_after_return(
-                            state, retval, uuid, args, kwargs, None,
-                        )
-                    if postrun_receivers:
-                        send_postrun(sender=task, task_id=uuid, task=task,
-                                     args=args, kwargs=kwargs,
-                                     retval=retval, state=state)
-            finally:
-                pop_task()
-                pop_request()
-                if not eager:
-                    try:
-                        backend_cleanup()
-                        loader_cleanup()
-                    except (KeyboardInterrupt, SystemExit, MemoryError):
-                        raise
-                    except Exception as exc:
-                        _logger.error('Process cleanup failed: %r', exc,
-                                      exc_info=True)
-        except MemoryError:
-            raise
-        except Exception as exc:
-            if eager:
-                raise
-            R = report_internal_error(task, exc)
-        return R, I
-
-    return trace_task
-
-
-def trace_task(task, uuid, args, kwargs, request={}, **opts):
-    try:
-        if task.__trace__ is None:
-            task.__trace__ = build_tracer(task.name, task, **opts)
-        return task.__trace__(uuid, args, kwargs, request)[0]
-    except Exception as exc:
-        return report_internal_error(task, exc)
-
-
-def _trace_task_ret(name, uuid, args, kwargs, request={}, **opts):
-    return trace_task(current_app.tasks[name],
-                      uuid, args, kwargs, request, **opts)
-trace_task_ret = _trace_task_ret
-
-
-def _fast_trace_task(task, uuid, args, kwargs, request={}):
-    # setup_worker_optimizations will point trace_task_ret to here,
-    # so this is the function used in the worker.
-    return _tasks[task].__trace__(uuid, args, kwargs, request)[0]
-
-
-def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
-    opts.setdefault('eager', True)
-    return build_tracer(task.name, task, **opts)(
-        uuid, args, kwargs, request)
-
-
-def report_internal_error(task, exc):
-    _type, _value, _tb = sys.exc_info()
-    try:
-        _value = task.backend.prepare_exception(exc)
-        exc_info = ExceptionInfo((_type, _value, _tb), internal=True)
-        warn(RuntimeWarning(
-            'Exception raised outside body: {0!r}:\n{1}'.format(
-                exc, exc_info.traceback)))
-        return exc_info
-    finally:
-        del(_tb)
-
-
-def setup_worker_optimizations(app):
-    global _tasks
-    global trace_task_ret
-
-    # make sure custom Task.__call__ methods that calls super
-    # will not mess up the request/task stack.
-    _install_stack_protection()
-
-    # all new threads start without a current app, so if an app is not
-    # passed on to the thread it will fall back to the "default app",
-    # which then could be the wrong app.  So for the worker
-    # we set this to always return our app.  This is a hack,
-    # and means that only a single app can be used for workers
-    # running in the same process.
-    app.set_current()
-    set_default_app(app)
-
-    # evaluate all task classes by finalizing the app.
-    app.finalize()
-
-    # set fast shortcut to task registry
-    _tasks = app._tasks
-
-    trace_task_ret = _fast_trace_task
-    from celery.worker import job as job_module
-    job_module.trace_task_ret = _fast_trace_task
-    job_module.__optimize__()
-
-
-def reset_worker_optimizations():
-    global trace_task_ret
-    trace_task_ret = _trace_task_ret
-    try:
-        delattr(BaseTask, '_stackprotected')
-    except AttributeError:
-        pass
-    try:
-        BaseTask.__call__ = _patched.pop('BaseTask.__call__')
-    except KeyError:
-        pass
-    from celery.worker import job as job_module
-    job_module.trace_task_ret = _trace_task_ret
-
-
-def _install_stack_protection():
-    # Patches BaseTask.__call__ in the worker to handle the edge case
-    # where people override it and also call super.
-    #
-    # - The worker optimizes away BaseTask.__call__ and instead
-    #   calls task.run directly.
-    # - so with the addition of current_task and the request stack
-    #   BaseTask.__call__ now pushes to those stacks so that
-    #   they work when tasks are called directly.
-    #
-    # The worker only optimizes away __call__ in the case
-    # where it has not been overridden, so the request/task stack
-    # will blow if a custom task class defines __call__ and also
-    # calls super().
-    if not getattr(BaseTask, '_stackprotected', False):
-        _patched['BaseTask.__call__'] = orig = BaseTask.__call__
-
-        def __protected_call__(self, *args, **kwargs):
-            stack = self.request_stack
-            req = stack.top
-            if req and not req._protected and \
-                    len(stack) == 1 and not req.called_directly:
-                req._protected = 1
-                return self.run(*args, **kwargs)
-            return orig(self, *args, **kwargs)
-        BaseTask.__call__ = __protected_call__
-        BaseTask._stackprotected = True
+from celery.app import trace
+sys.modules[__name__] = trace

+ 1 - 1
celery/tests/bin/test_worker.py

@@ -16,10 +16,10 @@ from celery import Celery
 from celery import platforms
 from celery import signals
 from celery import current_app
+from celery.app import trace
 from celery.apps import worker as cd
 from celery.bin.worker import worker, main as worker_main
 from celery.exceptions import ImproperlyConfigured, SystemTerminate
-from celery.task import trace
 from celery.utils.log import ensure_process_aware_logger
 from celery.worker import state
 

+ 3 - 3
celery/tests/tasks/test_trace.py

@@ -6,7 +6,7 @@ from celery import uuid
 from celery import signals
 from celery import states
 from celery.exceptions import RetryTaskError, Ignore
-from celery.task.trace import (
+from celery.app.trace import (
     TraceInfo,
     eager_trace_task,
     trace_task,
@@ -144,8 +144,8 @@ class test_trace(TraceCase):
         with self.assertRaises(KeyError):
             trace(self.raises, (KeyError('foo'), ), {}, propagate=True)
 
-    @patch('celery.task.trace.build_tracer')
-    @patch('celery.task.trace.report_internal_error')
+    @patch('celery.app.trace.build_tracer')
+    @patch('celery.app.trace.report_internal_error')
     def test_outside_body_error(self, report_internal_error, build_tracer):
         tracer = Mock()
         tracer.side_effect = KeyError('foo')

+ 11 - 11
celery/tests/worker/test_request.py

@@ -17,6 +17,15 @@ from mock import Mock, patch
 from nose import SkipTest
 
 from celery import states
+from celery.app.trace import (
+    trace_task,
+    _trace_task_ret,
+    TraceInfo,
+    mro_lookup,
+    build_tracer,
+    setup_worker_optimizations,
+    reset_worker_optimizations,
+)
 from celery.concurrency.base import BasePool
 from celery.exceptions import (
     RetryTaskError,
@@ -27,15 +36,6 @@ from celery.exceptions import (
     Ignore,
 )
 from celery.five import keys
-from celery.task.trace import (
-    trace_task,
-    _trace_task_ret,
-    TraceInfo,
-    mro_lookup,
-    build_tracer,
-    setup_worker_optimizations,
-    reset_worker_optimizations,
-)
 from celery.result import AsyncResult
 from celery.signals import task_revoked
 from celery.task import task as task_dec
@@ -164,7 +164,7 @@ class test_RetryTaskError(AppCase):
 
 class test_trace_task(AppCase):
 
-    @patch('celery.task.trace._logger')
+    @patch('celery.app.trace._logger')
     def test_process_cleanup_fails(self, _logger):
         backend = mytask.backend
         mytask.backend = Mock()
@@ -689,7 +689,7 @@ class test_Request(AppCase):
             mytask.ignore_result = False
 
     def test_fast_trace_task(self):
-        from celery.task import trace
+        from celery.app import trace
         setup_worker_optimizations(self.app)
         self.assertIs(trace.trace_task_ret, trace._fast_trace_task)
         try:

+ 1 - 1
celery/tests/worker/test_worker.py

@@ -904,7 +904,7 @@ class test_WorkController(AppCase):
             'celeryd', hostname='awesome.worker.com',
         )
 
-        with patch('celery.task.trace.setup_worker_optimizations') as swo:
+        with patch('celery.app.trace.setup_worker_optimizations') as swo:
             os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
             try:
                 process_initializer(app, 'luke.worker.com')

+ 1 - 1
celery/worker/consumer.py

@@ -32,10 +32,10 @@ from kombu.utils.limits import TokenBucket
 
 from celery import bootsteps
 from celery.app import app_or_default
+from celery.app.trace import build_tracer
 from celery.canvas import subtask
 from celery.exceptions import InvalidTaskError
 from celery.five import items, values
-from celery.task.trace import build_tracer
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.text import truncate

+ 3 - 3
celery/worker/job.py

@@ -22,6 +22,7 @@ from kombu.utils.encoding import safe_repr, safe_str
 
 from celery import signals
 from celery.app import app_or_default
+from celery.app.trace import trace_task, trace_task_ret
 from celery.exceptions import (
     Ignore, TaskRevokedError, InvalidTaskError,
     SoftTimeLimitExceeded, TimeLimitExceeded,
@@ -29,7 +30,6 @@ from celery.exceptions import (
 )
 from celery.five import items
 from celery.platforms import signals as _signals
-from celery.task.trace import trace_task, trace_task_ret
 from celery.utils import fun_takes_kwargs
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
@@ -49,7 +49,7 @@ _does_debug = False
 
 
 def __optimize__():
-    # this is also called by celery.task.trace.setup_worker_optimizations
+    # this is also called by celery.app.trace.setup_worker_optimizations
     global _does_debug
     global _does_info
     _does_debug = logger.isEnabledFor(logging.DEBUG)
@@ -232,7 +232,7 @@ class Request(object):
         return result
 
     def execute(self, loglevel=None, logfile=None):
-        """Execute the task in a :func:`~celery.task.trace.trace_task`.
+        """Execute the task in a :func:`~celery.app.trace.trace_task`.
 
         :keyword loglevel: The loglevel used by the task.
         :keyword logfile: The logfile used by the task.

+ 3 - 3
docs/internals/reference/celery.task.trace.rst → docs/internals/reference/celery.app.trace.rst

@@ -1,11 +1,11 @@
 ==========================================
- celery.task.trace
+ celery.app.trace
 ==========================================
 
 .. contents::
     :local:
-.. currentmodule:: celery.task.trace
+.. currentmodule:: celery.app.trace
 
-.. automodule:: celery.task.trace
+.. automodule:: celery.app.trace
     :members:
     :undoc-members:

+ 1 - 1
docs/internals/reference/index.rst

@@ -32,7 +32,7 @@
     celery.backends.mongodb
     celery.backends.redis
     celery.backends.cassandra
-    celery.task.trace
+    celery.app.trace
     celery.app.annotations
     celery.app.routes
     celery.datastructures

+ 1 - 1
funtests/benchmarks/trace.py

@@ -33,7 +33,7 @@ x.update_strategies()
 name = T.name
 ts = time()
 from celery.datastructures import AttributeDict
-from celery.task.trace import trace_task_ret
+from celery.app.trace import trace_task_ret
 request = AttributeDict(
                 {'called_directly': False,
                  'callbacks': [],