Browse Source

Worker: Now uses WorkerShutdown instead of SystemTerminate so that gevent/eventlet can distinguish from "WorkerLostError"

Ask Solem 11 years ago
parent
commit
167c00f7e1

+ 6 - 4
celery/apps/worker.py

@@ -25,7 +25,9 @@ from kombu.utils.encoding import safe_str
 
 
 from celery import VERSION_BANNER, platforms, signals
 from celery import VERSION_BANNER, platforms, signals
 from celery.app import trace
 from celery.app import trace
-from celery.exceptions import CDeprecationWarning, SystemTerminate
+from celery.exceptions import (
+    CDeprecationWarning, WorkerShutdown, WorkerTerminate,
+)
 from celery.five import string, string_t
 from celery.five import string, string_t
 from celery.loaders.app import AppLoader
 from celery.loaders.app import AppLoader
 from celery.platforms import check_privileges
 from celery.platforms import check_privileges
@@ -275,7 +277,7 @@ class Worker(WorkController):
 
 
 
 
 def _shutdown_handler(worker, sig='TERM', how='Warm',
 def _shutdown_handler(worker, sig='TERM', how='Warm',
-                      exc=SystemExit, callback=None):
+                      exc=WorkerShutdown, callback=None):
 
 
     def _handle_request(*args):
     def _handle_request(*args):
         with in_sighandler():
         with in_sighandler():
