|  | @@ -26,6 +26,7 @@ from kombu.utils import kwdict
 | 
	
		
			
				|  |  |  from celery import current_app
 | 
	
		
			
				|  |  |  from celery import states, signals
 | 
	
		
			
				|  |  |  from celery._state import _task_stack, default_app
 | 
	
		
			
				|  |  | +from celery.app import set_default_app
 | 
	
		
			
				|  |  |  from celery.app.task import Task as BaseTask, Context
 | 
	
		
			
				|  |  |  from celery.datastructures import ExceptionInfo
 | 
	
		
			
				|  |  |  from celery.exceptions import RetryTaskError
 | 
	
	
		
			
				|  | @@ -46,11 +47,58 @@ RETRY = states.RETRY
 | 
	
		
			
				|  |  |  FAILURE = states.FAILURE
 | 
	
		
			
				|  |  |  EXCEPTION_STATES = states.EXCEPTION_STATES
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -try:
 | 
	
		
			
				|  |  | -    _tasks = default_app._tasks
 | 
	
		
			
				|  |  | -except AttributeError:
 | 
	
		
			
				|  |  | -    # Windows: will be set later by concurrency.processes.
 | 
	
		
			
				|  |  | -    pass
 | 
	
		
			
				|  |  | +#: set by :func:`setup_worker_optimizations`
 | 
	
		
			
				|  |  | +_tasks = None
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def setup_worker_optimizations(app):
 | 
	
		
			
				|  |  | +    global _tasks
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # make sure custom Task.__call__ methods that calls super
 | 
	
		
			
				|  |  | +    # will not mess up the request/task stack.
 | 
	
		
			
				|  |  | +    _install_stack_protection()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # all new threads start without a current app, so if an app is not
 | 
	
		
			
				|  |  | +    # passed on to the thread it will fall back to the "default app",
 | 
	
		
			
				|  |  | +    # which then could be the wrong app.  So for the worker
 | 
	
		
			
				|  |  | +    # we set this to always return our app.  This is a hack,
 | 
	
		
			
				|  |  | +    # and means that only a single app can be used for workers
 | 
	
		
			
				|  |  | +    # running in the same process.
 | 
	
		
			
				|  |  | +    set_default_app(app)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # evaluate all task classes by finalizing the app.
 | 
	
		
			
				|  |  | +    app.finalize()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # set fast shortcut to task registry
 | 
	
		
			
				|  |  | +    _tasks = app._tasks
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _install_stack_protection():
 | 
	
		
			
				|  |  | +    # Patches BaseTask.__call__ in the worker to handle the edge case
 | 
	
		
			
				|  |  | +    # where people override it and also call super.
 | 
	
		
			
				|  |  | +    #
 | 
	
		
			
				|  |  | +    # - The worker optimizes away BaseTask.__call__ and instead
 | 
	
		
			
				|  |  | +    #   calls task.run directly.
 | 
	
		
			
				|  |  | +    # - so with the addition of current_task and the request stack
 | 
	
		
			
				|  |  | +    #   BaseTask.__call__ now pushes to those stacks so that
 | 
	
		
			
				|  |  | +    #   they work when tasks are called directly.
 | 
	
		
			
				|  |  | +    #
 | 
	
		
			
				|  |  | +    # The worker only optimizes away __call__ in the case
 | 
	
		
			
				|  |  | +    # where it has not been overridden, so the request/task stack
 | 
	
		
			
				|  |  | +    # will blow if a custom task class defines __call__ and also
 | 
	
		
			
				|  |  | +    # calls super().
 | 
	
		
			
				|  |  | +    if not getattr(BaseTask, '_stackprotected', False):
 | 
	
		
			
				|  |  | +        orig = BaseTask.__call__
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def __protected_call__(self, *args, **kwargs):
 | 
	
		
			
				|  |  | +            req, stack = self.request, self.request_stack
 | 
	
		
			
				|  |  | +            if not req._protected and len(stack) == 2 and \
 | 
	
		
			
				|  |  | +                    not req.called_directly:
 | 
	
		
			
				|  |  | +                req._protected = 1
 | 
	
		
			
				|  |  | +                return self.run(*args, **kwargs)
 | 
	
		
			
				|  |  | +            return orig(self, *args, **kwargs)
 | 
	
		
			
				|  |  | +        BaseTask.__call__ = __protected_call__
 | 
	
		
			
				|  |  | +        BaseTask._stackprotected = True
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  def mro_lookup(cls, attr, stop=(), monkey_patched=[]):
 |