Browse Source

Much of the WorkController code is now in bootsteps.Namespace

Ask Solem 12 years ago
parent
commit
96479fc168

+ 2 - 0
celery/apps/worker.py

@@ -102,6 +102,8 @@ class Worker(WorkController):
         self.redirect_stdouts_level = redirect_stdouts_level
 
     def on_start(self):
+        WorkController.on_start(self)
+
         # apply task execution optimizations
         trace.setup_worker_optimizations(self.app)
 

+ 16 - 15
celery/tests/worker/test_worker.py

@@ -24,10 +24,11 @@ from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker.components import Queues, Timers, EvLoop, Pool
+from celery.worker.bootsteps import RUN, CLOSE, TERMINATE
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker.consumer import BlockingConsumer
-from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
+from celery.worker.consumer import QoS, PREFETCH_COUNT_MAX
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
@@ -770,7 +771,7 @@ class test_WorkController(AppCase):
 
     def create_worker(self, **kw):
         worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
-        worker._shutdown_complete.set()
+        worker.namespace.shutdown_complete.set()
         return worker
 
     @patch('celery.platforms.create_pidlock')
@@ -839,17 +840,17 @@ class test_WorkController(AppCase):
     def test_dont_stop_or_terminate(self):
         worker = WorkController(concurrency=1, loglevel=0)
         worker.stop()
-        self.assertNotEqual(worker._state, worker.CLOSE)
+        self.assertNotEqual(worker.namespace.state, CLOSE)
         worker.terminate()
-        self.assertNotEqual(worker._state, worker.CLOSE)
+        self.assertNotEqual(worker.namespace.state, CLOSE)
 
         sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
         try:
-            worker._state = worker.RUN
+            worker.namespace.state = RUN
             worker.stop(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
+            self.assertNotEqual(worker.namespace.state, CLOSE)
             worker.terminate(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
+            self.assertNotEqual(worker.namespace.state, CLOSE)
         finally:
             worker.pool.signal_safe = sigsafe
 
@@ -892,10 +893,10 @@ class test_WorkController(AppCase):
                            kwargs={})
         task = Request.from_message(m, m.decode())
         worker.components = []
-        worker._state = worker.RUN
+        worker.namespace.state = RUN
         with self.assertRaises(KeyboardInterrupt):
             worker.process_task(task)
-        self.assertEqual(worker._state, worker.TERMINATE)
+        self.assertEqual(worker.namespace.state, TERMINATE)
 
     def test_process_task_raise_SystemTerminate(self):
         worker = self.worker
@@ -906,10 +907,10 @@ class test_WorkController(AppCase):
                            kwargs={})
         task = Request.from_message(m, m.decode())
         worker.components = []
-        worker._state = worker.RUN
+        worker.namespace.state = RUN
         with self.assertRaises(SystemExit):
             worker.process_task(task)
-        self.assertEqual(worker._state, worker.TERMINATE)
+        self.assertEqual(worker.namespace.state, TERMINATE)
 
     def test_process_task_raise_regular(self):
         worker = self.worker
@@ -986,7 +987,7 @@ class test_WorkController(AppCase):
 
     def test_start__stop(self):
         worker = self.worker
-        worker._shutdown_complete.set()
+        worker.namespace.shutdown_complete.set()
         worker.components = [Mock(), Mock(), Mock(), Mock()]
 
         worker.start()
@@ -1020,7 +1021,7 @@ class test_WorkController(AppCase):
 
     def test_start__terminate(self):
         worker = self.worker
-        worker._shutdown_complete.set()
+        worker.namespace.shutdown_complete.set()
         worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
         for component in worker.components[:3]:
             component.terminate = None
@@ -1028,8 +1029,8 @@ class test_WorkController(AppCase):
         worker.start()
         for w in worker.components[:3]:
             self.assertTrue(w.start.call_count)
-        self.assertTrue(worker._running, len(worker.components))
-        self.assertEqual(worker._state, RUN)
+        self.assertTrue(worker.namespace.started, len(worker.components))
+        self.assertEqual(worker.namespace.state, RUN)
         worker.terminate()
         for component in worker.components[:3]:
             self.assertTrue(component.stop.call_count)

+ 35 - 89
celery/worker/__init__.py

@@ -15,8 +15,6 @@ import socket
 import sys
 import traceback
 
-from threading import Event
-
 from billiard import cpu_count
 from kombu.syn import detect_environment
 from kombu.utils.finalize import Finalize
@@ -30,23 +28,12 @@ from celery.exceptions import (
     ImproperlyConfigured, SystemTerminate, TaskRevokedError,
 )
 from celery.utils import worker_direct
-from celery.utils.imports import qualname, reload_from_cwd
+from celery.utils.imports import reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 
 from . import bootsteps
 from . import state
 