@@ -292,11 +294,11 @@ def _shutdown_handler(worker, sig='TERM', how='Warm',
     _handle_request.__name__ = str('worker_{0}'.format(how))
     _handle_request.__name__ = str('worker_{0}'.format(how))
     platforms.signals[sig] = _handle_request
     platforms.signals[sig] = _handle_request
 install_worker_term_handler = partial(
 install_worker_term_handler = partial(
-    _shutdown_handler, sig='SIGTERM', how='Warm', exc=SystemExit,
+    _shutdown_handler, sig='SIGTERM', how='Warm', exc=WorkerShutdown,
 )
 )
 if not is_jython:  # pragma: no cover
 if not is_jython:  # pragma: no cover
     install_worker_term_hard_handler = partial(
     install_worker_term_hard_handler = partial(
-        _shutdown_handler, sig='SIGQUIT', how='Cold', exc=SystemTerminate,
+        _shutdown_handler, sig='SIGQUIT', how='Cold', exc=WorkerTerminate,
     )
     )
 else:  # pragma: no cover
 else:  # pragma: no cover
     install_worker_term_handler = \
     install_worker_term_handler = \

+ 3 - 0
celery/concurrency/base.py

@@ -16,6 +16,7 @@ from billiard.einfo import ExceptionInfo
 from billiard.exceptions import WorkerLostError
 from billiard.exceptions import WorkerLostError
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
 
 
+from celery.exceptions import WorkerShutdown, WorkerTerminate
 from celery.five import monotonic, reraise
 from celery.five import monotonic, reraise
 from celery.utils import timer2
 from celery.utils import timer2
 from celery.utils.text import truncate
 from celery.utils.text import truncate
@@ -37,6 +38,8 @@ def apply_target(target, args=(), kwargs={}, callback=None,
         raise
         raise
     except Exception:
     except Exception:
         raise
         raise
+    except (WorkerShutdown, WorkerTerminate):
+        raise
     except BaseException as exc:
     except BaseException as exc:
         try:
         try:
             reraise(WorkerLostError, WorkerLostError(repr(exc)),
             reraise(WorkerLostError, WorkerLostError(repr(exc)),

+ 10 - 3
celery/exceptions.py

@@ -14,7 +14,8 @@ from billiard.exceptions import (  # noqa
     SoftTimeLimitExceeded, TimeLimitExceeded, WorkerLostError, Terminated,
     SoftTimeLimitExceeded, TimeLimitExceeded, WorkerLostError, Terminated,
 )
 )
 
 
-__all__ = ['SecurityError', 'Ignore', 'SystemTerminate', 'QueueNotFound',
+__all__ = ['SecurityError', 'Ignore', 'QueueNotFound',
+           'WorkerShutdown', 'WorkerTerminate',
            'ImproperlyConfigured', 'NotRegistered', 'AlreadyRegistered',
            'ImproperlyConfigured', 'NotRegistered', 'AlreadyRegistered',
            'TimeoutError', 'MaxRetriesExceededError', 'Retry',
            'TimeoutError', 'MaxRetriesExceededError', 'Retry',
            'TaskRevokedError', 'NotConfigured', 'AlwaysEagerIgnored',
            'TaskRevokedError', 'NotConfigured', 'AlwaysEagerIgnored',
@@ -52,8 +53,14 @@ class Reject(Exception):
         return 'reject requeue=%s: %s' % (self.requeue, self.reason)
         return 'reject requeue=%s: %s' % (self.requeue, self.reason)
 
 
 
 
-class SystemTerminate(SystemExit):
-    """Signals that the worker should terminate."""
+class WorkerTerminate(SystemExit):
+    """Signals that the worker should terminate immediately."""
+SystemTerminate = WorkerTerminate  # XXX compat
+
+
+class WorkerShutdown(SystemExit):
+    """Signals that the worker should perform a warm shutdown."""
+
 
 
 
 
 class QueueNotFound(KeyError):
 class QueueNotFound(KeyError):

+ 10 - 8
celery/tests/bin/test_worker.py

@@ -14,7 +14,9 @@ from celery import signals
 from celery.app import trace
 from celery.app import trace
 from celery.apps import worker as cd
 from celery.apps import worker as cd
 from celery.bin.worker import worker, main as worker_main
 from celery.bin.worker import worker, main as worker_main
-from celery.exceptions import ImproperlyConfigured, SystemTerminate
+from celery.exceptions import (
+    ImproperlyConfigured, WorkerShutdown, WorkerTerminate,
+)
 from celery.utils.log import ensure_process_aware_logger
 from celery.utils.log import ensure_process_aware_logger
 from celery.worker import state
 from celery.worker import state
 
 
@@ -514,12 +516,12 @@ class test_signal_handlers(WorkerAppCase):
             c.return_value = 1
             c.return_value = 1
             p, platforms.signals = platforms.signals, Signals()
             p, platforms.signals = platforms.signals, Signals()
             try:
             try:
-                with self.assertRaises(SystemExit):
+                with self.assertRaises(WorkerShutdown):
                     handlers['SIGINT']('SIGINT', object())
                     handlers['SIGINT']('SIGINT', object())
             finally:
             finally:
                 platforms.signals = p
                 platforms.signals = p
 
 
-            with self.assertRaises(SystemTerminate):
+            with self.assertRaises(WorkerTerminate):
                 next_handlers['SIGINT']('SIGINT', object())
                 next_handlers['SIGINT']('SIGINT', object())
 
 
     @disable_stdouts
     @disable_stdouts
@@ -546,7 +548,7 @@ class test_signal_handlers(WorkerAppCase):
             try:
             try:
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_int_handler, worker)
                 handlers = self.psig(cd.install_worker_int_handler, worker)
-                with self.assertRaises(SystemExit):
+                with self.assertRaises(WorkerShutdown):
                     handlers['SIGINT']('SIGINT', object())
                     handlers['SIGINT']('SIGINT', object())
             finally:
             finally:
                 process.name = name
                 process.name = name
@@ -582,7 +584,7 @@ class test_signal_handlers(WorkerAppCase):
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(
                 handlers = self.psig(
                     cd.install_worker_term_hard_handler, worker)
                     cd.install_worker_term_hard_handler, worker)
-                with self.assertRaises(SystemTerminate):
+                with self.assertRaises(WorkerTerminate):
                     handlers['SIGQUIT']('SIGQUIT', object())
                     handlers['SIGQUIT']('SIGQUIT', object())
         finally:
         finally:
             process.name = name
             process.name = name
@@ -606,7 +608,7 @@ class test_signal_handlers(WorkerAppCase):
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_handler, worker)
             handlers = self.psig(cd.install_worker_term_handler, worker)
             try:
             try:
-                with self.assertRaises(SystemExit):
+                with self.assertRaises(WorkerShutdown):
                     handlers['SIGTERM']('SIGTERM', object())
                     handlers['SIGTERM']('SIGTERM', object())
             finally:
             finally:
                 state.should_stop = False
                 state.should_stop = False
@@ -638,7 +640,7 @@ class test_signal_handlers(WorkerAppCase):
                 c.return_value = 1
                 c.return_value = 1
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_term_handler, worker)
                 handlers = self.psig(cd.install_worker_term_handler, worker)
-                with self.assertRaises(SystemExit):
+                with self.assertRaises(WorkerShutdown):
                     handlers['SIGTERM']('SIGTERM', object())
                     handlers['SIGTERM']('SIGTERM', object())
         finally:
         finally:
             process.name = name
             process.name = name
@@ -688,5 +690,5 @@ class test_signal_handlers(WorkerAppCase):
             c.return_value = 1
             c.return_value = 1
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
-            with self.assertRaises(SystemTerminate):
+            with self.assertRaises(WorkerTerminate):
                 handlers['SIGQUIT']('SIGQUIT', object())
                 handlers['SIGQUIT']('SIGQUIT', object())

+ 4 - 4
celery/tests/worker/test_loops.py

@@ -5,7 +5,7 @@ import socket
 from kombu.async import Hub, READ, WRITE, ERR
 from kombu.async import Hub, READ, WRITE, ERR
 
 
 from celery.bootsteps import CLOSE, RUN
 from celery.bootsteps import CLOSE, RUN
-from celery.exceptions import InvalidTaskError, SystemTerminate
+from celery.exceptions import InvalidTaskError, WorkerShutdown, WorkerTerminate
 from celery.five import Empty
 from celery.five import Empty
 from celery.worker import state
 from celery.worker import state
 from celery.worker.consumer import Consumer
 from celery.worker.consumer import Consumer
@@ -174,7 +174,7 @@ class test_asynloop(AppCase):
         # XXX why aren't the errors propagated?!?
         # XXX why aren't the errors propagated?!?
         state.should_terminate = True
         state.should_terminate = True
         try:
         try:
-            with self.assertRaises(SystemTerminate):
+            with self.assertRaises(WorkerTerminate):
                 asynloop(*x.args)
                 asynloop(*x.args)
         finally:
         finally:
             state.should_terminate = False
             state.should_terminate = False
@@ -185,7 +185,7 @@ class test_asynloop(AppCase):
         state.should_terminate = True
         state.should_terminate = True
         x.hub.close.side_effect = MemoryError()
         x.hub.close.side_effect = MemoryError()
         try:
         try:
-            with self.assertRaises(SystemTerminate):
+            with self.assertRaises(WorkerTerminate):
                 asynloop(*x.args)
                 asynloop(*x.args)
         finally:
         finally:
             state.should_terminate = False
             state.should_terminate = False
@@ -194,7 +194,7 @@ class test_asynloop(AppCase):
         x = X(self.app)
         x = X(self.app)
         state.should_stop = True
         state.should_stop = True
         try:
         try:
-            with self.assertRaises(SystemExit):
+            with self.assertRaises(WorkerShutdown):
                 asynloop(*x.args)
                 asynloop(*x.args)
         finally:
         finally:
             state.should_stop = False
             state.should_stop = False

+ 3 - 3
celery/tests/worker/test_state.py

@@ -5,7 +5,7 @@ import pickle
 from time import time
 from time import time
 
 
 from celery.datastructures import LimitedSet
 from celery.datastructures import LimitedSet
-from celery.exceptions import SystemTerminate
+from celery.exceptions import WorkerShutdown, WorkerTerminate
 from celery.worker import state
 from celery.worker import state
 
 
 from celery.tests.case import AppCase, Mock, patch
 from celery.tests.case import AppCase, Mock, patch
@@ -53,12 +53,12 @@ class test_maybe_shutdown(AppCase):
 
 
     def test_should_stop(self):
     def test_should_stop(self):
         state.should_stop = True
         state.should_stop = True
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(WorkerShutdown):
             state.maybe_shutdown()
             state.maybe_shutdown()
 
 
     def test_should_terminate(self):
     def test_should_terminate(self):
         state.should_terminate = True
         state.should_terminate = True
-        with self.assertRaises(SystemTerminate):
+        with self.assertRaises(WorkerTerminate):
             state.maybe_shutdown()
             state.maybe_shutdown()
 
 
 
 

+ 9 - 7
celery/tests/worker/test_worker.py

@@ -16,7 +16,9 @@ from celery.app.defaults import DEFAULTS
 from celery.bootsteps import RUN, CLOSE, StartStopStep
 from celery.bootsteps import RUN, CLOSE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
-from celery.exceptions import SystemTerminate, TaskRevokedError
+from celery.exceptions import (
+    WorkerShutdown, WorkerTerminate, TaskRevokedError,
+)
 from celery.five import Empty, range, Queue as FastQueue
 from celery.five import Empty, range, Queue as FastQueue
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker import components
 from celery.worker import components
@@ -268,9 +270,9 @@ class test_Consumer(AppCase):
         l.event_dispatcher = mock_event_dispatcher()
         l.event_dispatcher = mock_event_dispatcher()
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.connection = Mock()
         l.connection = Mock()
-        l.connection.drain_events.side_effect = SystemExit()
+        l.connection.drain_events.side_effect = WorkerShutdown()
 
 
-        with self.assertRaises(SystemExit):
+        with self.assertRaises(WorkerShutdown):
             l.loop(*l.loop_args())
             l.loop(*l.loop_args())
         self.assertTrue(l.task_consumer.register_callback.called)
         self.assertTrue(l.task_consumer.register_callback.called)
         return l.task_consumer.register_callback.call_args[0][0]
         return l.task_consumer.register_callback.call_args[0][0]
@@ -918,10 +920,10 @@ class test_WorkController(AppCase):
         with self.assertRaises(KeyboardInterrupt):
         with self.assertRaises(KeyboardInterrupt):
             worker._process_task(task)
             worker._process_task(task)
 
 
-    def test_process_task_raise_SystemTerminate(self):
+    def test_process_task_raise_WorkerTerminate(self):
         worker = self.worker
         worker = self.worker
         worker.pool = Mock()
         worker.pool = Mock()
-        worker.pool.apply_async.side_effect = SystemTerminate()
+        worker.pool.apply_async.side_effect = WorkerTerminate()
         backend = Mock()
         backend = Mock()
         m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
         m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                            kwargs={})
                            kwargs={})
