浏览代码

Worker now preserves exit code. Closes #2024

Ask Solem 11 年之前
父节点
当前提交
dc28e8a54d

+ 8 - 6
celery/apps/worker.py

@@ -30,7 +30,7 @@ from celery.exceptions import (
 )
 )
 from celery.five import string, string_t
 from celery.five import string, string_t
 from celery.loaders.app import AppLoader
 from celery.loaders.app import AppLoader
-from celery.platforms import check_privileges
+from celery.platforms import EX_FAILURE, EX_OK, check_privileges
 from celery.utils import cry, isatty
 from celery.utils import cry, isatty
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
 from celery.utils.log import get_logger, in_sighandler, set_in_sighandler
 from celery.utils.log import get_logger, in_sighandler, set_in_sighandler
@@ -277,7 +277,7 @@ class Worker(WorkController):
 
 
 
 
 def _shutdown_handler(worker, sig='TERM', how='Warm',
 def _shutdown_handler(worker, sig='TERM', how='Warm',
-                      exc=WorkerShutdown, callback=None):
+                      exc=WorkerShutdown, callback=None, exitcode=EX_OK):
 
 
     def _handle_request(*args):
     def _handle_request(*args):
         with in_sighandler():
         with in_sighandler():
@@ -288,9 +288,9 @@ def _shutdown_handler(worker, sig='TERM', how='Warm',
                 safe_say('worker: {0} shutdown (MainProcess)'.format(how))
                 safe_say('worker: {0} shutdown (MainProcess)'.format(how))
             if active_thread_count() > 1:
             if active_thread_count() > 1:
                 setattr(state, {'Warm': 'should_stop',
                 setattr(state, {'Warm': 'should_stop',
-                                'Cold': 'should_terminate'}[how], True)
+                                'Cold': 'should_terminate'}[how], exitcode)
             else:
             else:
-                raise exc()
+                raise exc(exitcode)
     _handle_request.__name__ = str('worker_{0}'.format(how))
     _handle_request.__name__ = str('worker_{0}'.format(how))
     platforms.signals[sig] = _handle_request
     platforms.signals[sig] = _handle_request
 install_worker_term_handler = partial(
 install_worker_term_handler = partial(
@@ -299,6 +299,7 @@ install_worker_term_handler = partial(
 if not is_jython:  # pragma: no cover
 if not is_jython:  # pragma: no cover
     install_worker_term_hard_handler = partial(
     install_worker_term_hard_handler = partial(
         _shutdown_handler, sig='SIGQUIT', how='Cold', exc=WorkerTerminate,
         _shutdown_handler, sig='SIGQUIT', how='Cold', exc=WorkerTerminate,
+        exitcode=EX_FAILURE,
     )
     )
 else:  # pragma: no cover
 else:  # pragma: no cover
     install_worker_term_handler = \
     install_worker_term_handler = \
@@ -310,7 +311,8 @@ def on_SIGINT(worker):
     install_worker_term_hard_handler(worker, sig='SIGINT')
     install_worker_term_hard_handler(worker, sig='SIGINT')
 if not is_jython:  # pragma: no cover
 if not is_jython:  # pragma: no cover
     install_worker_int_handler = partial(
     install_worker_int_handler = partial(
-        _shutdown_handler, sig='SIGINT', callback=on_SIGINT
+        _shutdown_handler, sig='SIGINT', callback=on_SIGINT,
+        exitcode=EX_FAILURE,
     )
     )
 else:  # pragma: no cover
 else:  # pragma: no cover
     install_worker_int_handler = lambda *a, **kw: None
     install_worker_int_handler = lambda *a, **kw: None
@@ -332,7 +334,7 @@ def install_worker_restart_handler(worker, sig='SIGHUP'):
         import atexit
         import atexit
         atexit.register(_reload_current_worker)
         atexit.register(_reload_current_worker)
         from celery.worker import state
         from celery.worker import state
-        state.should_stop = True
+        state.should_stop = EX_OK
     platforms.signals[sig] = restart_worker_sig_handler
     platforms.signals[sig] = restart_worker_sig_handler
 
 
 
 

+ 4 - 2
celery/bin/worker.py

@@ -205,12 +205,14 @@ class worker(Command):
                     loglevel, '|'.join(
                     loglevel, '|'.join(
                         l for l in LOG_LEVELS if isinstance(l, string_t))))
                         l for l in LOG_LEVELS if isinstance(l, string_t))))
 
 
