Procházet zdrojové kódy

Make sure BaseTask.__call__ does not modify stack even if overridden calls super() (As discussed in #947)

Ask Solem před 12 roky
rodič
revize
389caf247f

+ 1 - 0
celery/app/task.py

@@ -56,6 +56,7 @@ class Context(object):
     errbacks = None
     timeouts = None
     _children = None   # see property
+    _protected = 0
 
     def __init__(self, *args, **kwargs):
         self.update(*args, **kwargs)

+ 2 - 1
celery/canvas.py

@@ -354,10 +354,11 @@ class chord(Signature):
     def __call__(self, body=None, **kwargs):
         _chord = self.Chord
         body = (body or self.kwargs['body']).clone()
+        kwargs = dict(self.kwargs, body=body, **kwargs)
         if _chord.app.conf.CELERY_ALWAYS_EAGER:
             return self.apply((), kwargs)
         callback_id = body.options.setdefault('task_id', uuid())
-        _chord(**dict(self.kwargs, body=body, **kwargs))
+        _chord(**kwargs)
         return _chord.AsyncResult(callback_id)
 
     def clone(self, *args, **kwargs):

+ 53 - 5
celery/task/trace.py

@@ -26,6 +26,7 @@ from kombu.utils import kwdict
 from celery import current_app
 from celery import states, signals
 from celery._state import _task_stack, default_app
+from celery.app import set_default_app
 from celery.app.task import Task as BaseTask, Context
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import RetryTaskError
@@ -46,11 +47,58 @@ RETRY = states.RETRY
 FAILURE = states.FAILURE
 EXCEPTION_STATES = states.EXCEPTION_STATES
 
-try:
-    _tasks = default_app._tasks
-except AttributeError:
-    # Windows: will be set later by concurrency.processes.
-    pass
+#: set by :func:`setup_worker_optimizations`
+_tasks = None
+
+
+def setup_worker_optimizations(app):
+    global _tasks
+
+    # make sure custom Task.__call__ methods that calls super
+    # will not mess up the request/task stack.
+    _install_stack_protection()
+
+    # all new threads start without a current app, so if an app is not
+    # passed on to the thread it will fall back to the "default app",
+    # which then could be the wrong app.  So for the worker
+    # we set this to always return our app.  This is a hack,
+    # and means that only a single app can be used for workers
+    # running in the same process.
+    set_default_app(app)
+
+    # evaluate all task classes by finalizing the app.
+    app.finalize()
+
+    # set fast shortcut to task registry
+    _tasks = app._tasks
+
+
+def _install_stack_protection():
+    # Patches BaseTask.__call__ in the worker to handle the edge case
+    # where people override it and also call super.
+    #
+    # - The worker optimizes away BaseTask.__call__ and instead
+    #   calls task.run directly.
+    # - so with the addition of current_task and the request stack
+    #   BaseTask.__call__ now pushes to those stacks so that
+    #   they work when tasks are called directly.
+    #
+    # The worker only optimizes away __call__ in the case
+    # where it has not been overridden, so the request/task stack
+    # will blow if a custom task class defines __call__ and also
+    # calls super().
+    if not getattr(BaseTask, '_stackprotected', False):
+        orig = BaseTask.__call__
+
+        def __protected_call__(self, *args, **kwargs):
+            req, stack = self.request, self.request_stack
+            if not req._protected and len(stack) == 2 and \
+                    not req.called_directly:
+                req._protected = 1
+                return self.run(*args, **kwargs)
+            return orig(self, *args, **kwargs)
+        BaseTask.__call__ = __protected_call__
+        BaseTask._stackprotected = True
 
 
 def mro_lookup(cls, attr, stop=(), monkey_patched=[]):

+ 3 - 1
celery/tests/tasks/test_chord.py

@@ -132,7 +132,9 @@ class test_chord(AppCase):
             x = chord(add.s(i, i) for i in xrange(10))
             body = add.s(2)
             result = x(body)
-            self.assertEqual(result.id, body.options['task_id'])
+            # does not modify original subtask
+            with self.assertRaises(KeyError):
+                body.options['task_id']
             self.assertTrue(chord.Chord.called)
         finally:
             chord.Chord = prev

+ 7 - 0
celery/utils/threads.py

@@ -213,6 +213,10 @@ class _LocalStack(object):
         else:
             return stack.pop()
 
+    def __len__(self):
+        stack = getattr(self._local, 'stack', None)
+        return len(stack) if stack else 0
+
     @property
     def stack(self):
         """get_current_worker_task uses this to find
@@ -294,6 +298,9 @@ class _FastLocalStack(threading.local):
         except (AttributeError, IndexError):
             return None
 
+    def __len__(self):
+        return len(self.stack)
+
 if USE_FAST_LOCALS:
     LocalStack = _FastLocalStack
 else:

+ 3 - 11
celery/worker/__init__.py

@@ -24,12 +24,12 @@ from kombu.utils.finalize import Finalize
 from celery import concurrency as _concurrency
 from celery import platforms
 from celery import signals
-from celery.app import app_or_default, set_default_app
+from celery.app import app_or_default
 from celery.app.abstract import configurated, from_config
 from celery.exceptions import (
     ImproperlyConfigured, SystemTerminate, TaskRevokedError,
 )
-from celery.task import trace
+from celery.task.trace import setup_worker_optimizations
 from celery.utils import worker_direct
 from celery.utils.imports import qualname, reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
@@ -115,15 +115,7 @@ class WorkController(configurated):
 
     def __init__(self, app=None, hostname=None, **kwargs):
         self.app = app_or_default(app or self.app)
-        # all new threads start without a current app, so if an app is not
-        # passed on to the thread it will fall back to the "default app",
-        # which then could be the wrong app.  So for the worker
-        # we set this to always return our app.  This is a hack,
-        # and means that only a single app can be used for workers
-        # running in the same process.
-        set_default_app(self.app)
-        self.app.finalize()
-        trace._tasks = self.app._tasks   # optimization
+        setup_worker_optimizations(self.app)
         self.hostname = hostname or socket.gethostname()
         self.on_before_init(**kwargs)