@@ -946,7 +948,7 @@ class test_WorkController(AppCase):
         worker1 = self.create_worker()
         worker1 = self.create_worker()
         worker1.blueprint.state = RUN
         worker1.blueprint.state = RUN
         stc = MockStep()
         stc = MockStep()
-        stc.start.side_effect = SystemTerminate()
+        stc.start.side_effect = WorkerTerminate()
         worker1.steps = [stc]
         worker1.steps = [stc]
         worker1.start()
         worker1.start()
         stc.start.assert_called_with(worker1)
         stc.start.assert_called_with(worker1)
@@ -955,7 +957,7 @@ class test_WorkController(AppCase):
         worker2 = self.create_worker()
         worker2 = self.create_worker()
         worker2.blueprint.state = RUN
         worker2.blueprint.state = RUN
         sec = MockStep()
         sec = MockStep()
-        sec.start.side_effect = SystemExit()
+        sec.start.side_effect = WorkerShutdown()
         sec.terminate = None
         sec.terminate = None
         worker2.steps = [sec]
         worker2.steps = [sec]
         worker2.start()
         worker2.start()

+ 2 - 2
celery/worker/__init__.py

@@ -29,7 +29,7 @@ from celery import concurrency as _concurrency
 from celery import platforms
 from celery import platforms
 from celery import signals
 from celery import signals
 from celery.exceptions import (
 from celery.exceptions import (
-    ImproperlyConfigured, SystemTerminate, TaskRevokedError,
+    ImproperlyConfigured, WorkerTerminate, TaskRevokedError,
 )
 )
 from celery.five import string_t, values
 from celery.five import string_t, values
 from celery.utils import default_nodename, worker_direct
 from celery.utils import default_nodename, worker_direct