-        return self.app.Worker(
+        worker = self.app.Worker(
             hostname=hostname, pool_cls=pool_cls, loglevel=loglevel,
             hostname=hostname, pool_cls=pool_cls, loglevel=loglevel,
             logfile=logfile,  # node format handled by celery.app.log.setup
             logfile=logfile,  # node format handled by celery.app.log.setup
             pidfile=self.node_format(pidfile, hostname),
             pidfile=self.node_format(pidfile, hostname),
             state_db=self.node_format(state_db, hostname), **kwargs
             state_db=self.node_format(state_db, hostname), **kwargs
-        ).start()
+        )
+        worker.start()
+        return worker.exitcode
 
 
     def with_pool_option(self, argv):
     def with_pool_option(self, argv):
         # this command support custom pools
         # this command support custom pools

+ 23 - 17
celery/tests/bin/test_worker.py

@@ -17,6 +17,7 @@ from celery.bin.worker import worker, main as worker_main
 from celery.exceptions import (
 from celery.exceptions import (
     ImproperlyConfigured, WorkerShutdown, WorkerTerminate,
     ImproperlyConfigured, WorkerShutdown, WorkerTerminate,
 )
 )
+from celery.platforms import EX_FAILURE, EX_OK
 from celery.utils.log import ensure_process_aware_logger
 from celery.utils.log import ensure_process_aware_logger
 from celery.worker import state
 from celery.worker import state
 
 
@@ -490,8 +491,8 @@ class test_signal_handlers(WorkerAppCase):
         worker = self._Worker()
         worker = self._Worker()
         handlers = self.psig(cd.install_worker_int_handler, worker)
         handlers = self.psig(cd.install_worker_int_handler, worker)
         next_handlers = {}
         next_handlers = {}
-        state.should_stop = False
-        state.should_terminate = False
+        state.should_stop = None
+        state.should_terminate = None
 
 
         class Signals(platforms.Signals):
         class Signals(platforms.Signals):
 
 
@@ -504,15 +505,17 @@ class test_signal_handlers(WorkerAppCase):
             try:
             try:
                 handlers['SIGINT']('SIGINT', object())
                 handlers['SIGINT']('SIGINT', object())
                 self.assertTrue(state.should_stop)
                 self.assertTrue(state.should_stop)
+                self.assertEqual(state.should_stop, EX_FAILURE)
             finally:
             finally:
                 platforms.signals = p
                 platforms.signals = p
-                state.should_stop = False
+                state.should_stop = None
 
 
             try:
             try:
                 next_handlers['SIGINT']('SIGINT', object())
                 next_handlers['SIGINT']('SIGINT', object())
                 self.assertTrue(state.should_terminate)
                 self.assertTrue(state.should_terminate)
+                self.assertEqual(state.should_terminate, EX_FAILURE)
             finally:
             finally:
-                state.should_terminate = False
+                state.should_terminate = None
 
 
         with patch('celery.apps.worker.active_thread_count') as c:
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 1
             c.return_value = 1
@@ -543,7 +546,7 @@ class test_signal_handlers(WorkerAppCase):
                 self.assertTrue(state.should_stop)
                 self.assertTrue(state.should_stop)
             finally:
             finally:
                 process.name = name
                 process.name = name
-                state.should_stop = False
+                state.should_stop = None
 
 
         with patch('celery.apps.worker.active_thread_count') as c:
         with patch('celery.apps.worker.active_thread_count') as c:
             c.return_value = 1
             c.return_value = 1
@@ -554,7 +557,7 @@ class test_signal_handlers(WorkerAppCase):
                     handlers['SIGINT']('SIGINT', object())
                     handlers['SIGINT']('SIGINT', object())
             finally:
             finally:
                 process.name = name
                 process.name = name
-                state.should_stop = False
+                state.should_stop = None
 
 
     @disable_stdouts
     @disable_stdouts
     def test_install_HUP_not_supported_handler(self):
     def test_install_HUP_not_supported_handler(self):
