Browse Source

Much of the WorkController code is now in bootsteps.Namespace

Ask Solem 12 years ago
parent
commit
96479fc168

+ 2 - 0
celery/apps/worker.py

@@ -102,6 +102,8 @@ class Worker(WorkController):
         self.redirect_stdouts_level = redirect_stdouts_level
         self.redirect_stdouts_level = redirect_stdouts_level
 
 
     def on_start(self):
     def on_start(self):
+        WorkController.on_start(self)
+
         # apply task execution optimizations
         # apply task execution optimizations
         trace.setup_worker_optimizations(self.app)
         trace.setup_worker_optimizations(self.app)
 
 

+ 16 - 15
celery/tests/worker/test_worker.py

@@ -24,10 +24,11 @@ from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker import WorkController
 from celery.worker.components import Queues, Timers, EvLoop, Pool
 from celery.worker.components import Queues, Timers, EvLoop, Pool
+from celery.worker.bootsteps import RUN, CLOSE, TERMINATE
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker.job import Request
 from celery.worker.consumer import BlockingConsumer
 from celery.worker.consumer import BlockingConsumer
-from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
+from celery.worker.consumer import QoS, PREFETCH_COUNT_MAX
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
@@ -770,7 +771,7 @@ class test_WorkController(AppCase):
 
 
     def create_worker(self, **kw):
     def create_worker(self, **kw):
         worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
         worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
-        worker._shutdown_complete.set()
+        worker.namespace.shutdown_complete.set()
         return worker
         return worker
 
 
     @patch('celery.platforms.create_pidlock')
     @patch('celery.platforms.create_pidlock')
@@ -839,17 +840,17 @@ class test_WorkController(AppCase):
     def test_dont_stop_or_terminate(self):
     def test_dont_stop_or_terminate(self):
         worker = WorkController(concurrency=1, loglevel=0)
         worker = WorkController(concurrency=1, loglevel=0)
         worker.stop()
         worker.stop()
-        self.assertNotEqual(worker._state, worker.CLOSE)
+        self.assertNotEqual(worker.namespace.state, CLOSE)
         worker.terminate()
         worker.terminate()
-        self.assertNotEqual(worker._state, worker.CLOSE)
+        self.assertNotEqual(worker.namespace.state, CLOSE)
 
 
         sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
         sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
         try:
         try:
-            worker._state = worker.RUN
+            worker.namespace.state = RUN
             worker.stop(in_sighandler=True)
             worker.stop(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
+            self.assertNotEqual(worker.namespace.state, CLOSE)
             worker.terminate(in_sighandler=True)
             worker.terminate(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
+            self.assertNotEqual(worker.namespace.state, CLOSE)
         finally:
         finally:
             worker.pool.signal_safe = sigsafe
             worker.pool.signal_safe = sigsafe
 
 
@@ -892,10 +893,10 @@ class test_WorkController(AppCase):
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode())
         task = Request.from_message(m, m.decode())
         worker.components = []
         worker.components = []
-        worker._state = worker.RUN
+        worker.namespace.state = RUN
         with self.assertRaises(KeyboardInterrupt):
         with self.assertRaises(KeyboardInterrupt):
             worker.process_task(task)
             worker.process_task(task)
-        self.assertEqual(worker._state, worker.TERMINATE)
+        self.assertEqual(worker.namespace.state, TERMINATE)
 
 
     def test_process_task_raise_SystemTerminate(self):
     def test_process_task_raise_SystemTerminate(self):
         worker = self.worker
         worker = self.worker
@@ -906,10 +907,10 @@ class test_WorkController(AppCase):
                            kwargs={})
                            kwargs={})
         task = Request.from_message(m, m.decode())
         task = Request.from_message(m, m.decode())
         worker.components = []
         worker.components = []
-        worker._state = worker.RUN
+        worker.namespace.state = RUN
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             worker.process_task(task)
             worker.process_task(task)