@@ -204,7 +204,7 @@ class WorkController(object):
         """Starts the workers main loop."""
         """Starts the workers main loop."""
         try:
         try:
             self.blueprint.start(self)
             self.blueprint.start(self)
-        except SystemTerminate:
+        except WorkerTerminate:
             self.terminate()
             self.terminate()
         except Exception as exc:
         except Exception as exc:
             logger.error('Unrecoverable error: %r', exc, exc_info=True)
             logger.error('Unrecoverable error: %r', exc, exc_info=True)

+ 2 - 1
celery/worker/control.py

@@ -13,6 +13,7 @@ import tempfile
 
 
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
 
 
+from celery.exceptions import WorkerShutdown
 from celery.five import UserDict, items
 from celery.five import UserDict, items
 from celery.platforms import signals as _signals
 from celery.platforms import signals as _signals
 from celery.utils import timeutils
 from celery.utils import timeutils
@@ -336,7 +337,7 @@ def autoscale(state, max=None, min=None):
 @Panel.register
 @Panel.register
 def shutdown(state, msg='Got shutdown from remote', **kwargs):
 def shutdown(state, msg='Got shutdown from remote', **kwargs):
     logger.warning(msg)
     logger.warning(msg)
-    raise SystemExit(msg)
+    raise WorkerShutdown(msg)
 
 
 
 
 @Panel.register
 @Panel.register