@@ -580,14 +583,17 @@ class test_signal_handlers(WorkerAppCase):
                     handlers['SIGQUIT']('SIGQUIT', object())
                     handlers['SIGQUIT']('SIGQUIT', object())
                     self.assertTrue(state.should_terminate)
                     self.assertTrue(state.should_terminate)
                 finally:
                 finally:
-                    state.should_terminate = False
+                    state.should_terminate = None
             with patch('celery.apps.worker.active_thread_count') as c:
             with patch('celery.apps.worker.active_thread_count') as c:
                 c.return_value = 1
                 c.return_value = 1
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(
                 handlers = self.psig(
                     cd.install_worker_term_hard_handler, worker)
                     cd.install_worker_term_hard_handler, worker)
-                with self.assertRaises(WorkerTerminate):
-                    handlers['SIGQUIT']('SIGQUIT', object())
+                try:
+                    with self.assertRaises(WorkerTerminate):
+                        handlers['SIGQUIT']('SIGQUIT', object())
+                finally:
+                    state.should_terminate = None
         finally:
         finally:
             process.name = name
             process.name = name
 
 
@@ -599,9 +605,9 @@ class test_signal_handlers(WorkerAppCase):
             handlers = self.psig(cd.install_worker_term_handler, worker)
             handlers = self.psig(cd.install_worker_term_handler, worker)
             try:
             try:
                 handlers['SIGTERM']('SIGTERM', object())
                 handlers['SIGTERM']('SIGTERM', object())
-                self.assertTrue(state.should_stop)
+                self.assertEqual(state.should_stop, EX_OK)
             finally:
             finally:
-                state.should_stop = False
+                state.should_stop = None
 
 
     @disable_stdouts
     @disable_stdouts
     def test_worker_term_handler_when_single_thread(self):
     def test_worker_term_handler_when_single_thread(self):
@@ -613,7 +619,7 @@ class test_signal_handlers(WorkerAppCase):
                 with self.assertRaises(WorkerShutdown):
                 with self.assertRaises(WorkerShutdown):
                     handlers['SIGTERM']('SIGTERM', object())
                     handlers['SIGTERM']('SIGTERM', object())
             finally:
             finally:
-                state.should_stop = False
+                state.should_stop = None
 
 
     @patch('sys.__stderr__')
     @patch('sys.__stderr__')
     @skip_if_pypy
     @skip_if_pypy
@@ -637,7 +643,7 @@ class test_signal_handlers(WorkerAppCase):
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_term_handler, worker)
                 handlers = self.psig(cd.install_worker_term_handler, worker)
                 handlers['SIGTERM']('SIGTERM', object())
                 handlers['SIGTERM']('SIGTERM', object())
-                self.assertTrue(state.should_stop)
+                self.assertEqual(state.should_stop, EX_OK)
             with patch('celery.apps.worker.active_thread_count') as c:
             with patch('celery.apps.worker.active_thread_count') as c:
                 c.return_value = 1
                 c.return_value = 1
                 worker = self._Worker()
                 worker = self._Worker()
@@ -646,7 +652,7 @@ class test_signal_handlers(WorkerAppCase):
                     handlers['SIGTERM']('SIGTERM', object())
                     handlers['SIGTERM']('SIGTERM', object())
         finally:
         finally:
             process.name = name
             process.name = name
-            state.should_stop = False
+            state.should_stop = None
 
 
     @disable_stdouts
     @disable_stdouts
     @patch('celery.platforms.close_open_fds')
     @patch('celery.platforms.close_open_fds')
@@ -665,14 +671,14 @@ class test_signal_handlers(WorkerAppCase):
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_restart_handler, worker)
             handlers = self.psig(cd.install_worker_restart_handler, worker)
             handlers['SIGHUP']('SIGHUP', object())
             handlers['SIGHUP']('SIGHUP', object())
-            self.assertTrue(state.should_stop)
+            self.assertEqual(state.should_stop, EX_OK)
             self.assertTrue(register.called)
             self.assertTrue(register.called)
             callback = register.call_args[0][0]
             callback = register.call_args[0][0]
             callback()
             callback()
             self.assertTrue(argv)
             self.assertTrue(argv)
         finally:
         finally:
             os.execv = execv
             os.execv = execv
-            state.should_stop = False
+            state.should_stop = None
 
 
     @disable_stdouts
     @disable_stdouts
     def test_worker_term_hard_handler_when_threaded(self):
     def test_worker_term_hard_handler_when_threaded(self):
