ソースを参照

Worker optimizations now set up in celery.apps.worker not WorkerController, and tests reset optimiations

Ask Solem 12 年 前
コミット
1514f01179

+ 4 - 0
celery/apps/worker.py

@@ -24,6 +24,7 @@ from billiard import current_process
 from celery import VERSION_BANNER, platforms, signals
 from celery.exceptions import SystemTerminate
 from celery.loaders.app import AppLoader
+from celery.task import trace
 from celery.utils import cry, isatty
 from celery.utils.imports import qualname
 from celery.utils.log import get_logger, in_sighandler, set_in_sighandler
@@ -82,6 +83,9 @@ class Worker(WorkController):
 
     def on_before_init(self, purge=False, redirect_stdouts=None,
             redirect_stdouts_level=None, **kwargs):
+        # apply task execution optimizations
+        trace.setup_worker_optimizations(self.app)
+
         # this signal can be used to set up configuration for
         # workers by name.
         conf = self.app.conf

+ 27 - 2
celery/task/trace.py

@@ -49,10 +49,12 @@ EXCEPTION_STATES = states.EXCEPTION_STATES
 
 #: set by :func:`setup_worker_optimizations`
 _tasks = None
+_patched = {}
 
 
 def setup_worker_optimizations(app):
     global _tasks
+    global trace_task_ret
 
     # make sure custom Task.__call__ methods that calls super
     # will not mess up the request/task stack.
@@ -72,6 +74,21 @@ def setup_worker_optimizations(app):
     # set fast shortcut to task registry
     _tasks = app._tasks
 
+    trace_task_ret = _fast_trace_task
+
+
+def reset_worker_optimizations():
+    global trace_task_ret
+    trace_task_ret = trace_task
+    try:
+        delattr(BaseTask, '_stackprotected')
+    except AttributeError:
+        pass
+    try:
+        BaseTask.__call__ = _patched.pop('BaseTask.__call__')
+    except KeyError:
+        pass
+
 
 def _install_stack_protection():
     # Patches BaseTask.__call__ in the worker to handle the edge case
@@ -88,7 +105,7 @@ def _install_stack_protection():
     # will blow if a custom task class defines __call__ and also
     # calls super().
     if not getattr(BaseTask, '_stackprotected', False):
-        orig = BaseTask.__call__
+        _patched['BaseTask.__call__'] = orig = BaseTask.__call__
 
         def __protected_call__(self, *args, **kwargs):
             req, stack = self.request, self.request_stack
@@ -336,7 +353,15 @@ def trace_task(task, uuid, args, kwargs, request={}, **opts):
         return report_internal_error(task, exc)
 
 
-def trace_task_ret(task, uuid, args, kwargs, request={}):
+def _trace_task_ret(name, uuid, args, kwargs, request={}, **opts):
+    return trace_task(current_app.tasks[name],
+                      uuid, args, kwargs, request, **opts)
+trace_task_ret = _trace_task_ret
+
+
+def _fast_trace_task(task, uuid, args, kwargs, request={}):
+    # setup_worker_optimizations will point trace_task_ret to here,
+    # so this is the function used in the worker.
     return _tasks[task].__trace__(uuid, args, kwargs, request)[0]
 
 

+ 12 - 3
celery/tests/bin/test_celeryd.py

@@ -19,6 +19,7 @@ from celery import current_app
 from celery.apps import worker as cd
 from celery.bin.celeryd import WorkerCommand, main as celeryd_main
 from celery.exceptions import ImproperlyConfigured, SystemTerminate
+from celery.task import trace
 from celery.utils.log import ensure_process_aware_logger
 from celery.worker import state
 
@@ -32,6 +33,14 @@ from celery.tests.utils import (
 ensure_process_aware_logger()
 
 
+
+class WorkerAppCase(AppCase):
+
+    def tearDown(self):
+        super(WorkerAppCase, self).tearDown()
+        trace.reset_worker_optimizations()
+
+
 def disable_stdouts(fun):
 
     @wraps(fun)
@@ -61,7 +70,7 @@ class Worker(cd.Worker):
         self.on_start()
 
 
-class test_Worker(AppCase):
+class test_Worker(WorkerAppCase):
 
     Worker = Worker
 
@@ -369,7 +378,7 @@ class test_Worker(AppCase):
         self.assertTrue(worker_ready_sent[0])
 
 
-class test_funs(AppCase):
+class test_funs(WorkerAppCase):
 
     def test_active_thread_count(self):
         self.assertTrue(cd.active_thread_count())
@@ -417,7 +426,7 @@ class test_funs(AppCase):
             sys.argv = s
 
 
-class test_signal_handlers(AppCase):
+class test_signal_handlers(WorkerAppCase):
 
     class _Worker(object):
         stopped = False

+ 0 - 2
celery/worker/__init__.py

@@ -29,7 +29,6 @@ from celery.app.abstract import configurated, from_config
 from celery.exceptions import (
     ImproperlyConfigured, SystemTerminate, TaskRevokedError,
 )
-from celery.task.trace import setup_worker_optimizations
 from celery.utils import worker_direct
 from celery.utils.imports import qualname, reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
@@ -115,7 +114,6 @@ class WorkController(configurated):
 
     def __init__(self, app=None, hostname=None, **kwargs):
         self.app = app_or_default(app or self.app)
-        setup_worker_optimizations(self.app)
         self.hostname = hostname or socket.gethostname()
         self.on_before_init(**kwargs)