+ 3 - 3
celery/worker/loops.py

@@ -10,7 +10,7 @@ from __future__ import absolute_import
 import socket
 import socket
 
 
 from celery.bootsteps import RUN
 from celery.bootsteps import RUN
-from celery.exceptions import SystemTerminate, WorkerLostError
+from celery.exceptions import WorkerShutdown, WorkerTerminate, WorkerLostError
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 
 
 from . import state
 from . import state
@@ -57,9 +57,9 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
         while blueprint.state == RUN and obj.connection:
         while blueprint.state == RUN and obj.connection:
             # shutdown if signal handlers told us to.
             # shutdown if signal handlers told us to.
             if state.should_stop:
             if state.should_stop:
-                raise SystemExit()
+                raise WorkerShutdown()
             elif state.should_terminate:
             elif state.should_terminate:
-                raise SystemTerminate()
+                raise WorkerTerminate()
 
 
             # We only update QoS when there is no more messages to read.
             # We only update QoS when there is no more messages to read.
             # This groups together qos calls, and makes sure that remote
             # This groups together qos calls, and makes sure that remote

+ 3 - 3
celery/worker/state.py

@@ -22,7 +22,7 @@ from kombu.utils import cached_property
 
 
 from celery import __version__
 from celery import __version__
 from celery.datastructures import LimitedSet
 from celery.datastructures import LimitedSet
-from celery.exceptions import SystemTerminate
+from celery.exceptions import WorkerShutdown, WorkerTerminate
 from celery.five import Counter
 from celery.five import Counter
 
 
 __all__ = ['SOFTWARE_INFO', 'reserved_requests', 'active_requests',
 __all__ = ['SOFTWARE_INFO', 'reserved_requests', 'active_requests',
@@ -66,9 +66,9 @@ should_terminate = False
 
 
 def maybe_shutdown():
 def maybe_shutdown():
     if should_stop:
     if should_stop:
-        raise SystemExit()
+        raise WorkerShutdown()
     elif should_terminate:
     elif should_terminate:
-        raise SystemTerminate()
+        raise WorkerTerminate()
 
 
 
 
 def task_accepted(request, _all_total_count=all_total_count):
 def task_accepted(request, _all_total_count=all_total_count):