@@ -684,7 +690,7 @@ class test_signal_handlers(WorkerAppCase):
                 handlers['SIGQUIT']('SIGQUIT', object())
                 handlers['SIGQUIT']('SIGQUIT', object())
                 self.assertTrue(state.should_terminate)
                 self.assertTrue(state.should_terminate)
             finally:
             finally:
-                state.should_terminate = False
+                state.should_terminate = None
 
 
     @disable_stdouts
     @disable_stdouts
     def test_worker_term_hard_handler_when_single_threaded(self):
     def test_worker_term_hard_handler_when_single_threaded(self):

+ 9 - 0
celery/tests/case.py

@@ -464,6 +464,15 @@ class AppCase(Case):
             self._threads_at_setup, list(threading.enumerate()),
             self._threads_at_setup, list(threading.enumerate()),
         )
         )
 
 
+        # Make sure no test left the shutdown flags enabled.
+        from celery.worker import state as worker_state
+        # check for EX_OK
+        self.assertIsNot(worker_state.should_stop, False)
+        self.assertIsNot(worker_state.should_terminate, False)
+        # check for other true values
+        self.assertFalse(worker_state.should_stop)
+        self.assertFalse(worker_state.should_terminate)
+
     def _get_test_name(self):
     def _get_test_name(self):
         return '.'.join([self.__class__.__name__, self._testMethodName])
         return '.'.join([self.__class__.__name__, self._testMethodName])
 
 

+ 6 - 5
celery/tests/worker/test_loops.py

@@ -7,6 +7,7 @@ from kombu.async import Hub, READ, WRITE, ERR
 from celery.bootsteps import CLOSE, RUN
 from celery.bootsteps import CLOSE, RUN
 from celery.exceptions import InvalidTaskError, WorkerShutdown, WorkerTerminate
 from celery.exceptions import InvalidTaskError, WorkerShutdown, WorkerTerminate
 from celery.five import Empty
 from celery.five import Empty
+from celery.platforms import EX_FAILURE
 from celery.worker import state
 from celery.worker import state
 from celery.worker.consumer import Consumer
 from celery.worker.consumer import Consumer
 from celery.worker.loops import asynloop, synloop
 from celery.worker.loops import asynloop, synloop
@@ -179,27 +180,27 @@ class test_asynloop(AppCase):
             with self.assertRaises(WorkerTerminate):
             with self.assertRaises(WorkerTerminate):
                 asynloop(*x.args)
                 asynloop(*x.args)
         finally:
         finally:
-            state.should_terminate = False
+            state.should_terminate = None
 
 
     def test_should_terminate_hub_close_raises(self):
     def test_should_terminate_hub_close_raises(self):
         x = X(self.app)
         x = X(self.app)
         # XXX why aren't the errors propagated?!?
         # XXX why aren't the errors propagated?!?
-        state.should_terminate = True
+        state.should_terminate = EX_FAILURE
         x.hub.close.side_effect = MemoryError()
         x.hub.close.side_effect = MemoryError()
         try:
         try:
             with self.assertRaises(WorkerTerminate):
             with self.assertRaises(WorkerTerminate):
                 asynloop(*x.args)
                 asynloop(*x.args)
         finally:
         finally:
-            state.should_terminate = False
+            state.should_terminate = None
 
 
     def test_should_stop(self):
     def test_should_stop(self):
         x = X(self.app)
         x = X(self.app)
-        state.should_stop = True
+        state.should_stop = 303
         try:
         try:
             with self.assertRaises(WorkerShutdown):
             with self.assertRaises(WorkerShutdown):
                 asynloop(*x.args)
                 asynloop(*x.args)
         finally:
         finally:
-            state.should_stop = False
+            state.should_stop = None
 
 
     def test_updates_qos(self):
     def test_updates_qos(self):
         x = X(self.app)
         x = X(self.app)

+ 31 - 2
celery/tests/worker/test_state.py

@@ -48,13 +48,42 @@ class MyPersistent(state.Persistent):
 class test_maybe_shutdown(AppCase):
 class test_maybe_shutdown(AppCase):
 
 
     def teardown(self):
     def teardown(self):