-try:
-    from greenlet import GreenletExit
-    IGNORE_ERRORS = (GreenletExit, )
-except ImportError:  # pragma: no cover
-    IGNORE_ERRORS = ()
-
-#: Worker states
-RUN = 0x1
-CLOSE = 0x2
-TERMINATE = 0x3
-
 UNKNOWN_QUEUE = """\
 Trying to select queue subset of {0!r}, but queue {1} is not
 defined in the CELERY_QUEUES setting.
@@ -55,9 +42,6 @@ If you want to automatically declare unknown queues you can
 enable the CELERY_CREATE_MISSING_QUEUES setting.
 """
 
-#: Default socket timeout at shutdown.
-SHUTDOWN_SOCKET_TIMEOUT = 5.0
-
 
 class Namespace(bootsteps.Namespace):
     """This is the boot-step namespace of the :class:`WorkController`.
@@ -79,10 +63,6 @@ class Namespace(bootsteps.Namespace):
 
 class WorkController(configurated):
     """Unmanaged worker instance."""
-    RUN = RUN
-    CLOSE = CLOSE
-    TERMINATE = TERMINATE
-
     app = None
     concurrency = from_config()
     loglevel = from_config('log_level')
@@ -108,8 +88,6 @@ class WorkController(configurated):
     disable_rate_limits = from_config()
     worker_lost_wait = from_config()
 
-    _state = None
-    _running = 0
     pidlock = None
 
     def __init__(self, app=None, hostname=None, **kwargs):
@@ -118,18 +96,8 @@ class WorkController(configurated):
         self.on_before_init(**kwargs)
 
         self._finalize = Finalize(self, self.stop, exitpriority=1)
-        self._shutdown_complete = Event()
         self.setup_instance(**self.prepare_args(**kwargs))
 
-    def on_before_init(self, **kwargs):
-        pass
-
-    def on_start(self):
-        pass
-
-    def on_consumer_ready(self, consumer):
-        pass
-
     def setup_instance(self, queues=None, ready_callback=None,
             pidfile=None, include=None, **kwargs):
         self.pidfile = pidfile
@@ -155,7 +123,32 @@ class WorkController(configurated):
         # Initialize boot steps
         self.pool_cls = _concurrency.get_implementation(self.pool_cls)
         self.components = []
-        self.namespace = Namespace(app=self.app).apply(self, **kwargs)
+        self.namespace = Namespace(app=self.app,
+                                   on_start=self.on_start,
+                                   on_close=self.on_close,
+                                   on_stopped=self.on_stopped)
+        self.namespace.apply(self, **kwargs)
+
+
+    def on_before_init(self, **kwargs):
+        pass
+
+    def on_start(self):
+        if self.pidfile:
+            self.pidlock = platforms.create_pidlock(self.pidfile)
+
+    def on_consumer_ready(self, consumer):
+        pass
+
+    def on_close(self):
+        self.app.loader.shutdown_worker()
+
+    def on_stopped(self):
+        self.timer.stop()
+        self.consumer.close_connection()
+
+        if self.pidlock:
+            self.pidlock.release()
 
     def setup_queues(self, queues):
         if isinstance(queues, basestring):
@@ -187,34 +180,16 @@ class WorkController(configurated):
 
     def start(self):
         """Starts the workers main loop."""
-        self.on_start()
-        self._state = self.RUN
-        if self.pidfile:
-            self.pidlock = platforms.create_pidlock(self.pidfile)
         try:
-            for i, component in enumerate(self.components):
-                logger.debug('Starting %s...', qualname(component))
-                self._running = i + 1
-                if component:
-                    component.start()
-                logger.debug('%s OK!', qualname(component))
+            self.namespace.start(self)
         except SystemTerminate:
             self.terminate()
         except Exception as exc:
-            logger.error('Unrecoverable error: %r', exc,
-                         exc_info=True)
+            logger.error('Unrecoverable error: %r', exc, exc_info=True)
             self.stop()
         except (KeyboardInterrupt, SystemExit):
             self.stop()
 
-        try:
-            # Will only get here if running green,
-            # makes sure all greenthreads have exited.
-            self._shutdown_complete.wait()
-        except IGNORE_ERRORS:
-            pass
-    run = start   # XXX Compat
-
     def process_task_sem(self, req):
         return self._quick_acquire(self.process_task, req)
 
@@ -260,41 +235,8 @@ class WorkController(configurated):
             self._shutdown(warm=False)
 
     def _shutdown(self, warm=True):
-        what = 'Stopping' if warm else 'Terminating'
-        socket_timeout = socket.getdefaulttimeout()
-        socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
-
-        if self._state in (self.CLOSE, self.TERMINATE):
-            return
-
-        self.app.loader.shutdown_worker()
-
-        if self.pool:
-            self.pool.close()
-
-        if self._state != self.RUN or self._running != len(self.components):
-            # Not fully started, can safely exit.
-            self._state = self.TERMINATE
-            self._shutdown_complete.set()
-            return
-        self._state = self.CLOSE
-
-        for component in reversed(self.components):
-            logger.debug('%s %s...', what, qualname(component))
-            if component:
-                stop = component.stop
-                if not warm:
-                    stop = getattr(component, 'terminate', None) or stop
-                stop()
-
-        self.timer.stop()
-        self.consumer.close_connection()
-
-        if self.pidlock:
-            self.pidlock.release()
-        self._state = self.TERMINATE
-        socket.setdefaulttimeout(socket_timeout)
-        self._shutdown_complete.set()
+        self.namespace.stop(self, terminate=not warm)
+        self.namespace.join()
 
     def reload(self, modules=None, reload=False, reloader=None):
         modules = self.app.loader.task_modules if modules is None else modules
@@ -309,6 +251,10 @@ class WorkController(configurated):
                 reload_from_cwd(sys.modules[module], reloader)
         self.pool.restart()
 
+    @property
+    def _state(self):
+        return self.namespace.state
+
     @property
     def state(self):
         return state

+ 89 - 3
celery/worker/bootsteps.py

@@ -8,13 +8,30 @@
 """
 from __future__ import absolute_import
 