-        self.assertEqual(worker._state, worker.TERMINATE)
+        self.assertEqual(worker.namespace.state, TERMINATE)
 
 
     def test_process_task_raise_regular(self):
     def test_process_task_raise_regular(self):
         worker = self.worker
         worker = self.worker
@@ -986,7 +987,7 @@ class test_WorkController(AppCase):
 
 
     def test_start__stop(self):
     def test_start__stop(self):
         worker = self.worker
         worker = self.worker
-        worker._shutdown_complete.set()
+        worker.namespace.shutdown_complete.set()
         worker.components = [Mock(), Mock(), Mock(), Mock()]
         worker.components = [Mock(), Mock(), Mock(), Mock()]
 
 
         worker.start()
         worker.start()
@@ -1020,7 +1021,7 @@ class test_WorkController(AppCase):
 
 
     def test_start__terminate(self):
     def test_start__terminate(self):
         worker = self.worker
         worker = self.worker
-        worker._shutdown_complete.set()
+        worker.namespace.shutdown_complete.set()
         worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
         worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
         for component in worker.components[:3]:
         for component in worker.components[:3]:
             component.terminate = None
             component.terminate = None
@@ -1028,8 +1029,8 @@ class test_WorkController(AppCase):
         worker.start()
         worker.start()
         for w in worker.components[:3]:
         for w in worker.components[:3]:
             self.assertTrue(w.start.call_count)
             self.assertTrue(w.start.call_count)
-        self.assertTrue(worker._running, len(worker.components))
-        self.assertEqual(worker._state, RUN)
+        self.assertTrue(worker.namespace.started, len(worker.components))
+        self.assertEqual(worker.namespace.state, RUN)
         worker.terminate()
         worker.terminate()
         for component in worker.components[:3]:
         for component in worker.components[:3]:
             self.assertTrue(component.stop.call_count)
             self.assertTrue(component.stop.call_count)

+ 35 - 89
celery/worker/__init__.py

@@ -15,8 +15,6 @@ import socket
 import sys
 import sys
 import traceback
 import traceback
 
 
-from threading import Event
-
 from billiard import cpu_count
 from billiard import cpu_count
 from kombu.syn import detect_environment
 from kombu.syn import detect_environment
 from kombu.utils.finalize import Finalize
 from kombu.utils.finalize import Finalize
@@ -30,23 +28,12 @@ from celery.exceptions import (
     ImproperlyConfigured, SystemTerminate, TaskRevokedError,
     ImproperlyConfigured, SystemTerminate, TaskRevokedError,
 )
 )
 from celery.utils import worker_direct
 from celery.utils import worker_direct
-from celery.utils.imports import qualname, reload_from_cwd
+from celery.utils.imports import reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 from celery.utils.log import mlevel, worker_logger as logger
 
 
 from . import bootsteps
 from . import bootsteps
 from . import state
 from . import state
 
 
-try:
-    from greenlet import GreenletExit
-    IGNORE_ERRORS = (GreenletExit, )
-except ImportError:  # pragma: no cover
-    IGNORE_ERRORS = ()
-
-#: Worker states
-RUN = 0x1
-CLOSE = 0x2
-TERMINATE = 0x3
-
 UNKNOWN_QUEUE = """\
 UNKNOWN_QUEUE = """\
 Trying to select queue subset of {0!r}, but queue {1} is not
 Trying to select queue subset of {0!r}, but queue {1} is not
 defined in the CELERY_QUEUES setting.
 defined in the CELERY_QUEUES setting.
@@ -55,9 +42,6 @@ If you want to automatically declare unknown queues you can
 enable the CELERY_CREATE_MISSING_QUEUES setting.
 enable the CELERY_CREATE_MISSING_QUEUES setting.
 """
 """
 
 
-#: Default socket timeout at shutdown.
-SHUTDOWN_SOCKET_TIMEOUT = 5.0
-
 
 
 class Namespace(bootsteps.Namespace):
 class Namespace(bootsteps.Namespace):
     """This is the boot-step namespace of the :class:`WorkController`.
     """This is the boot-step namespace of the :class:`WorkController`.
