Browse Source

platforms.signals: New interface for managing process signals.

* install_signal_handler("USR1", handler) -> signals["USR1"] = handler
* install_signal_handler(USR1=handler)    -> signals.update(USR1=handler)
* ignore_signal("USR1", "USR2")           -> signals.ignore("USR1", "USR2")
* reset_signal("USR1", "USR2")            -> signals.reset("USR1", "USR2")
* get_signal("USR1")                      -> signals.signum("USR1")
Ask Solem 14 years ago
parent
commit
f0962a90db

+ 1 - 1
celery/apps/beat.py

@@ -122,4 +122,4 @@ class Beat(object):
             beat.sync()
             raise SystemExit()
 
-        platforms.install_signal_handler(SIGTERM=_sync, SIGINT=_sync)
+        platforms.signals.update(SIGTERM=_sync, SIGINT=_sync)

+ 7 - 7
celery/apps/worker.py

@@ -286,7 +286,7 @@ def install_worker_int_handler(worker):
             worker.stop(in_sighandler=True)
         raise SystemExit()
 
-    platforms.install_signal_handler(SIGINT=_stop)
+    platforms.signals["SIGINT"] = _stop
 
 
 def install_worker_int_again_handler(worker):
@@ -301,7 +301,7 @@ def install_worker_int_again_handler(worker):
             worker.terminate(in_sighandler=True)
         raise SystemTerminate()
 
-    platforms.install_signal_handler(SIGINT=_stop)
+    platforms.signals["SIGINT"] =_stop
 
 
 def install_worker_term_handler(worker):
@@ -316,7 +316,7 @@ def install_worker_term_handler(worker):
             worker.stop(in_sighandler=True)
         raise SystemExit()
 
-    platforms.install_signal_handler(SIGTERM=_stop)
+    platforms.signals["SIGTERM"] = _stop
 
 
 def install_worker_restart_handler(worker):
@@ -328,7 +328,7 @@ def install_worker_restart_handler(worker):
         worker.stop(in_sighandler=True)
         os.execv(sys.executable, [sys.executable] + sys.argv)
 
-    platforms.install_signal_handler(SIGHUP=restart_worker_sig_handler)
+    platforms.signals["SIGHUP"] = restart_worker_sig_handler
 
 
 def install_cry_handler(logger):
@@ -341,7 +341,7 @@ def install_cry_handler(logger):
             """Signal handler logging the stacktrace of all active threads."""
             logger.error("\n" + cry())
 
-        platforms.install_signal_handler(SIGUSR1=cry_handler)
+        platforms.signals["SIGUSR1"] = cry_handler
 
 
 def install_rdb_handler():  # pragma: no cover
@@ -352,7 +352,7 @@ def install_rdb_handler():  # pragma: no cover
         rdb.set_trace(frame)
 
     if os.environ.get("CELERY_RDBSIG"):
-        platforms.install_signal_handler(SIGUSR2=rdb_handler)
+        platforms.signals["SIGUSR2"] = rdb_handler
 
 
 def install_HUP_not_supported_handler(worker):
@@ -361,4 +361,4 @@ def install_HUP_not_supported_handler(worker):
         worker.logger.error("SIGHUP not supported: "
             "Restarting with HUP is unstable on this platform!")
 
-    platforms.install_signal_handler(SIGHUP=warn_on_HUP_handler)
+    platforms.signals["SIGHUP"] = warn_on_HUP_handler

+ 0 - 2
celery/backends/base.py

@@ -19,8 +19,6 @@ class BaseBackend(object):
 
     TimeoutError = TimeoutError
 
-    can_get_many = False
-
     def __init__(self, *args, **kwargs):
         from celery.app import app_or_default
         self.app = app_or_default(kwargs.get("app"))

+ 1 - 1
celery/beat.py

@@ -420,7 +420,7 @@ if multiprocessing is not None:
             self.name = "Beat"
 
         def run(self):
-            platforms.reset_signal("SIGTERM")
+            platforms.signals.reset("SIGTERM")
             self.service.start(embedded_process=True)
 
         def stop(self):

+ 52 - 41
celery/platforms.py

@@ -3,7 +3,7 @@ from __future__ import absolute_import
 import os
 import sys
 import errno
-import signal
+import signal as _signal
 
 from celery.local import try_import
 