+import socket
+
 from collections import defaultdict
 from importlib import import_module
+from threading import Event
 
 from celery.datastructures import DependencyGraph
-from celery.utils.imports import instantiate
+from celery.utils.imports import instantiate, qualname
 from celery.utils.log import get_logger
 
+try:
+    from greenlet import GreenletExit
+    IGNORE_ERRORS = (GreenletExit, )
+except ImportError:  # pragma: no cover
+    IGNORE_ERRORS = ()
+
+#: Default socket timeout at shutdown.
+SHUTDOWN_SOCKET_TIMEOUT = 5.0
+
+#: States
+RUN = 0x1
+CLOSE = 0x2
+TERMINATE = 0x3
+
 logger = get_logger(__name__)
 
 
@@ -32,13 +49,79 @@ class Namespace(object):
 
     """
     name = None
+    state = None
+    started = 0
+
     _unclaimed = defaultdict(dict)
-    _started_count = 0
 
-    def __init__(self, name=None, app=None):
+    def __init__(self, name=None, app=None, on_start=None,
+            on_close=None, on_stopped=None):
         self.app = app
         self.name = name or self.name
+        self.on_start = on_start
+        self.on_close = on_close
+        self.on_stopped = on_stopped
         self.services = []
+        self.shutdown_complete = Event()
+
+    def start(self, parent):
+        self.state = RUN
+        if self.on_start:
+            self.on_start()
+        for i, component in enumerate(parent.components):
+            if component:
+                logger.debug('Starting %s...', qualname(component))
+                self.started = i + 1
+                component.start()
+                logger.debug('%s OK!', qualname(component))
+
+    def close(self, parent):
+        if self.on_close:
+            self.on_close()
+        for component in parent.components:
+            try:
+                close = component.close
+            except AttributeError:
+                pass
+            else:
+                close()
+
+    def stop(self, parent, terminate=False):
+        what = 'Terminating' if terminate else 'Stopping'
+        socket_timeout = socket.getdefaulttimeout()
+        socket.setdefaulttimeout(SHUTDOWN_SOCKET_TIMEOUT)  # Issue 975
+
+        if self.state in (CLOSE, TERMINATE):
+            return
+
+        if self.state != RUN or self.started != len(parent.components):
+            # Not fully started, can safely exit.
+            self.state = TERMINATE
+            self.shutdown_complete.set()
+            return
+
+        self.state = CLOSE
+        for component in reversed(parent.components):
+            if component:
+                logger.debug('%s %s...', what, qualname(component))
+                stop = component.stop
+                if terminate:
+                    stop = getattr(component, 'terminate', None) or stop
+                stop()
+
+        if self.on_stopped:
+            self.on_stopped()
+        self.state = TERMINATE
+        socket.setdefaulttimeout(socket_timeout)
+        self.shutdown_complete.set()
+
+    def join(self, timeout=None):
+        try:
+            # Will only get here if running green,
+            # makes sure all greenthreads have exited.
+            self.shutdown_complete.wait(timeout=timeout)
+        except IGNORE_ERRORS:
+            pass
 
     def modules(self):
         """Subclasses can override this to return a
@@ -200,6 +283,9 @@ class StartStopComponent(Component):
     def stop(self):
         return self.obj.stop()
 
+    def close(self):
+        pass
+
     def terminate(self):
         if self.terminable:
             return self.obj.terminate()

+ 4 - 0
celery/worker/components.py

@@ -54,6 +54,10 @@ class Pool(bootsteps.StartStopComponent):
             w.max_concurrency, w.min_concurrency = w.autoscale
         self.autoreload_enabled = autoreload
 
+    def close(self, w):
+        if w.pool:
+            w.pool.close()
+
     def on_poll_init(self, pool, hub):
         apply_after = hub.timer.apply_after
         apply_at = hub.timer.apply_at

+ 1 - 4
celery/worker/consumer.py

@@ -94,13 +94,10 @@ from celery.utils.log import get_logger
 from celery.utils.timeutils import humanize_seconds
 
 from . import state
-from .bootsteps import StartStopComponent
+from .bootsteps import StartStopComponent, RUN, CLOSE
 from .control import Panel
 from .heartbeat import Heart
 
-RUN = 0x1
-CLOSE = 0x2
-
 #: Heartbeat check is called every heartbeat_seconds' / rate'.
 AMQHEARTBEAT_RATE = 2.0