@@ -79,10 +63,6 @@ class Namespace(bootsteps.Namespace):
 
 
 class WorkController(configurated):
 class WorkController(configurated):
     """Unmanaged worker instance."""
     """Unmanaged worker instance."""
-    RUN = RUN
-    CLOSE = CLOSE
-    TERMINATE = TERMINATE
-
     app = None
     app = None
     concurrency = from_config()
     concurrency = from_config()
     loglevel = from_config('log_level')
     loglevel = from_config('log_level')
@@ -108,8 +88,6 @@ class WorkController(configurated):
     disable_rate_limits = from_config()
     disable_rate_limits = from_config()
     worker_lost_wait = from_config()
     worker_lost_wait = from_config()
 
 
-    _state = None
-    _running = 0
     pidlock = None
     pidlock = None
 
 
     def __init__(self, app=None, hostname=None, **kwargs):
     def __init__(self, app=None, hostname=None, **kwargs):
@@ -118,18 +96,8 @@ class WorkController(configurated):
         self.on_before_init(**kwargs)
         self.on_before_init(**kwargs)
 
 
         self._finalize = Finalize(self, self.stop, exitpriority=1)
         self._finalize = Finalize(self, self.stop, exitpriority=1)
-        self._shutdown_complete = Event()
         self.setup_instance(**self.prepare_args(**kwargs))
         self.setup_instance(**self.prepare_args(**kwargs))
 
 
-    def on_before_init(self, **kwargs):
-        pass
-
-    def on_start(self):
-        pass
-
-    def on_consumer_ready(self, consumer):
-        pass
-
     def setup_instance(self, queues=None, ready_callback=None,
     def setup_instance(self, queues=None, ready_callback=None,
             pidfile=None, include=None, **kwargs):
             pidfile=None, include=None, **kwargs):
         self.pidfile = pidfile
         self.pidfile = pidfile
@@ -155,7 +123,32 @@ class WorkController(configurated):
         # Initialize boot steps
         # Initialize boot steps
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.components = []
         self.components = []
-        self.namespace = Namespace(app=self.app).apply(self, **kwargs)
+        self.namespace = Namespace(app=self.app,
+                                   on_start=self.on_start,
+                                   on_close=self.on_close,
+                                   on_stopped=self.on_stopped)
+        self.namespace.apply(self, **kwargs)
+
+
+    def on_before_init(self, **kwargs):
+        pass
+
+    def on_start(self):
+        if self.pidfile:
+            self.pidlock = platforms.create_pidlock(self.pidfile)
+
+    def on_consumer_ready(self, consumer):
+        pass
+
+    def on_close(self):
+        self.app.loader.shutdown_worker()
+
+    def on_stopped(self):
+        self.timer.stop()
+        self.consumer.close_connection()
+
+        if self.pidlock:
+            self.pidlock.release()
 
 
     def setup_queues(self, queues):
     def setup_queues(self, queues):
         if isinstance(queues, basestring):
         if isinstance(queues, basestring):
@@ -187,34 +180,16 @@ class WorkController(configurated):
 
 
     def start(self):
     def start(self):
         """Starts the workers main loop."""
         """Starts the workers main loop."""
-        self.on_start()
-        self._state = self.RUN
-        if self.pidfile:
-            self.pidlock = platforms.create_pidlock(self.pidfile)
         try:
         try:
-            for i, component in enumerate(self.components):
-                logger.debug('Starting %s...', qualname(component))
-                self._running = i + 1
-                if component:
-                    component.start()
-                logger.debug('%s OK!', qualname(component))
+            self.namespace.start(self)
         except SystemTerminate:
         except SystemTerminate:
             self.terminate()
             self.terminate()
         except Exception as exc:
         except Exception as exc:
