Bladeren bron

Make sure BaseTask.__call__ does not modify stack even if overridden calls super() (As discussed in #947)

Ask Solem 12 jaren geleden
bovenliggende
commit
389caf247f
6 gewijzigde bestanden met toevoegingen van 69 en 18 verwijderingen
  1. 1 0
      celery/app/task.py
  2. 2 1
      celery/canvas.py
  3. 53 5
      celery/task/trace.py
  4. 3 1
      celery/tests/tasks/test_chord.py
  5. 7 0
      celery/utils/threads.py
  6. 3 11
      celery/worker/__init__.py

+ 1 - 0
celery/app/task.py

@@ -56,6 +56,7 @@ class Context(object):
     errbacks = None
     errbacks = None
     timeouts = None
     timeouts = None
     _children = None   # see property
     _children = None   # see property
+    _protected = 0
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         self.update(*args, **kwargs)
         self.update(*args, **kwargs)

+ 2 - 1
celery/canvas.py

@@ -354,10 +354,11 @@ class chord(Signature):
     def __call__(self, body=None, **kwargs):
     def __call__(self, body=None, **kwargs):
         _chord = self.Chord
         _chord = self.Chord
         body = (body or self.kwargs['body']).clone()
         body = (body or self.kwargs['body']).clone()
+        kwargs = dict(self.kwargs, body=body, **kwargs)
         if _chord.app.conf.CELERY_ALWAYS_EAGER:
         if _chord.app.conf.CELERY_ALWAYS_EAGER:
             return self.apply((), kwargs)
             return self.apply((), kwargs)
         callback_id = body.options.setdefault('task_id', uuid())
         callback_id = body.options.setdefault('task_id', uuid())
-        _chord(**dict(self.kwargs, body=body, **kwargs))
+        _chord(**kwargs)
         return _chord.AsyncResult(callback_id)
         return _chord.AsyncResult(callback_id)
 
 
     def clone(self, *args, **kwargs):
     def clone(self, *args, **kwargs):

+ 53 - 5
celery/task/trace.py

@@ -26,6 +26,7 @@ from kombu.utils import kwdict
 from celery import current_app
 from celery import current_app
 from celery import states, signals
 from celery import states, signals
 from celery._state import _task_stack, default_app
 from celery._state import _task_stack, default_app
+from celery.app import set_default_app
 from celery.app.task import Task as BaseTask, Context
 from celery.app.task import Task as BaseTask, Context
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
@@ -46,11 +47,58 @@ RETRY = states.RETRY
 FAILURE = states.FAILURE
 FAILURE = states.FAILURE
 EXCEPTION_STATES = states.EXCEPTION_STATES
 EXCEPTION_STATES = states.EXCEPTION_STATES
 
 
-try:
-    _tasks = default_app._tasks
-except AttributeError:
-    # Windows: will be set later by concurrency.processes.
-    pass
+#: set by :func:`setup_worker_optimizations`
+_tasks = None
+
+
+def setup_worker_optimizations(app):
+    global _tasks
+
+    # make sure custom Task.__call__ methods that calls super
+    # will not mess up the request/task stack.
+    _install_stack_protection()
+
+    # all new threads start without a current app, so if an app is not
+    # passed on to the thread it will fall back to the "default app",
+    # which then could be the wrong app.  So for the worker
+    # we set this to always return our app.  This is a hack,
+    # and means that only a single app can be used for workers
+    # running in the same process.
+    set_default_app(app)
+
+    # evaluate all task classes by finalizing the app.
+    app.finalize()
+
+    # set fast shortcut to task registry
+    _tasks = app._tasks
+
+
+def _install_stack_protection():
+    # Patches BaseTask.__call__ in the worker to handle the edge case
+    # where people override it and also call super.
+    #
+    # - The worker optimizes away BaseTask.__call__ and instead
+    #   calls task.run directly.
+    # - so with the addition of current_task and the request stack
+    #   BaseTask.__call__ now pushes to those stacks so that
+    #   they work when tasks are called directly.
+    #
+    # The worker only optimizes away __call__ in the case
+    # where it has not been overridden, so the request/task stack
+    # will blow if a custom task class defines __call__ and also
+    # calls super().
+    if not getattr(BaseTask, '_stackprotected', False):
+        orig = BaseTask.__call__
+
+        def __protected_call__(self, *args, **kwargs):
+            req, stack = self.request, self.request_stack
+            if not req._protected and len(stack) == 2 and \
+                    not req.called_directly:
+                req._protected = 1
+                return self.run(*args, **kwargs)
+            return orig(self, *args, **kwargs)
+        BaseTask.__call__ = __protected_call__
+        BaseTask._stackprotected = True
 
 
 
 
 def mro_lookup(cls, attr, stop=(), monkey_patched=[]):
 def mro_lookup(cls, attr, stop=(), monkey_patched=[]):

+ 3 - 1
celery/tests/tasks/test_chord.py

@@ -132,7 +132,9 @@ class test_chord(AppCase):
             x = chord(add.s(i, i) for i in xrange(10))
             x = chord(add.s(i, i) for i in xrange(10))
             body = add.s(2)
             body = add.s(2)
             result = x(body)
             result = x(body)
-            self.assertEqual(result.id, body.options['task_id'])
+            # does not modify original subtask
+            with self.assertRaises(KeyError):
+                body.options['task_id']
             self.assertTrue(chord.Chord.called)
             self.assertTrue(chord.Chord.called)
         finally:
         finally:
             chord.Chord = prev
             chord.Chord = prev

+ 7 - 0
celery/utils/threads.py

@@ -213,6 +213,10 @@ class _LocalStack(object):
         else:
         else:
             return stack.pop()
             return stack.pop()
 
 
+    def __len__(self):
+        stack = getattr(self._local, 'stack', None)
+        return len(stack) if stack else 0
+
     @property
     @property
     def stack(self):
     def stack(self):
         """get_current_worker_task uses this to find
         """get_current_worker_task uses this to find
@@ -294,6 +298,9 @@ class _FastLocalStack(threading.local):
         except (AttributeError, IndexError):
         except (AttributeError, IndexError):
             return None
             return None
 
 
+    def __len__(self):
+        return len(self.stack)
+
 if USE_FAST_LOCALS:
 if USE_FAST_LOCALS:
     LocalStack = _FastLocalStack
     LocalStack = _FastLocalStack
 else:
 else:

+ 3 - 11
celery/worker/__init__.py

@@ -24,12 +24,12 @@ from kombu.utils.finalize import Finalize
 from celery import concurrency as _concurrency
 from celery import concurrency as _concurrency
 from celery import platforms
 from celery import platforms
 from celery import signals
 from celery import signals
-from celery.app import app_or_default, set_default_app
+from celery.app import app_or_default
 from celery.app.abstract import configurated, from_config
 from celery.app.abstract import configurated, from_config
 from celery.exceptions import (
 from celery.exceptions import (
     ImproperlyConfigured, SystemTerminate, TaskRevokedError,
     ImproperlyConfigured, SystemTerminate, TaskRevokedError,
 )
 )
-from celery.task import trace
+from celery.task.trace import setup_worker_optimizations
 from celery.utils import worker_direct
 from celery.utils import worker_direct
 from celery.utils.imports import qualname, reload_from_cwd
 from celery.utils.imports import qualname, reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 from celery.utils.log import mlevel, worker_logger as logger
@@ -115,15 +115,7 @@ class WorkController(configurated):
 
 
     def __init__(self, app=None, hostname=None, **kwargs):
     def __init__(self, app=None, hostname=None, **kwargs):
         self.app = app_or_default(app or self.app)
         self.app = app_or_default(app or self.app)
-        # all new threads start without a current app, so if an app is not
-        # passed on to the thread it will fall back to the "default app",
-        # which then could be the wrong app.  So for the worker
-        # we set this to always return our app.  This is a hack,
-        # and means that only a single app can be used for workers
-        # running in the same process.
-        set_default_app(self.app)
-        self.app.finalize()
-        trace._tasks = self.app._tasks   # optimization
+        setup_worker_optimizations(self.app)
         self.hostname = hostname or socket.gethostname()
         self.hostname = hostname or socket.gethostname()
         self.on_before_init(**kwargs)
         self.on_before_init(**kwargs)