-        state.should_stop = False
-        state.should_terminate = False
+        state.should_stop = None
+        state.should_terminate = None
 
 
     def test_should_stop(self):
     def test_should_stop(self):
         state.should_stop = True
         state.should_stop = True
         with self.assertRaises(WorkerShutdown):
         with self.assertRaises(WorkerShutdown):
             state.maybe_shutdown()
             state.maybe_shutdown()
+        state.should_stop = 0
+        with self.assertRaises(WorkerShutdown):
+            state.maybe_shutdown()
+        state.should_stop = False
+        try:
+            state.maybe_shutdown()
+        except SystemExit:
+            raise RuntimeError('should not have exited')
+        state.should_stop = None
+        try:
+            state.maybe_shutdown()
+        except SystemExit:
+            raise RuntimeError('should not have exited')
+
+        state.should_stop = 0
+        try:
+            state.maybe_shutdown()
+        except SystemExit as exc:
+            self.assertEqual(exc.code, 0)
+        else:
+            raise RuntimeError('should have exited')
+
+        state.should_stop = 303
+        try:
+            state.maybe_shutdown()
+        except SystemExit as exc:
+            self.assertEqual(exc.code, 303)
+        else:
+            raise RuntimeError('should have exited')
 
 
     def test_should_terminate(self):
     def test_should_terminate(self):
         state.should_terminate = True
         state.should_terminate = True

+ 3 - 2
celery/tests/worker/test_worker.py

@@ -20,6 +20,7 @@ from celery.exceptions import (
     WorkerShutdown, WorkerTerminate, TaskRevokedError, InvalidTaskError,
     WorkerShutdown, WorkerTerminate, TaskRevokedError, InvalidTaskError,
 )
 )
 from celery.five import Empty, range, Queue as FastQueue
 from celery.five import Empty, range, Queue as FastQueue
+from celery.platforms import EX_FAILURE
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker import components
 from celery.worker import components
 from celery.worker import consumer
 from celery.worker import consumer
@@ -864,7 +865,7 @@ class test_WorkController(AppCase):
         self.worker.blueprint = None
         self.worker.blueprint = None
         self.worker._shutdown()
         self.worker._shutdown()
 
 
-    @patch('celery.platforms.create_pidlock')
+    @patch('celery.worker.create_pidlock')
     def test_use_pidfile(self, create_pidlock):
     def test_use_pidfile(self, create_pidlock):
         create_pidlock.return_value = Mock()
         create_pidlock.return_value = Mock()
         worker = self.create_worker(pidfile='pidfilelockfilepid')
         worker = self.create_worker(pidfile='pidfilelockfilepid')
@@ -1112,7 +1113,7 @@ class test_WorkController(AppCase):
         step.start.side_effect = TypeError()
         step.start.side_effect = TypeError()
         worker.stop = Mock()
         worker.stop = Mock()
         worker.start()
         worker.start()
-        worker.stop.assert_called_with()
+        worker.stop.assert_called_with(exitcode=EX_FAILURE)
 
 
     def test_state(self):
     def test_state(self):
         self.assertTrue(self.worker.state)
         self.assertTrue(self.worker.state)

+ 13 - 6
celery/worker/__init__.py

@@ -26,12 +26,12 @@ from kombu.syn import detect_environment
 from celery import bootsteps
 from celery import bootsteps
 from celery.bootsteps import RUN, TERMINATE
 from celery.bootsteps import RUN, TERMINATE
 from celery import concurrency as _concurrency
 from celery import concurrency as _concurrency
-from celery import platforms
 from celery import signals
 from celery import signals
 from celery.exceptions import (
 from celery.exceptions import (
     ImproperlyConfigured, WorkerTerminate, TaskRevokedError,
     ImproperlyConfigured, WorkerTerminate, TaskRevokedError,
 )
 )
 from celery.five import string_t, values
 from celery.five import string_t, values
+from celery.platforms import EX_FAILURE, create_pidlock
 from celery.utils import default_nodename, worker_direct
 from celery.utils import default_nodename, worker_direct
 from celery.utils.imports import reload_from_cwd
 from celery.utils.imports import reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 from celery.utils.log import mlevel, worker_logger as logger
@@ -73,6 +73,9 @@ class WorkController(object):
     pool = None
     pool = None
     semaphore = None
     semaphore = None
 
 
