소스 검색

make set_in_sighandler a contextmanager

Ask Solem 12 년 전
부모
커밋
581daa11cb
2개의 변경된 파일18개의 추가작업 그리고 22개의 파일을 삭제
  1. 5 17
      celery/apps/worker.py
  2. 13 5
      celery/utils/log.py

+ 5 - 17
celery/apps/worker.py

@@ -26,7 +26,7 @@ from celery.exceptions import SystemTerminate
 from celery.loaders.app import AppLoader
 from celery.loaders.app import AppLoader
 from celery.utils import cry, isatty
 from celery.utils import cry, isatty
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
-from celery.utils.log import get_logger, set_in_sighandler
+from celery.utils.log import get_logger, in_sighandler, set_in_sighandler
 from celery.utils.text import pluralize
 from celery.utils.text import pluralize
 from celery.worker import WorkController
 from celery.worker import WorkController
 
 
@@ -215,8 +215,7 @@ def _shutdown_handler(worker, sig='TERM', how='Warm', exc=SystemExit,
         callback=None):
         callback=None):
 
 
     def _handle_request(signum, frame):
     def _handle_request(signum, frame):
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             from celery.worker import state
             from celery.worker import state
             if current_process()._name == 'MainProcess':
             if current_process()._name == 'MainProcess':
                 if callback:
                 if callback:
@@ -227,8 +226,6 @@ def _shutdown_handler(worker, sig='TERM', how='Warm', exc=SystemExit,
                                 'Cold': 'should_terminate'}[how], True)
                                 'Cold': 'should_terminate'}[how], True)
             else:
             else:
                 raise exc()
                 raise exc()
-        finally:
-            set_in_sighandler(False)
     _handle_request.__name__ = 'worker_' + how
     _handle_request.__name__ = 'worker_' + how
     platforms.signals[sig] = _handle_request
     platforms.signals[sig] = _handle_request
 install_worker_term_handler = partial(
 install_worker_term_handler = partial(
@@ -274,11 +271,8 @@ def install_cry_handler():
 
 
     def cry_handler(signum, frame):
     def cry_handler(signum, frame):
         """Signal handler logging the stacktrace of all active threads."""
         """Signal handler logging the stacktrace of all active threads."""
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             safe_say(cry())
             safe_say(cry())
-        finally:
-            set_in_sighandler(False)
     platforms.signals['SIGUSR1'] = cry_handler
     platforms.signals['SIGUSR1'] = cry_handler
 
 
 
 
@@ -287,12 +281,9 @@ def install_rdb_handler(envvar='CELERY_RDBSIG',
 
 
     def rdb_handler(signum, frame):
     def rdb_handler(signum, frame):
         """Signal handler setting a rdb breakpoint at the current frame."""
         """Signal handler setting a rdb breakpoint at the current frame."""
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             from celery.contrib import rdb
             from celery.contrib import rdb
             rdb.set_trace(frame)
             rdb.set_trace(frame)
-        finally:
-            set_in_sighandler(False)
     if os.environ.get(envvar):
     if os.environ.get(envvar):
         platforms.signals[sig] = rdb_handler
         platforms.signals[sig] = rdb_handler
 
 
@@ -300,10 +291,7 @@ def install_rdb_handler(envvar='CELERY_RDBSIG',
 def install_HUP_not_supported_handler(worker, sig='SIGHUP'):
 def install_HUP_not_supported_handler(worker, sig='SIGHUP'):
 
 
     def warn_on_HUP_handler(signum, frame):
     def warn_on_HUP_handler(signum, frame):
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             safe_say('{sig} not supported: Restarting with {sig} is '
             safe_say('{sig} not supported: Restarting with {sig} is '
                      'unstable on this platform!'.format(sig=sig))
                      'unstable on this platform!'.format(sig=sig))
-        finally:
-            set_in_sighandler(False)
     platforms.signals[sig] = warn_on_HUP_handler
     platforms.signals[sig] = warn_on_HUP_handler

+ 13 - 5
celery/utils/log.py

@@ -14,6 +14,7 @@ import sys
 import threading
 import threading
 import traceback
 import traceback
 
 
+from contextlib import contextmanager
 from billiard import current_process, util as mputil
 from billiard import current_process, util as mputil
 from kombu.log import get_logger as _get_logger, LOG_LEVELS
 from kombu.log import get_logger as _get_logger, LOG_LEVELS
 
 
@@ -34,12 +35,19 @@ MP_LOG = os.environ.get('MP_LOG', False)
 base_logger = logger = _get_logger('celery')
 base_logger = logger = _get_logger('celery')
 mp_logger = _get_logger('multiprocessing')
 mp_logger = _get_logger('multiprocessing')
 
 
-in_sighandler = False
+_in_sighandler = False
 
 
 
 
 def set_in_sighandler(value):
 def set_in_sighandler(value):
-    global in_sighandler
-    in_sighandler = value
+    global _in_sighandler
+    _in_sighandler = value
+
+
+@contextmanager
+def in_sighandler():
+    set_in_sighandler(True)
+    yield
+    set_in_sighandler(False)
 
 
 
 
 def get_logger(name):
 def get_logger(name):
@@ -146,7 +154,7 @@ class LoggingProxy(object):
 
 
     def write(self, data):
     def write(self, data):
         """Write message to logging object."""
         """Write message to logging object."""
-        if in_sighandler:
+        if _in_sighandler:
             print(safe_str(data), file=sys.__stderr__)
             print(safe_str(data), file=sys.__stderr__)
         if getattr(self._thread, 'recurse_protection', False):
         if getattr(self._thread, 'recurse_protection', False):
             # Logger is logging back to this file, so stop recursing.
             # Logger is logging back to this file, so stop recursing.
@@ -233,7 +241,7 @@ def _patch_logger_class():
                 _signal_safe = True
                 _signal_safe = True
 
 
                 def log(self, *args, **kwargs):
                 def log(self, *args, **kwargs):
-                    if in_sighandler:
+                    if _in_sighandler:
                         print('CANNOT LOG IN SIGHANDLER',  # noqa
                         print('CANNOT LOG IN SIGHANDLER',  # noqa
                                 file=sys.__stderr__)
                                 file=sys.__stderr__)
                         return
                         return