-            logger.error('Unrecoverable error: %r', exc,
-                         exc_info=True)
+            logger.error('Unrecoverable error: %r', exc, exc_info=True)
             self.stop()
             self.stop()
         except (KeyboardInterrupt, SystemExit):
         except (KeyboardInterrupt, SystemExit):
             self.stop()
             self.stop()
 
 
-        try:
-            # Will only get here if running green,
-            # makes sure all greenthreads have exited.
-            self._shutdown_complete.wait()
-        except IGNORE_ERRORS:
-            pass
-    run = start   # XXX Compat
-
     def process_task_sem(self, req):
     def process_task_sem(self, req):
         return self._quick_acquire(self.process_task, req)
         return self._quick_acquire(self.process_task, req)
 
 
@@ -260,41 +235,8 @@ class WorkController(configurated):
             self._shutdown(warm=False)
             self._shutdown(warm=False)
 
 
     def _shutdown(self, warm=True):
     def _shutdown(self, warm=True):
-        what = 'Stopping' if warm else 'Terminating'
-        socket_timeout = socket.getdefaulttimeout()
-        socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
-
-        if self._state in (self.CLOSE, self.TERMINATE):
-            return
-
-        self.app.loader.shutdown_worker()
-
-        if self.pool:
-            self.pool.close()
-
-        if self._state != self.RUN or self._running != len(self.components):
-            # Not fully started, can safely exit.
-            self._state = self.TERMINATE
-            self._shutdown_complete.set()
-            return
-        self._state = self.CLOSE
-
-        for component in reversed(self.components):
-            logger.debug('%s %s...', what, qualname(component))
-            if component:
-                stop = component.stop
-                if not warm:
-                    stop = getattr(component, 'terminate', None) or stop
-                stop()
-
-        self.timer.stop()
-        self.consumer.close_connection()
-
-        if self.pidlock:
-            self.pidlock.release()
-        self._state = self.TERMINATE
-        socket.setdefaulttimeout(socket_timeout)
-        self._shutdown_complete.set()
+        self.namespace.stop(self, terminate=not warm)
+        self.namespace.join()
 
 
     def reload(self, modules=None, reload=False, reloader=None):
     def reload(self, modules=None, reload=False, reloader=None):
         modules = self.app.loader.task_modules if modules is None else modules
         modules = self.app.loader.task_modules if modules is None else modules
@@ -309,6 +251,10 @@ class WorkController(configurated):
                 reload_from_cwd(sys.modules[module], reloader)
                 reload_from_cwd(sys.modules[module], reloader)
         self.pool.restart()
         self.pool.restart()
 
 
+    @property
+    def _state(self):
+        return self.namespace.state
+
     @property
     @property
     def state(self):
     def state(self):
         return state
         return state

+ 89 - 3
celery/worker/bootsteps.py

@@ -8,13 +8,30 @@
 """
 """
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
+import socket
+
 from collections import defaultdict
 from collections import defaultdict
 from importlib import import_module
 from importlib import import_module
+from threading import Event
 
 
 from celery.datastructures import DependencyGraph
 from celery.datastructures import DependencyGraph
-from celery.utils.imports import instantiate
+from celery.utils.imports import instantiate, qualname
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 
 
+try:
+    from greenlet import GreenletExit
+    IGNORE_ERRORS = (GreenletExit, )
+except ImportError:  # pragma: no cover
+    IGNORE_ERRORS = ()
+
+#: Default socket timeout at shutdown.
+SHUTDOWN_SOCKET_TIMEOUT = 5.0
+
+#: States
+RUN = 0x1
+CLOSE = 0x2
+TERMINATE = 0x3
+
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -32,13 +49,79 @@ class Namespace(object):
 
 
     """
     """
     name = None
     name = None
+    state = None
+    started = 0
+
     _unclaimed = defaultdict(dict)
     _unclaimed = defaultdict(dict)
-    _started_count = 0
 
 
-    def __init__(self, name=None, app=None):
+    def __init__(self, name=None, app=None, on_start=None,
+            on_close=None, on_stopped=None):
         self.app = app
         self.app = app
         self.name = name or self.name
         self.name = name or self.name