+    #: contains the exit code if a :exc:`SystemExit` event is handled.
+    exitcode = None
+
     class Blueprint(bootsteps.Blueprint):
     class Blueprint(bootsteps.Blueprint):
         """Worker bootstep blueprint."""
         """Worker bootstep blueprint."""
         name = 'Worker'
         name = 'Worker'
@@ -150,7 +153,7 @@ class WorkController(object):
 
 
     def on_start(self):
     def on_start(self):
         if self.pidfile:
         if self.pidfile:
-            self.pidlock = platforms.create_pidlock(self.pidfile)
+            self.pidlock = create_pidlock(self.pidfile)
 
 
     def on_consumer_ready(self, consumer):
     def on_consumer_ready(self, consumer):
         pass
         pass
@@ -207,9 +210,11 @@ class WorkController(object):
             self.terminate()
             self.terminate()
         except Exception as exc:
         except Exception as exc:
             logger.error('Unrecoverable error: %r', exc, exc_info=True)
             logger.error('Unrecoverable error: %r', exc, exc_info=True)
-            self.stop()
-        except (KeyboardInterrupt, SystemExit):
-            self.stop()
+            self.stop(exitcode=EX_FAILURE)
+        except SystemExit as exc:
+            self.stop(exitcode=exc.code)
+        except KeyboardInterrupt:
+            self.stop(exitcode=EX_FAILURE)
 
 
     def register_with_event_loop(self, hub):
     def register_with_event_loop(self, hub):
         self.blueprint.send_all(
         self.blueprint.send_all(
@@ -243,8 +248,10 @@ class WorkController(object):
         return (detect_environment() == 'default' and
         return (detect_environment() == 'default' and
                 self._conninfo.is_evented and not self.app.IS_WINDOWS)
                 self._conninfo.is_evented and not self.app.IS_WINDOWS)
 
 
-    def stop(self, in_sighandler=False):
+    def stop(self, in_sighandler=False, exitcode=None):
         """Graceful shutdown of the worker server."""
         """Graceful shutdown of the worker server."""
+        if exitcode is not None:
+            self.exitcode = exitcode
         if self.blueprint.state == RUN:
         if self.blueprint.state == RUN:
             self.signal_consumer_close()
             self.signal_consumer_close()
             if not in_sighandler or self.pool.signal_safe:
             if not in_sighandler or self.pool.signal_safe:

+ 8 - 4
celery/worker/loops.py

@@ -57,10 +57,14 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
     try:
     try:
         while blueprint.state == RUN and obj.connection:
         while blueprint.state == RUN and obj.connection:
             # shutdown if signal handlers told us to.
             # shutdown if signal handlers told us to.
-            if state.should_stop:
-                raise WorkerShutdown()
-            elif state.should_terminate:
-                raise WorkerTerminate()
+            should_stop, should_terminate = (
+                state.should_stop, state.should_terminate,
+            )
+            # False == EX_OK, so must use is not False
+            if should_stop is not None and should_stop is not False:
+                raise WorkerShutdown(should_stop)
+            elif should_terminate is not None and should_stop is not False:
+                raise WorkerTerminate(should_terminate)
 
 
             # We only update QoS when there is no more messages to read.
             # We only update QoS when there is no more messages to read.
             # This groups together qos calls, and makes sure that remote
             # This groups together qos calls, and makes sure that remote

+ 6 - 6
celery/worker/state.py

@@ -60,15 +60,15 @@ revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 #: Update global state when a task has been reserved.
 #: Update global state when a task has been reserved.
 task_reserved = reserved_requests.add
 task_reserved = reserved_requests.add
 
 
-should_stop = False
-should_terminate = False
+should_stop = None
+should_terminate = None
 
 
 
 
 def maybe_shutdown():
 def maybe_shutdown():
-    if should_stop:
-        raise WorkerShutdown()
-    elif should_terminate:
-        raise WorkerTerminate()
+    if should_stop is not None and should_stop is not False:
+        raise WorkerShutdown(should_stop)
+    elif should_terminate is not None and should_terminate is not False:
+        raise WorkerTerminate(should_terminate)
 
 
 
 
 def task_accepted(request, _all_total_count=all_total_count):
 def task_accepted(request, _all_total_count=all_total_count):