@@ -172,7 +172,7 @@ def detached(logfile=None, pidfile=None, uid=None, gid=None, umask=0,
         raise RuntimeError("This platform does not support detach.")
     workdir = os.getcwd() if workdir is None else workdir
 
-    reset_signal("SIGCLD")  # Make sure SIGCLD is using the default handler.
+    signals.reset("SIGCLD")  # Make sure SIGCLD is using the default handler.
     set_effective_user(uid=uid, gid=gid)
 
     # Since without stderr any errors will be silently suppressed,
@@ -258,61 +258,72 @@ def set_effective_user(uid=None, gid=None):
         gid and setegid(gid)
 
 
-def get_signal(signal_name):
-    """Get signal number from signal name."""
-    if not isinstance(signal_name, basestring) or not signal_name.isupper():
-        raise TypeError("signal name must be uppercase string.")
-    if not signal_name.startswith("SIG"):
-        signal_name = "SIG" + signal_name
-    return getattr(signal, signal_name)
+class Signals(object):
+    ignored = _signal.SIG_IGN
+    default = _signal.SIG_DFL
 
+    def supported(self, signal_name):
+        """Returns true value if ``signal_name`` exists on this platform."""
+        try:
+            return self.signum(signal_name)
+        except AttributeError:
+            pass
 
-def reset_signal(*signal_names):
-    """Reset signal to the default signal handler.
+    def signum(self, signal_name):
+        """Get signal number from signal name."""
+        if isinstance(signal_name, int):
+            return signal_name
+        if not isinstance(signal_name, basestring) or not signal_name.isupper():
+            raise TypeError("signal name must be uppercase string.")
+        if not signal_name.startswith("SIG"):
+            signal_name = "SIG" + signal_name
+        return getattr(_signal, signal_name)
 
-    Does nothing if the platform doesn't support signals,
-    or the specified signal in particular.
+    def reset(self, *signal_names):
+        """Reset signals to the default signal handler.
 
-    """
-    for signal_name in signal_names:
-        try:
-            signum = getattr(signal, signal_name)
-            signal.signal(signum, signal.SIG_DFL)
-        except (AttributeError, ValueError):
-            pass
+        Does nothing if the platform doesn't support signals,
+        or the specified signal in particular.
 
+        """
+        self.update((sig, self.default) for sig in signal_names)
 
-def ignore_signal(*signal_names):
-    """Ignore signal using :const:`SIG_IGN`.
+    def ignore(self, *signal_names):
+        """Ignore signal using :const:`SIG_IGN`.
 
-    Does nothing if the platform doesn't support signals,
-    or the specified signal in particular.
+        Does nothing if the platform doesn't support signals,
+        or the specified signal in particular.
 
-    """
-    for signal_name in signal_names:
-        try:
-            signum = getattr(signal, signal_name)
-            signal.signal(signum, signal.SIG_IGN)
-        except (AttributeError, ValueError):
-            pass
+        """
+        self.update((sig, self.ignored) for sig in signal_names)
 
+    def __getitem__(self, signal_name):
+        return _signal.getsignal(self.signum(signal_name))
 
-def install_signal_handler(signal_name=None, handler=None, **sigmap):
-    """Install signal handlers.
+    def __setitem__(self, signal_name, handler):
+        """Install signal handler.
 
-    Does nothing if the current platform doesn't support signals,
-    or the specified signal in particular.
+        Does nothing if the current platform doesn't support signals,
+        or the specified signal in particular.
 
-    """
-    if signal_name:
-        sigmap[signal_name] = handler
-    for signal_name, handler in sigmap.iteritems():
+        """
         try:
-            signum = getattr(signal, signal_name)
-            signal.signal(signum, handler)
+            _signal.signal(self.signum(signal_name), handler)
         except (AttributeError, ValueError):
             pass
 
+    def update(self, _d_=None, **sigmap):
+        """Set signal handlers from a mapping."""
+        for signal_name, handler in dict(_d_ or {}, **sigmap).iteritems():
+            self[signal_name] = handler
+
+
+signals = Signals()
+get_signal = signals.signum                   # compat
+install_signal_handler = signals.__setitem__  # compat
+reset_signal = signals.reset                  # compat
+ignore_signal = signals.ignore                # compat
+
 
 def strargv(argv):
     arg_start = 2 if "manage" in argv[0] else 1

+ 6 - 7
celery/tests/test_bin/test_celerybeat.py

@@ -81,18 +81,17 @@ class test_Beat(AppCase):
     def psig(self, fun, *args, **kwargs):
         handlers = {}
 
-        def i(sig=None, handler=None, **sigmap):
-            if sig:
-                sigmap[sig] = handler
-            handlers.update(sigmap)
+        class Signals(platforms.Signals):
 
-        p, platforms.install_signal_handler = \
-                platforms.install_signal_handler, i
+            def __setitem__(self, sig, handler):
+                handlers[sig] = handler
+
+        p, platforms.signals = platforms.signals, Signals()
         try:
             fun(*args, **kwargs)
             return handlers
         finally:
-            platforms.install_signal_handler = p
+            platforms.signals = p
 
     def test_install_sync_handler(self):
         b = beatapp.Beat()

+ 18 - 21
celery/tests/test_bin/test_celeryd.py

@@ -79,13 +79,13 @@ class test_Worker(AppCase):
     def test_run_worker(self):
         handlers = {}
 
-        def i(sig=None, handler=None, **sigmap):
-            if sig:
-                sigmap[sig] = handler
-            handlers.update(sigmap)
+        class Signals(platforms.Signals):
 
-        p = platforms.install_signal_handler
-        platforms.install_signal_handler = i
+            def __setitem__(self, sig, handler):
+                handlers[sig] = handler
+
+        p = platforms.signals
+        platforms.signals = Signals()
         try:
             w = self.Worker()
             w._isatty = False
@@ -101,7 +101,7 @@ class test_Worker(AppCase):
                 self.assertIn(sig, handlers)
             self.assertNotIn("SIGHUP", handlers)
         finally:
-            platforms.install_signal_handler = p
+            platforms.signals = p
 
     @disable_stdouts
     def test_startup_info(self):
@@ -380,18 +380,16 @@ class test_signal_handlers(AppCase):
     def psig(self, fun, *args, **kwargs):
         handlers = {}
 
-        def i(sig=None, handler=None, **sigmap):
-            if sig:
-                sigmap[sig] = handler
-            handlers.update(sigmap)
+        class Signals(platforms.Signals):
+            def __setitem__(self, sig, handler):
+                handlers[sig] = handler
 
-        p, platforms.install_signal_handler = \
-                platforms.install_signal_handler, i
+        p, platforms.signals = platforms.signals, Signals()
         try:
             fun(*args, **kwargs)
             return handlers
         finally:
-            platforms.install_signal_handler = p
+            platforms.signals = p
 
     @disable_stdouts
     def test_worker_int_handler(self):
@@ -399,19 +397,18 @@ class test_signal_handlers(AppCase):
         handlers = self.psig(cd.install_worker_int_handler, worker)
         next_handlers = {}
 
-        def i(sig=None, handler=None, **sigmap):
-            if sig:
-                sigmap[sig] = handler
-            next_handlers.update(sigmap)
+        class Signals(platforms.Signals):
+
+            def __setitem__(self, sig, handler):
+                next_handlers[sig] = handler
 
-        p = platforms.install_signal_handler
-        platforms.install_signal_handler = i
+        p, platforms.signals = platforms.signals, Signals()
         try:
             self.assertRaises(SystemExit, handlers["SIGINT"],
                               "SIGINT", object())
             self.assertTrue(worker.stopped)
         finally:
-            platforms.install_signal_handler = p
+            platforms.signals = p
 
         self.assertRaises(SystemExit, next_handlers["SIGINT"],
                           "SIGINT", object())

+ 4 - 6
celery/tests/test_worker/test_worker.py

@@ -525,11 +525,9 @@ class test_WorkController(AppCase):
         worker.logger = Mock()
         return worker
 
-    @patch("celery.platforms.reset_signal")
-    @patch("celery.platforms.ignore_signal")
+    @patch("celery.platforms.signals")
     @patch("celery.platforms.set_mp_process_title")
-    def test_process_initializer(self, set_mp_process_title, ignore_signal,
-            reset_signal):
+    def test_process_initializer(self, set_mp_process_title, _signals):
         from celery import Celery
         from celery import signals
         from celery.app import _tls
@@ -544,9 +542,9 @@ class test_WorkController(AppCase):
         app = Celery(loader=Mock(), set_as_current=False)
         process_initializer(app, "awesome.worker.com")
         self.assertIn((tuple(WORKER_SIGIGNORE), {}),
-                      ignore_signal.call_args_list)
+                      _signals.ignore.call_args_list)
         self.assertIn((tuple(WORKER_SIGRESET), {}),
-                      reset_signal.call_args_list)
+                      _signals.reset.call_args_list)
         self.assertTrue(app.loader.init_worker.call_count)
         self.assertTrue(on_worker_process_init.called)
         self.assertIs(_tls.current_app, app)

+ 2 - 2
celery/worker/__init__.py

@@ -41,8 +41,8 @@ def process_initializer(app, hostname):
     """
     app = app_or_default(app)
     app.set_current()
-    platforms.reset_signal(*WORKER_SIGRESET)
-    platforms.ignore_signal(*WORKER_SIGIGNORE)
+    platforms.signals.reset(*WORKER_SIGRESET)
+    platforms.signals.ignore(*WORKER_SIGIGNORE)
     platforms.set_mp_process_title("celeryd", hostname=hostname)
 
     # This is for Windows and other platforms not supporting

+ 2 - 2
celery/worker/control/builtins.py

@@ -2,7 +2,7 @@ import sys
 
 from datetime import datetime
 
-from celery.platforms import get_signal
+from celery.platforms import signals as _signals
 from celery.registry import tasks
 from celery.utils import timeutils
 from celery.worker import state
@@ -19,7 +19,7 @@ def revoke(panel, task_id, terminate=False, signal=None, **kwargs):
     revoked.add(task_id)
     action = "revoked"
     if terminate:
-        signum = get_signal(signal)
+        signum = _signals.signum(signal)
         for request in state.active_requests:
             if request.task_id == task_id:
                 action = "terminated (%s)" % (signum, )