Browse Source

Merge branch '3.0'

Conflicts:
	celery/apps/worker.py
	celery/task/trace.py
	celery/tests/bin/test_celeryd.py
	celery/tests/tasks/test_chord.py
	celery/worker/__init__.py
Ask Solem 12 years ago
parent
commit
877b420ee2

+ 3 - 0
celery/apps/worker.py

@@ -102,6 +102,9 @@ class Worker(WorkController):
         self.redirect_stdouts_level = redirect_stdouts_level
 
     def on_start(self):
+        # apply task execution optimizations
+        trace.setup_worker_optimizations(self.app)
+
         # this signal can be used to e.g. change queues after
         # the -Q option has been applied.
         signals.celeryd_after_setup.send(sender=self.hostname, instance=self,

+ 75 - 67
celery/task/trace.py

@@ -25,7 +25,7 @@ from kombu.utils import kwdict
 
 from celery import current_app
 from celery import states, signals
-from celery._state import _task_stack, default_app
+from celery._state import _task_stack
 from celery.app import set_default_app
 from celery.app.task import Task as BaseTask, Context
 from celery.datastructures import ExceptionInfo
@@ -52,72 +52,6 @@ _tasks = None
 _patched = {}
 
 
-def setup_worker_optimizations(app):
-    global _tasks
-    global trace_task_ret
-
-    # make sure custom Task.__call__ methods that calls super
-    # will not mess up the request/task stack.
-    _install_stack_protection()
-
-    # all new threads start without a current app, so if an app is not
-    # passed on to the thread it will fall back to the "default app",
-    # which then could be the wrong app.  So for the worker
-    # we set this to always return our app.  This is a hack,
-    # and means that only a single app can be used for workers
-    # running in the same process.
-    set_default_app(app)
-
-    # evaluate all task classes by finalizing the app.
-    app.finalize()
-
-    # set fast shortcut to task registry
-    _tasks = app._tasks
-
-    trace_task_ret = _fast_trace_task
-
-
-def reset_worker_optimizations():
-    global trace_task_ret
-    trace_task_ret = trace_task
-    try:
-        delattr(BaseTask, '_stackprotected')
-    except AttributeError:
-        pass
-    try:
-        BaseTask.__call__ = _patched.pop('BaseTask.__call__')
-    except KeyError:
-        pass
-
-
-def _install_stack_protection():
-    # Patches BaseTask.__call__ in the worker to handle the edge case
-    # where people override it and also call super.
-    #
-    # - The worker optimizes away BaseTask.__call__ and instead
-    #   calls task.run directly.
-    # - so with the addition of current_task and the request stack
-    #   BaseTask.__call__ now pushes to those stacks so that
-    #   they work when tasks are called directly.
-    #
-    # The worker only optimizes away __call__ in the case
-    # where it has not been overridden, so the request/task stack
-    # will blow if a custom task class defines __call__ and also
-    # calls super().
-    if not getattr(BaseTask, '_stackprotected', False):
-        _patched['BaseTask.__call__'] = orig = BaseTask.__call__
-
-        def __protected_call__(self, *args, **kwargs):
-            req, stack = self.request, self.request_stack
-            if not req._protected and len(stack) == 2 and \
-                    not req.called_directly:
-                req._protected = 1
-                return self.run(*args, **kwargs)
-            return orig(self, *args, **kwargs)
-        BaseTask.__call__ = __protected_call__
-        BaseTask._stackprotected = True
-
-
 def mro_lookup(cls, attr, stop=(), monkey_patched=[]):
     """Returns the first node by MRO order that defines an attribute.
 
@@ -382,3 +316,77 @@ def report_internal_error(task, exc):
         return exc_info
     finally:
         del(_tb)
+
+
+def setup_worker_optimizations(app):
+    global _tasks
+    global trace_task_ret
+
+    # make sure custom Task.__call__ methods that calls super
+    # will not mess up the request/task stack.
+    _install_stack_protection()
+
+    # all new threads start without a current app, so if an app is not
+    # passed on to the thread it will fall back to the "default app",
+    # which then could be the wrong app.  So for the worker
+    # we set this to always return our app.  This is a hack,
+    # and means that only a single app can be used for workers
+    # running in the same process.
+    set_default_app(app)
+
+    # evaluate all task classes by finalizing the app.
+    app.finalize()
+
+    # set fast shortcut to task registry
+    _tasks = app._tasks
+
+    trace_task_ret = _fast_trace_task
+    try:
+        sys.modules['celery.worker.job'].trace_task_ret = _fast_trace_task
+    except KeyError:
+        pass
+
+
+def reset_worker_optimizations():
+    global trace_task_ret
+    trace_task_ret = _trace_task_ret
+    try:
+        delattr(BaseTask, '_stackprotected')
+    except AttributeError:
+        pass
+    try:
+        BaseTask.__call__ = _patched.pop('BaseTask.__call__')
+    except KeyError:
+        pass
+    try:
+        sys.modules['celery.worker.job'].trace_task_ret = _trace_task_ret
+    except KeyError:
+        pass
+
+
+def _install_stack_protection():
+    # Patches BaseTask.__call__ in the worker to handle the edge case
+    # where people override it and also call super.
+    #
+    # - The worker optimizes away BaseTask.__call__ and instead
+    #   calls task.run directly.
+    # - so with the addition of current_task and the request stack
+    #   BaseTask.__call__ now pushes to those stacks so that
+    #   they work when tasks are called directly.
+    #
+    # The worker only optimizes away __call__ in the case
+    # where it has not been overridden, so the request/task stack
+    # will blow if a custom task class defines __call__ and also
+    # calls super().
+    if not getattr(BaseTask, '_stackprotected', False):
+        _patched['BaseTask.__call__'] = orig = BaseTask.__call__
+
+        def __protected_call__(self, *args, **kwargs):
+            req, stack = self.request, self.request_stack
+            if not req._protected and len(stack) == 2 and \
+                    not req.called_directly:
+                req._protected = 1
+                return self.run(*args, **kwargs)
+            return orig(self, *args, **kwargs)
+        BaseTask.__call__ = __protected_call__
+        BaseTask._stackprotected = True

+ 0 - 2
celery/tests/bin/test_celeryd.py

@@ -33,7 +33,6 @@ from celery.tests.utils import (
 ensure_process_aware_logger()
 
 
-
 class WorkerAppCase(AppCase):
 
     def tearDown(self):
@@ -71,7 +70,6 @@ class Worker(cd.Worker):
 
 
 class test_Worker(WorkerAppCase):
-
     Worker = Worker
 
     def teardown(self):

+ 1 - 0
celery/tests/tasks/test_chord.py

@@ -132,6 +132,7 @@ class test_chord(AppCase):
             x = chord(add.s(i, i) for i in xrange(10))
             body = add.s(2)
             result = x(body)
+            self.assertTrue(result.id)
             # does not modify original subtask
             with self.assertRaises(KeyError):
                 body.options['task_id']

+ 31 - 5
celery/tests/worker/test_request.py

@@ -27,10 +27,12 @@ from celery.exceptions import (
 )
 from celery.task.trace import (
     trace_task,
-    trace_task_ret,
+    _trace_task_ret,
     TraceInfo,
     mro_lookup,
     build_tracer,
+    setup_worker_optimizations,
+    reset_worker_optimizations,
 )
 from celery.result import AsyncResult
 from celery.signals import task_revoked
@@ -41,7 +43,7 @@ from celery.worker import job as module
 from celery.worker.job import Request, TaskRequest
 from celery.worker.state import revoked
 
-from celery.tests.utils import Case, assert_signal_called
+from celery.tests.utils import AppCase, Case, assert_signal_called
 
 scratch = {'ACK': False}
 some_kwargs_scratchpad = {}
@@ -231,7 +233,7 @@ class MockEventDispatcher(object):
         self.sent.append(event)
 
 
-class test_TaskRequest(Case):
+class test_TaskRequest(AppCase):
 
     def test_task_wrapper_repr(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
@@ -570,10 +572,34 @@ class test_TaskRequest(Case):
         finally:
             mytask.ignore_result = False
 
+    def test_fast_trace_task(self):
+        from celery.task import trace
+        setup_worker_optimizations(self.app)
+        self.assertIs(trace.trace_task_ret, trace._fast_trace_task)
+        try:
+            mytask.__trace__ = build_tracer(mytask.name, mytask,
+                                            self.app.loader, 'test')
+            res = trace.trace_task_ret(mytask.name, uuid(), [4], {})
+            self.assertEqual(res, 4 ** 4)
+        finally:
+            reset_worker_optimizations()
+            self.assertIs(trace.trace_task_ret, trace._trace_task_ret)
+        delattr(mytask, '__trace__')
+        res = trace.trace_task_ret(mytask.name, uuid(), [4], {})
+        self.assertEqual(res, 4 ** 4)
+
     def test_trace_task_ret(self):
         mytask.__trace__ = build_tracer(mytask.name, mytask,
-                                        current_app.loader, 'test')
-        res = trace_task_ret(mytask.name, uuid(), [4], {})
+                                        self.app.loader, 'test')
+        res = _trace_task_ret(mytask.name, uuid(), [4], {})
+        self.assertEqual(res, 4 ** 4)
+
+    def test_trace_task_ret__no_trace(self):
+        try:
+            delattr(mytask, '__trace__')
+        except AttributeError:
+            pass
+        res = _trace_task_ret(mytask.name, uuid(), [4], {})
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):