Browse Source

Signal handlers are now barebones, and only sets a flag for the consumer to respect

Ask Solem 13 years ago
parent
commit
9059724ed8
3 changed files with 83 additions and 22 deletions
  1. 47 20
      celery/apps/worker.py
  2. 33 2
      celery/utils/log.py
  3. 3 0
      celery/worker/state.py

+ 47 - 20
celery/apps/worker.py

@@ -17,7 +17,7 @@ from celery.app.abstract import configurated, from_config
 from celery.exceptions import ImproperlyConfigured, SystemTerminate
 from celery.utils import cry, isatty
 from celery.utils.imports import qualname
-from celery.utils.log import LOG_LEVELS, get_logger, mlevel
+from celery.utils.log import LOG_LEVELS, get_logger, mlevel, set_in_sighandler
 from celery.utils.text import pluralize
 from celery.worker import WorkController
 
@@ -30,6 +30,10 @@ except ImportError:  # pragma: no cover
 logger = get_logger(__name__)
 
 
+def safe_say(msg):
+    sys.__stderr__.write(msg + "\n")
+
+
 BANNER = """
  -------------- celery@%(hostname)s v%(version)s
 ---- **** -----
@@ -226,28 +230,35 @@ class Worker(configurated):
                 hostname=self.hostname)
 
 
-def _shutdown_handler(worker, sig="TERM", how="stop", exc=SystemExit,
-        callback=None, types={"terminate": "Cold", "stop": "Warm"}):
+def _shutdown_handler(worker, sig="TERM", how="Warm", exc=SystemExit,
+        callback=None):
 
     def _handle_request(signum, frame):
-        if current_process()._name == "MainProcess":
-            if callback:
-                callback(worker)
-            print("celeryd: %s shutdown (MainProcess)" % types[how])
-            getattr(worker, how)(in_sighandler=True)
-        raise exc()
+        set_in_sighandler(True)
+        try:
+            from celery.worker import state
+            if current_process()._name == "MainProcess":
+                if callback:
+                    callback(worker)
+                safe_say("celeryd: %s shutdown (MainProcess)" % how)
+            if how == "Warm":
+                state.should_stop = True
+            elif how == "Cold":
+                state.should_terminate = True
+        finally:
+            set_in_sighandler(False)
     _handle_request.__name__ = "worker_" + how
     platforms.signals[sig] = _handle_request
 install_worker_term_handler = partial(
-    _shutdown_handler, sig="SIGTERM", how="stop", exc=SystemExit,
+    _shutdown_handler, sig="SIGTERM", how="Warm", exc=SystemExit,
 )
 install_worker_term_hard_handler = partial(
-    _shutdown_handler, sig="SIGQUIT", how="terminate", exc=SystemTerminate,
+    _shutdown_handler, sig="SIGQUIT", how="Cold", exc=SystemTerminate,
 )
 
 
 def on_SIGINT(worker):
-    print("celeryd: Hitting Ctrl+C again will terminate all running tasks!")
+    safe_say("celeryd: Hitting Ctrl+C again will terminate all running tasks!")
     install_worker_term_hard_handler(worker, sig="SIGINT")
 install_worker_int_handler = partial(
     _shutdown_handler, sig="SIGINT", callback=on_SIGINT
@@ -258,9 +269,13 @@ def install_worker_restart_handler(worker, sig="SIGHUP"):
 
     def restart_worker_sig_handler(signum, frame):
         """Signal handler restarting the current python program."""
-        print("Restarting celeryd (%s)" % (" ".join(sys.argv), ))
-        worker.stop(in_sighandler=True)
-        os.execv(sys.executable, [sys.executable] + sys.argv)
+        set_in_sighandler(True)
+        safe_say("Restarting celeryd (%s)" % (" ".join(sys.argv), ))
+        pid = os.fork()
+        if pid == 0:
+            os.execv(sys.executable, [sys.executable] + sys.argv)
+        from celery.worker import state
+        state.should_stop = True
     platforms.signals[sig] = restart_worker_sig_handler
 
 
@@ -273,7 +288,11 @@ def install_cry_handler():
 
     def cry_handler(signum, frame):
         """Signal handler logging the stacktrace of all active threads."""
-        logger.error("\n" + cry())
+        set_in_sighandler(True)
+        try:
+            safe_say("\n" + cry())
+        finally:
+            set_in_sighandler(False)
     platforms.signals["SIGUSR1"] = cry_handler
 
 
@@ -282,8 +301,12 @@ def install_rdb_handler(envvar="CELERY_RDBSIG",
 
     def rdb_handler(signum, frame):
         """Signal handler setting a rdb breakpoint at the current frame."""
-        from celery.contrib import rdb
-        rdb.set_trace(frame)
+        set_in_sighandler(True)
+        try:
+            from celery.contrib import rdb
+            rdb.set_trace(frame)
+        finally:
+            set_in_sighandler(False)
     if os.environ.get(envvar):
         platforms.signals[sig] = rdb_handler
 
@@ -291,6 +314,10 @@ def install_rdb_handler(envvar="CELERY_RDBSIG",
 def install_HUP_not_supported_handler(worker, sig="SIGHUP"):
 
     def warn_on_HUP_handler(signum, frame):
-        logger.error("%(sig)s not supported: Restarting with %(sig)s is "
-            "unstable on this platform!" % {"sig": sig})
+        set_in_sighandler(True)
+        try:
+            safe_say("%(sig)s not supported: Restarting with %(sig)s is "
+                     "unstable on this platform!" % {"sig": sig})
+        finally:
+            set_in_sighandler(False)
     platforms.signals[sig] = warn_on_HUP_handler

+ 33 - 2
celery/utils/log.py

@@ -24,6 +24,13 @@ is_py3k = sys.version_info[0] == 3
 base_logger = logger = _get_logger("celery")
 mp_logger = _get_logger("multiprocessing")
 
+in_sighandler = False
+
+
+def set_in_sighandler(value):
+    global in_sighandler
+    in_sighandler = value
+
 
 def get_logger(name):
     l = _get_logger(name)
@@ -121,10 +128,12 @@ class LoggingProxy(object):
         return map(wrap_handler, self.logger.handlers)
 
     def write(self, data):
+        """Write message to logging object."""
+        if in_sighandler:
+            return sys.__stderr__.write(safe_str(data))
         if getattr(self._thread, "recurse_protection", False):
             # Logger is logging back to this file, so stop recursing.
             return
-        """Write message to logging object."""
         data = data.strip()
         if data and not self.closed:
             self._thread.recurse_protection = True
@@ -187,9 +196,31 @@ def ensure_process_aware_logger():
 
 
 def get_multiprocessing_logger():
-    return mputil.get_logger() if mputil else None
+    return None; #mputil.get_logger() if mputil else None
 
 
 def reset_multiprocessing_logger():
     if mputil and hasattr(mputil, "_logger"):
         mputil._logger = None
+
+
+def _patch_logger_class():
+    """Make sure loggers don't log while in a signal handler."""
+
+    logging._acquireLock()
+    try:
+        OldLoggerClass = logging.getLoggerClass()
+        if not getattr(OldLoggerClass, '_signal_safe', False):
+
+            class SigSafeLogger(OldLoggerClass):
+                _signal_safe = True
+
+                def log(self, *args, **kwargs):
+                    if in_sighandler:
+                        sys.__stderr__.write("IN SIGHANDLER WON'T LOG")
+                        return
+                    return OldLoggerClass.log(self, *args, **kwargs)
+            logging.setLoggerClass(SigSafeLogger)
+    finally:
+        logging._releaseLock()
+_patch_logger_class()

+ 3 - 0
celery/worker/state.py

@@ -51,6 +51,9 @@ revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 #: Updates global state when a task has been reserved.
 task_reserved = reserved_requests.add
 
+should_stop = False
+should_terminate = False
+
 
 def task_accepted(request):
     """Updates global state when a task has been accepted."""