Browse Source

make set_in_sighandler a contextmanager

Ask Solem 12 years ago
parent
commit
581daa11cb
2 changed files with 18 additions and 22 deletions
  1. 5 17
      celery/apps/worker.py
  2. 13 5
      celery/utils/log.py

+ 5 - 17
celery/apps/worker.py

@@ -26,7 +26,7 @@ from celery.exceptions import SystemTerminate
 from celery.loaders.app import AppLoader
 from celery.utils import cry, isatty
 from celery.utils.imports import qualname
-from celery.utils.log import get_logger, set_in_sighandler
+from celery.utils.log import get_logger, in_sighandler, set_in_sighandler
 from celery.utils.text import pluralize
 from celery.worker import WorkController
 
@@ -215,8 +215,7 @@ def _shutdown_handler(worker, sig='TERM', how='Warm', exc=SystemExit,
         callback=None):
 
     def _handle_request(signum, frame):
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             from celery.worker import state
             if current_process()._name == 'MainProcess':
                 if callback:
@@ -227,8 +226,6 @@ def _shutdown_handler(worker, sig='TERM', how='Warm', exc=SystemExit,
                                 'Cold': 'should_terminate'}[how], True)
             else:
                 raise exc()
-        finally:
-            set_in_sighandler(False)
     _handle_request.__name__ = 'worker_' + how
     platforms.signals[sig] = _handle_request
 install_worker_term_handler = partial(
@@ -274,11 +271,8 @@ def install_cry_handler():
 
     def cry_handler(signum, frame):
         """Signal handler logging the stacktrace of all active threads."""
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             safe_say(cry())
-        finally:
-            set_in_sighandler(False)
     platforms.signals['SIGUSR1'] = cry_handler
 
 
@@ -287,12 +281,9 @@ def install_rdb_handler(envvar='CELERY_RDBSIG',
 
     def rdb_handler(signum, frame):
         """Signal handler setting a rdb breakpoint at the current frame."""
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             from celery.contrib import rdb
             rdb.set_trace(frame)
-        finally:
-            set_in_sighandler(False)
     if os.environ.get(envvar):
         platforms.signals[sig] = rdb_handler
 
@@ -300,10 +291,7 @@ def install_rdb_handler(envvar='CELERY_RDBSIG',
 def install_HUP_not_supported_handler(worker, sig='SIGHUP'):
 
     def warn_on_HUP_handler(signum, frame):
-        set_in_sighandler(True)
-        try:
+        with in_sighandler():
             safe_say('{sig} not supported: Restarting with {sig} is '
                      'unstable on this platform!'.format(sig=sig))
-        finally:
-            set_in_sighandler(False)
     platforms.signals[sig] = warn_on_HUP_handler

+ 13 - 5
celery/utils/log.py

@@ -14,6 +14,7 @@ import sys
 import threading
 import traceback
 
+from contextlib import contextmanager
 from billiard import current_process, util as mputil
 from kombu.log import get_logger as _get_logger, LOG_LEVELS
 
@@ -34,12 +35,19 @@ MP_LOG = os.environ.get('MP_LOG', False)
 base_logger = logger = _get_logger('celery')
 mp_logger = _get_logger('multiprocessing')
 
-in_sighandler = False
+_in_sighandler = False
 
 
 def set_in_sighandler(value):
-    global in_sighandler
-    in_sighandler = value
+    global _in_sighandler
+    _in_sighandler = value
+
+
+@contextmanager
+def in_sighandler():
+    set_in_sighandler(True)
+    yield
+    set_in_sighandler(False)
 
 
 def get_logger(name):
@@ -146,7 +154,7 @@ class LoggingProxy(object):
 
     def write(self, data):
         """Write message to logging object."""
-        if in_sighandler:
+        if _in_sighandler:
             print(safe_str(data), file=sys.__stderr__)
         if getattr(self._thread, 'recurse_protection', False):
             # Logger is logging back to this file, so stop recursing.
@@ -233,7 +241,7 @@ def _patch_logger_class():
                 _signal_safe = True
 
                 def log(self, *args, **kwargs):
-                    if in_sighandler:
+                    if _in_sighandler:
                         print('CANNOT LOG IN SIGHANDLER',  # noqa
                                 file=sys.__stderr__)
                         return