+        self.on_start = on_start
+        self.on_close = on_close
+        self.on_stopped = on_stopped
         self.services = []
         self.services = []
+        self.shutdown_complete = Event()
+
+    def start(self, parent):
+        self.state = RUN
+        if self.on_start:
+            self.on_start()
+        for i, component in enumerate(parent.components):
+            if component:
+                logger.debug('Starting %s...', qualname(component))
+                self.started = i + 1
+                component.start()
+                logger.debug('%s OK!', qualname(component))
+
+    def close(self, parent):
+        if self.on_close:
+            self.on_close()
+        for component in parent.components:
+            try:
+                close = component.close
+            except AttributeError:
+                pass
+            else:
+                close()
+
+    def stop(self, parent, terminate=False):
+        what = 'Terminating' if terminate else 'Stopping'
+        socket_timeout = socket.getdefaulttimeout()
+        socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
+
+        if self.state in (CLOSE, TERMINATE):
+            return
+
+        if self.state != RUN or self.started != len(parent.components):
+            # Not fully started, can safely exit.
+            self.state = TERMINATE
+            self.shutdown_complete.set()
+            return
+
+        self.state = CLOSE
+        for component in reversed(parent.components):
+            if component:
+                logger.debug('%s %s...', what, qualname(component))
+                stop = component.stop
+                if terminate:
+                    stop = getattr(component, 'terminate', None) or stop
+                stop()
+
+        if self.on_stopped:
+            self.on_stopped()
+        self.state = TERMINATE
+        socket.setdefaulttimeout(socket_timeout)
+        self.shutdown_complete.set()
+
+    def join(self, timeout=None):
+        try:
+            # Will only get here if running green,
+            # makes sure all greenthreads have exited.
+            self.shutdown_complete.wait(timeout=timeout)
+        except IGNORE_ERRORS:
+            pass
 
 
     def modules(self):
     def modules(self):
         """Subclasses can override this to return a
         """Subclasses can override this to return a
@@ -200,6 +283,9 @@ class StartStopComponent(Component):
     def stop(self):
     def stop(self):
         return self.obj.stop()
         return self.obj.stop()
 
 
+    def close(self):
+        pass
+
     def terminate(self):
     def terminate(self):
         if self.terminable:
         if self.terminable:
             return self.obj.terminate()
             return self.obj.terminate()

+ 4 - 0
celery/worker/components.py

@@ -54,6 +54,10 @@ class Pool(bootsteps.StartStopComponent):
             w.max_concurrency, w.min_concurrency = w.autoscale
             w.max_concurrency, w.min_concurrency = w.autoscale
         self.autoreload_enabled = autoreload
         self.autoreload_enabled = autoreload
 
 
+    def close(self, w):
+        if w.pool:
+            w.pool.close()
+
     def on_poll_init(self, pool, hub):
     def on_poll_init(self, pool, hub):
         apply_after = hub.timer.apply_after
         apply_after = hub.timer.apply_after
         apply_at = hub.timer.apply_at
         apply_at = hub.timer.apply_at

+ 1 - 4
celery/worker/consumer.py

@@ -94,13 +94,10 @@ from celery.utils.log import get_logger
 from celery.utils.timeutils import humanize_seconds
 from celery.utils.timeutils import humanize_seconds
 
 
 from . import state
 from . import state
-from .bootsteps import StartStopComponent
+from .bootsteps import StartStopComponent, RUN, CLOSE
 from .control import Panel
 from .control import Panel
 from .heartbeat import Heart
 from .heartbeat import Heart
 
 
-RUN = 0x1
-CLOSE = 0x2
-
 #: Heartbeat check is called every heartbeat_seconds' / rate'.
 #: Heartbeat check is called every heartbeat_seconds' / rate'.
 AMQHEARTBEAT_RATE = 2.0
 AMQHEARTBEAT_RATE = 2.0